Skip to content

Conversation

@wsmoses
Copy link
Member

@wsmoses wsmoses commented Dec 3, 2025

No description provided.

@codecov
Copy link

codecov bot commented Dec 3, 2025

Codecov Report

❌ Patch coverage is 50.29586% with 84 lines in your changes missing coverage. Please review.
✅ Project coverage is 67.54%. Comparing base (7cdcc1c) to head (22eb451).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
src/compiler.jl 41.17% 60 Missing ⚠️
src/llvm/transforms.jl 75.55% 11 Missing ⚠️
src/errors.jl 0.00% 9 Missing ⚠️
src/compiler/optimize.jl 69.23% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2832      +/-   ##
==========================================
- Coverage   67.77%   67.54%   -0.23%     
==========================================
  Files          58       58              
  Lines       20913    21049     +136     
==========================================
+ Hits        14173    14217      +44     
- Misses       6740     6832      +92     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions
Copy link
Contributor

github-actions bot commented Dec 3, 2025

Benchmark Results

main 22eb451... main / 22eb451...
basics/make_zero/namedtuple 0.0537 ± 0.0024 μs 0.0539 ± 0.0025 μs 0.997 ± 0.064
basics/make_zero/struct 0.288 ± 0.0079 μs 0.264 ± 0.0054 μs 1.09 ± 0.037
basics/overhead 4.35 ± 0.009 ns 4.03 ± 0.01 ns 1.08 ± 0.0035
basics/remake_zero!/namedtuple 0.237 ± 0.0062 μs 0.236 ± 0.011 μs 1 ± 0.054
basics/remake_zero!/struct 0.233 ± 0.0079 μs 0.242 ± 0.0093 μs 0.965 ± 0.05
fold_broadcast/multidim_sum_bcast/1D 10.3 ± 0.21 μs 10.4 ± 1.8 μs 0.984 ± 0.17
fold_broadcast/multidim_sum_bcast/2D 10.3 ± 0.21 μs 12.1 ± 0.29 μs 0.85 ± 0.027
time_to_load 1.07 ± 0.016 s 1.03 ± 0.0038 s 1.04 ± 0.016

Benchmark Plots

A plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/20066125750/artifacts/4811571053.

@github-actions
Copy link
Contributor

github-actions bot commented Dec 3, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/compiler.jl b/src/compiler.jl
index bf7e62cc..d3ab49c7 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -3534,7 +3534,7 @@ function create_abi_wrapper(
                     ],
                 )
                 ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
-		extract_struct_into!(builder, ptr, eval)
+                extract_struct_into!(builder, ptr, eval)
                 returnNum += 1
                 if i == 3 && shadow_init
                     shadows = LLVM.Value[]
@@ -3576,7 +3576,7 @@ function create_abi_wrapper(
                     ],
                 )
                 ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
-		extract_struct_into!(builder, ptr, eval)
+                extract_struct_into!(builder, ptr, eval)
                 returnNum += 1
             end
         end
@@ -3676,7 +3676,7 @@ function create_abi_wrapper(
                 ],
             )
             ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
-	    extract_struct_into!(builder, ptr, eval)
+            extract_struct_into!(builder, ptr, eval)
         end
         @assert count_Sret == numLLVMReturns
     else
@@ -3693,7 +3693,7 @@ function create_abi_wrapper(
                             makeInstanceOf(builder, sret_types[returnNum+1])
                         end,
                     )
-	    	    extract_struct_into!(
+                    extract_struct_into!(
                         builder,
                         inbounds_gep!(
                             builder,
@@ -3719,7 +3719,7 @@ function create_abi_wrapper(
                 isboxed = GPUCompiler.deserves_argbox(T′)
                 if !isboxed
                     eval = extract_value!(builder, val, returnNum)
-	    	    extract_struct_into!(
+                    extract_struct_into!(
                         builder,
                         inbounds_gep!(
                             builder,
@@ -3817,17 +3817,18 @@ function to_llvm(lst::Vector{Cuint})
     vals = LLVM.Value[]
     push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
     for i in lst
-       push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
+        push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
     end
     return vals
 end
-    
+
 function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, sret::LLVM.Value, root_ty::LLVM.LLVMType, rootRet::Union{LLVM.Value, Nothing}, direction::SRetRootMovement; must_cache::Bool = false)
         count = 0
         todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
 	    Cuint[],
             jltype,
-        )]
+        ),
+    ]
 
 	extracted = LLVM.Value[]
 
@@ -3971,7 +3972,8 @@ function copy_floats_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::
     todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
 	    Cuint[],
         jltype,
-    )]
+        ),
+    ]
 
 	extracted = LLVM.Value[]
 
@@ -3983,8 +3985,8 @@ function copy_floats_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::
             end
 
             if isa(ty, LLVM.FloatingPointType)
-		dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloccf")
-		srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloccf")
+            dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloccf")
+            srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloccf")
                 val = load!(builder, ty, srcloc)
                 st = store!(builder, val, dstloc)
                 continue
@@ -4026,17 +4028,18 @@ function extract_nonjlvalues_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMTyp
     todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
 	    Cuint[],
         jltype,
-    )]
+        ),
+    ]
 
     extracted = LLVM.Value[]
-	
+
     if addrspace(value_type(dst)) == 10
-       PT2 = if LLVM.is_opaque(value_type(dst))
-	   LLVM.PointerType(11)
-       else
-	   LLVM.PointerType(eltype(value_type(dst)), 11)
-       end
-       dst = addrspacecast!(builder, PT2, dst)
+        PT2 = if LLVM.is_opaque(value_type(dst))
+            LLVM.PointerType(11)
+        else
+            LLVM.PointerType(eltype(value_type(dst)), 11)
+        end
+        dst = addrspacecast!(builder, PT2, dst)
     end
 
     while length(todo) != 0
@@ -4075,7 +4078,7 @@ function extract_nonjlvalues_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMTyp
                 continue
             end
 		
-	    dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstlocnjl")
+        dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstlocnjl")
             val = Enzyme.API.e_extract_value!(builder, src, path)
 	    st = store!(builder, val, dstloc)
         end
@@ -4086,122 +4089,126 @@ end
 function extract_struct_into!(builder::LLVM.IRBuilder, dst::LLVM.Value, src::LLVM.Value)
     count = 0
     jltype = value_type(src)
-    todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
-	    Cuint[],
-        jltype,
-    )]
+    todo = Tuple{Vector{Cuint}, LLVM.LLVMType}[
+        (
+            Cuint[],
+            jltype,
+        ),
+    ]
 
     extracted = LLVM.Value[]
-	
+
     if addrspace(value_type(dst)) == 10
-       PT2 = if LLVM.is_opaque(value_type(dst))
-	   LLVM.PointerType(11)
-       else
-	   LLVM.PointerType(eltype(value_type(dst)), 11)
-       end
-       dst = addrspacecast!(builder, PT2, dst)
+        PT2 = if LLVM.is_opaque(value_type(dst))
+            LLVM.PointerType(11)
+        else
+            LLVM.PointerType(eltype(value_type(dst)), 11)
+        end
+        dst = addrspacecast!(builder, PT2, dst)
     end
 
     while length(todo) != 0
-            path, ty = popfirst!(todo)
+        path, ty = popfirst!(todo)
 
-            if isa(ty, LLVM.ArrayType) && any_jltypes(ty)
-                for i = 1:length(ty)
-                    npath = copy(path)
-                    push!(npath, i - 1)
-                    push!(todo, (npath, eltype(ty)))
-                end
-                continue
+        if isa(ty, LLVM.ArrayType) && any_jltypes(ty)
+            for i in 1:length(ty)
+                npath = copy(path)
+                push!(npath, i - 1)
+                push!(todo, (npath, eltype(ty)))
             end
+            continue
+        end
 
-            if isa(ty, LLVM.VectorType) && any_jltypes(ty)
-                for i = 1:size(ty)
-                    npath = copy(path)
-                    push!(npath, i - 1)
-                    push!(todo, (npath, eltype(ty)))
-                end
-                continue
+        if isa(ty, LLVM.VectorType) && any_jltypes(ty)
+            for i in 1:size(ty)
+                npath = copy(path)
+                push!(npath, i - 1)
+                push!(todo, (npath, eltype(ty)))
             end
+            continue
+        end
 
-            if isa(ty, LLVM.StructType) && any_jltypes(ty)
-                for (i, t) in enumerate(LLVM.elements(ty))
-                    npath = copy(path)
-                    push!(npath, i - 1)
-                    push!(todo, (npath, t))
-                end
-                continue
+        if isa(ty, LLVM.StructType) && any_jltypes(ty)
+            for (i, t) in enumerate(LLVM.elements(ty))
+                npath = copy(path)
+                push!(npath, i - 1)
+                push!(todo, (npath, t))
             end
-		
-	    dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstlocsi")
-	    val = length(path) == 0 ? src : Enzyme.API.e_extract_value!(builder, src, path)
-	    st = store!(builder, val, dstloc)
+            continue
         end
 
-	return nothing
+        dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstlocsi")
+        val = length(path) == 0 ? src : Enzyme.API.e_extract_value!(builder, src, path)
+        st = store!(builder, val, dstloc)
+    end
+
+    return nothing
 end
 
 function copy_struct_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::LLVM.Value, src::LLVM.Value)
     count = 0
-    todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
-        Cuint[],
-        jltype,
-    )]
+    todo = Tuple{Vector{Cuint}, LLVM.LLVMType}[
+        (
+            Cuint[],
+            jltype,
+        ),
+    ]
 
     extracted = LLVM.Value[]
-	
+
     if addrspace(value_type(dst)) == 10
-       PT2 = if LLVM.is_opaque(value_type(dst))
-	   LLVM.PointerType(11)
-       else
-	   LLVM.PointerType(eltype(value_type(dst)), 11)
-       end
-       dst = addrspacecast!(builder, PT2, dst)
+        PT2 = if LLVM.is_opaque(value_type(dst))
+            LLVM.PointerType(11)
+        else
+            LLVM.PointerType(eltype(value_type(dst)), 11)
+        end
+        dst = addrspacecast!(builder, PT2, dst)
     end
-    
+
     if addrspace(value_type(src)) == 10
-       PT2 = if LLVM.is_opaque(value_type(src))
-	   LLVM.PointerType(11)
-       else
-	   LLVM.PointerType(eltype(value_type(src)), 11)
-       end
-       src = addrspacecast!(builder, src, PT2)
+        PT2 = if LLVM.is_opaque(value_type(src))
+            LLVM.PointerType(11)
+        else
+            LLVM.PointerType(eltype(value_type(src)), 11)
+        end
+        src = addrspacecast!(builder, src, PT2)
     end
 
     while length(todo) != 0
-            path, ty = popfirst!(todo)
+        path, ty = popfirst!(todo)
 
-            if isa(ty, LLVM.ArrayType) && any_jltypes(ty)
-                for i = 1:length(ty)
-                    npath = copy(path)
-                    push!(npath, i - 1)
-                    push!(todo, (npath, eltype(ty)))
-                end
-                continue
+        if isa(ty, LLVM.ArrayType) && any_jltypes(ty)
+            for i in 1:length(ty)
+                npath = copy(path)
+                push!(npath, i - 1)
+                push!(todo, (npath, eltype(ty)))
             end
+            continue
+        end
 
-            if isa(ty, LLVM.VectorType) && any_jltypes(ty)
-                for i = 1:size(ty)
-                    npath = copy(path)
-                    push!(npath, i - 1)
-                    push!(todo, (npath, eltype(ty)))
-                end
-                continue
+        if isa(ty, LLVM.VectorType) && any_jltypes(ty)
+            for i in 1:size(ty)
+                npath = copy(path)
+                push!(npath, i - 1)
+                push!(todo, (npath, eltype(ty)))
             end
+            continue
+        end
 
-            if isa(ty, LLVM.StructType) && any_jltypes(ty)
-                for (i, t) in enumerate(LLVM.elements(ty))
-                    npath = copy(path)
-                    push!(npath, i - 1)
-                    push!(todo, (npath, t))
-                end
-                continue
+        if isa(ty, LLVM.StructType) && any_jltypes(ty)
+            for (i, t) in enumerate(LLVM.elements(ty))
+                npath = copy(path)
+                push!(npath, i - 1)
+                push!(todo, (npath, t))
             end
-        
+            continue
+        end
+
         dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloccs")
         srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloccs")
         val = load!(builder, ty, srcloc)
         st = store!(builder, val, dstloc)
-        end
+    end
     return nothing
 end
 
diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl
index 713e8440..a5d17dbe 100644
--- a/src/compiler/optimize.jl
+++ b/src/compiler/optimize.jl
@@ -321,7 +321,7 @@ function addMachinePasses!(mpm::LLVM.NewPMPassManager)
             add!(mpm, AddressSanitizerPass())
         end
     end
-    add!(mpm, NewPMFunctionPassManager()) do fpm
+    return add!(mpm, NewPMFunctionPassManager()) do fpm
         add!(fpm, DemoteFloat16Pass())
         add!(fpm, GVNPass())              
     end
@@ -379,9 +379,9 @@ const DumpPostCallConv = Ref(false)
 
 function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true)
     addr13NoAlias(mod)
-    
+
     removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)
-    
+
 
     memcpy_sret_split!(mod)
     # if we did the move_sret_tofrom_roots, we will have loaded out of the sret, then stored into the rooted.
@@ -392,7 +392,7 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool
     # GVN actually forwards
     @dispose pb = NewPMPassBuilder() begin
         registerEnzymeAndPassPipeline!(pb)
-    	add!(pb, SimpleGVNPass())
+        add!(pb, SimpleGVNPass())
         run!(pb, mod, tm)
     end
     if DumpPreCallConv[]
diff --git a/src/errors.jl b/src/errors.jl
index 7b18ec26..f27ea8cc 100644
--- a/src/errors.jl
+++ b/src/errors.jl
@@ -1187,16 +1187,18 @@ function julia_error(
         data2 = LLVM.Value(data2)
         fn = LLVM.Function(LLVM.API.LLVMGetParamParent(data2::LLVM.Argument))
         @static if VERSION < v"1.11"
-	    sretkind = LLVM.kind(if LLVM.version().major >= 12
-		LLVM.TypeAttribute("sret", LLVM.Int32Type())
-	    else
-		LLVM.EnumAttribute("sret")
-	    end)
-	    if occursin("Could not find use of stored value", msg) && length(parameters(fn)) >= 1 && any(LLVM.kind(attr) == sretkind for attr in collect(LLVM.parameter_attributes(fn, 1)))
-		return C_NULL
-	    end
+            sretkind = LLVM.kind(
+                if LLVM.version().major >= 12
+                    LLVM.TypeAttribute("sret", LLVM.Int32Type())
+                else
+                    LLVM.EnumAttribute("sret")
+                end
+            )
+            if occursin("Could not find use of stored value", msg) && length(parameters(fn)) >= 1 && any(LLVM.kind(attr) == sretkind for attr in collect(LLVM.parameter_attributes(fn, 1)))
+                return C_NULL
+            end
         end
-	msgN = sprint() do io::IO
+        msgN = sprint() do io::IO
             print(io, msg)
             println(io)
             println(io, "Fn = ", string(fn))
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index e6e2d2be..8c21bf80 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -480,62 +480,65 @@ end
 function memcpy_sret_split!(mod::LLVM.Module)
     dl = datalayout(mod)
     ctx = context(mod)
-	    sretkind = LLVM.kind(if LLVM.version().major >= 12
-                LLVM.TypeAttribute("sret", LLVM.Int32Type())
-            else
-                LLVM.EnumAttribute("sret")
-            end)
+    sretkind = LLVM.kind(
+        if LLVM.version().major >= 12
+            LLVM.TypeAttribute("sret", LLVM.Int32Type())
+        else
+            LLVM.EnumAttribute("sret")
+        end
+    )
     for f in functions(mod)
 
         if length(blocks(f)) == 0
-	    continue
-	end
-	if length(parameters(f)) == 0
-	    continue
-	end
-	sty = nothing
-	for attr in collect(LLVM.parameter_attributes(f, 1))
-	    if LLVM.kind(attr) == sretkind
-		 sty = LLVM.value(attr)
-		 break
-	    end
-	end
-	if sty === nothing
-	    continue
-	end
-	tracked = CountTrackedPointers(sty)
-	if tracked.all || tracked.count == 0
-	    continue
-	end
-	todo = LLVM.CallInst[]
-	for bb in blocks(f)
+            continue
+        end
+        if length(parameters(f)) == 0
+            continue
+        end
+        sty = nothing
+        for attr in collect(LLVM.parameter_attributes(f, 1))
+            if LLVM.kind(attr) == sretkind
+                sty = LLVM.value(attr)
+                break
+            end
+        end
+        if sty === nothing
+            continue
+        end
+        tracked = CountTrackedPointers(sty)
+        if tracked.all || tracked.count == 0
+            continue
+        end
+        todo = LLVM.CallInst[]
+        for bb in blocks(f)
             for cur in instructions(bb)
-                    if isa(cur, LLVM.CallInst) &&
-                       isa(LLVM.called_operand(cur), LLVM.Function)
-                        intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(cur))
-			if intr == LLVM.Intrinsic("llvm.memcpy").id
-			    dst, _ = get_base_and_offset(operands(cur)[1]; offsetAllowed = false)
-			    if isa(dst, LLVM.Argument) && parameters(f)[1] == dst
-			    if isa(operands(cur)[3], LLVM.ConstantInt) && LLVM.sizeof(dl, sty) == convert(Int, operands(cur)[3])
-				push!(todo, cur)
-			    end
-			    end
+                if isa(cur, LLVM.CallInst) &&
+                        isa(LLVM.called_operand(cur), LLVM.Function)
+                    intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(cur))
+                    if intr == LLVM.Intrinsic("llvm.memcpy").id
+                        dst, _ = get_base_and_offset(operands(cur)[1]; offsetAllowed = false)
+                        if isa(dst, LLVM.Argument) && parameters(f)[1] == dst
+                            if isa(operands(cur)[3], LLVM.ConstantInt) && LLVM.sizeof(dl, sty) == convert(Int, operands(cur)[3])
+                                push!(todo, cur)
+                            end
                         end
                     end
-	    end
-	end
-	for cur in todo
-	      B = IRBuilder()
-	      position!(B, cur)
-	      dst, _ = get_base_and_offset(operands(cur)[1]; offsetAllowed = false)
-	      src, _ = get_base_and_offset(operands(cur)[2]; offsetAllowed = false)
-	      if !LLVM.is_opaque(value_type(dst)) && eltype(value_type(dst)) != eltype(value_type(src))
-	          src = pointercast!(B, src, LLVM.PointerType(eltype(value_type(dst)), addrspace(value_type(src))), "memcpy_sret_split_pointercast")
-	      end
-	      copy_struct_into!(B, sty, dst, src)
-	      LLVM.API.LLVMInstructionEraseFromParent(cur)
+                end
+            end
+        end
+        for cur in todo
+            B = IRBuilder()
+            position!(B, cur)
+            dst, _ = get_base_and_offset(operands(cur)[1]; offsetAllowed = false)
+            src, _ = get_base_and_offset(operands(cur)[2]; offsetAllowed = false)
+            if !LLVM.is_opaque(value_type(dst)) && eltype(value_type(dst)) != eltype(value_type(src))
+                src = pointercast!(B, src, LLVM.PointerType(eltype(value_type(dst)), addrspace(value_type(src))), "memcpy_sret_split_pointercast")
+            end
+            copy_struct_into!(B, sty, dst, src)
+            LLVM.API.LLVMInstructionEraseFromParent(cur)
         end
     end
+    return
 end
 
 # If there is a phi node of a decayed value, Enzyme may need to cache it

function memcpy_sret_split!(mod::LLVM.Module)
dl = datalayout(mod)
ctx = context(mod)
sretkind = LLVM.kind(if LLVM.version().major >= 12
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the LLVM version check to be extra safe? Technically, oldest supported version at this point is 15 (the one coming with Julia v1.10), so you could just inline the check here and below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we just use that everywhere else atm. I suppose we can just do the latter tho at this point

@wsmoses
Copy link
Member Author

wsmoses commented Dec 5, 2025

@wsmoses
Copy link
Member Author

wsmoses commented Dec 5, 2025

@wsmoses wsmoses merged commit 3fb0ab3 into main Dec 9, 2025
52 of 56 checks passed
@wsmoses wsmoses deleted the sstruct branch December 9, 2025 15:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants