-
Notifications
You must be signed in to change notification settings - Fork 82
Perform sret structuring #2832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perform sret structuring #2832
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2832 +/- ##
==========================================
- Coverage 67.77% 67.54% -0.23%
==========================================
Files 58 58
Lines 20913 21049 +136
==========================================
+ Hits 14173 14217 +44
- Misses 6740 6832 +92 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Benchmark Results
Benchmark PlotsA plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/20066125750/artifacts/4811571053. |
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/compiler.jl b/src/compiler.jl
index bf7e62cc..d3ab49c7 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -3534,7 +3534,7 @@ function create_abi_wrapper(
],
)
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
- extract_struct_into!(builder, ptr, eval)
+ extract_struct_into!(builder, ptr, eval)
returnNum += 1
if i == 3 && shadow_init
shadows = LLVM.Value[]
@@ -3576,7 +3576,7 @@ function create_abi_wrapper(
],
)
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
- extract_struct_into!(builder, ptr, eval)
+ extract_struct_into!(builder, ptr, eval)
returnNum += 1
end
end
@@ -3676,7 +3676,7 @@ function create_abi_wrapper(
],
)
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
- extract_struct_into!(builder, ptr, eval)
+ extract_struct_into!(builder, ptr, eval)
end
@assert count_Sret == numLLVMReturns
else
@@ -3693,7 +3693,7 @@ function create_abi_wrapper(
makeInstanceOf(builder, sret_types[returnNum+1])
end,
)
- extract_struct_into!(
+ extract_struct_into!(
builder,
inbounds_gep!(
builder,
@@ -3719,7 +3719,7 @@ function create_abi_wrapper(
isboxed = GPUCompiler.deserves_argbox(T′)
if !isboxed
eval = extract_value!(builder, val, returnNum)
- extract_struct_into!(
+ extract_struct_into!(
builder,
inbounds_gep!(
builder,
@@ -3817,17 +3817,18 @@ function to_llvm(lst::Vector{Cuint})
vals = LLVM.Value[]
push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
for i in lst
- push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
+ push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
end
return vals
end
-
+
function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, sret::LLVM.Value, root_ty::LLVM.LLVMType, rootRet::Union{LLVM.Value, Nothing}, direction::SRetRootMovement; must_cache::Bool = false)
count = 0
todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
Cuint[],
jltype,
- )]
+ ),
+ ]
extracted = LLVM.Value[]
@@ -3971,7 +3972,8 @@ function copy_floats_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::
todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
Cuint[],
jltype,
- )]
+ ),
+ ]
extracted = LLVM.Value[]
@@ -3983,8 +3985,8 @@ function copy_floats_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::
end
if isa(ty, LLVM.FloatingPointType)
- dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloccf")
- srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloccf")
+ dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloccf")
+ srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloccf")
val = load!(builder, ty, srcloc)
st = store!(builder, val, dstloc)
continue
@@ -4026,17 +4028,18 @@ function extract_nonjlvalues_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMTyp
todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
Cuint[],
jltype,
- )]
+ ),
+ ]
extracted = LLVM.Value[]
-
+
if addrspace(value_type(dst)) == 10
- PT2 = if LLVM.is_opaque(value_type(dst))
- LLVM.PointerType(11)
- else
- LLVM.PointerType(eltype(value_type(dst)), 11)
- end
- dst = addrspacecast!(builder, PT2, dst)
+ PT2 = if LLVM.is_opaque(value_type(dst))
+ LLVM.PointerType(11)
+ else
+ LLVM.PointerType(eltype(value_type(dst)), 11)
+ end
+ dst = addrspacecast!(builder, PT2, dst)
end
while length(todo) != 0
@@ -4075,7 +4078,7 @@ function extract_nonjlvalues_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMTyp
continue
end
- dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstlocnjl")
+ dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstlocnjl")
val = Enzyme.API.e_extract_value!(builder, src, path)
st = store!(builder, val, dstloc)
end
@@ -4086,122 +4089,126 @@ end
function extract_struct_into!(builder::LLVM.IRBuilder, dst::LLVM.Value, src::LLVM.Value)
count = 0
jltype = value_type(src)
- todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
- Cuint[],
- jltype,
- )]
+ todo = Tuple{Vector{Cuint}, LLVM.LLVMType}[
+ (
+ Cuint[],
+ jltype,
+ ),
+ ]
extracted = LLVM.Value[]
-
+
if addrspace(value_type(dst)) == 10
- PT2 = if LLVM.is_opaque(value_type(dst))
- LLVM.PointerType(11)
- else
- LLVM.PointerType(eltype(value_type(dst)), 11)
- end
- dst = addrspacecast!(builder, PT2, dst)
+ PT2 = if LLVM.is_opaque(value_type(dst))
+ LLVM.PointerType(11)
+ else
+ LLVM.PointerType(eltype(value_type(dst)), 11)
+ end
+ dst = addrspacecast!(builder, PT2, dst)
end
while length(todo) != 0
- path, ty = popfirst!(todo)
+ path, ty = popfirst!(todo)
- if isa(ty, LLVM.ArrayType) && any_jltypes(ty)
- for i = 1:length(ty)
- npath = copy(path)
- push!(npath, i - 1)
- push!(todo, (npath, eltype(ty)))
- end
- continue
+ if isa(ty, LLVM.ArrayType) && any_jltypes(ty)
+ for i in 1:length(ty)
+ npath = copy(path)
+ push!(npath, i - 1)
+ push!(todo, (npath, eltype(ty)))
end
+ continue
+ end
- if isa(ty, LLVM.VectorType) && any_jltypes(ty)
- for i = 1:size(ty)
- npath = copy(path)
- push!(npath, i - 1)
- push!(todo, (npath, eltype(ty)))
- end
- continue
+ if isa(ty, LLVM.VectorType) && any_jltypes(ty)
+ for i in 1:size(ty)
+ npath = copy(path)
+ push!(npath, i - 1)
+ push!(todo, (npath, eltype(ty)))
end
+ continue
+ end
- if isa(ty, LLVM.StructType) && any_jltypes(ty)
- for (i, t) in enumerate(LLVM.elements(ty))
- npath = copy(path)
- push!(npath, i - 1)
- push!(todo, (npath, t))
- end
- continue
+ if isa(ty, LLVM.StructType) && any_jltypes(ty)
+ for (i, t) in enumerate(LLVM.elements(ty))
+ npath = copy(path)
+ push!(npath, i - 1)
+ push!(todo, (npath, t))
end
-
- dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstlocsi")
- val = length(path) == 0 ? src : Enzyme.API.e_extract_value!(builder, src, path)
- st = store!(builder, val, dstloc)
+ continue
end
- return nothing
+ dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstlocsi")
+ val = length(path) == 0 ? src : Enzyme.API.e_extract_value!(builder, src, path)
+ st = store!(builder, val, dstloc)
+ end
+
+ return nothing
end
function copy_struct_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::LLVM.Value, src::LLVM.Value)
count = 0
- todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
- Cuint[],
- jltype,
- )]
+ todo = Tuple{Vector{Cuint}, LLVM.LLVMType}[
+ (
+ Cuint[],
+ jltype,
+ ),
+ ]
extracted = LLVM.Value[]
-
+
if addrspace(value_type(dst)) == 10
- PT2 = if LLVM.is_opaque(value_type(dst))
- LLVM.PointerType(11)
- else
- LLVM.PointerType(eltype(value_type(dst)), 11)
- end
- dst = addrspacecast!(builder, PT2, dst)
+ PT2 = if LLVM.is_opaque(value_type(dst))
+ LLVM.PointerType(11)
+ else
+ LLVM.PointerType(eltype(value_type(dst)), 11)
+ end
+ dst = addrspacecast!(builder, PT2, dst)
end
-
+
if addrspace(value_type(src)) == 10
- PT2 = if LLVM.is_opaque(value_type(src))
- LLVM.PointerType(11)
- else
- LLVM.PointerType(eltype(value_type(src)), 11)
- end
- src = addrspacecast!(builder, src, PT2)
+ PT2 = if LLVM.is_opaque(value_type(src))
+ LLVM.PointerType(11)
+ else
+ LLVM.PointerType(eltype(value_type(src)), 11)
+ end
+ src = addrspacecast!(builder, src, PT2)
end
while length(todo) != 0
- path, ty = popfirst!(todo)
+ path, ty = popfirst!(todo)
- if isa(ty, LLVM.ArrayType) && any_jltypes(ty)
- for i = 1:length(ty)
- npath = copy(path)
- push!(npath, i - 1)
- push!(todo, (npath, eltype(ty)))
- end
- continue
+ if isa(ty, LLVM.ArrayType) && any_jltypes(ty)
+ for i in 1:length(ty)
+ npath = copy(path)
+ push!(npath, i - 1)
+ push!(todo, (npath, eltype(ty)))
end
+ continue
+ end
- if isa(ty, LLVM.VectorType) && any_jltypes(ty)
- for i = 1:size(ty)
- npath = copy(path)
- push!(npath, i - 1)
- push!(todo, (npath, eltype(ty)))
- end
- continue
+ if isa(ty, LLVM.VectorType) && any_jltypes(ty)
+ for i in 1:size(ty)
+ npath = copy(path)
+ push!(npath, i - 1)
+ push!(todo, (npath, eltype(ty)))
end
+ continue
+ end
- if isa(ty, LLVM.StructType) && any_jltypes(ty)
- for (i, t) in enumerate(LLVM.elements(ty))
- npath = copy(path)
- push!(npath, i - 1)
- push!(todo, (npath, t))
- end
- continue
+ if isa(ty, LLVM.StructType) && any_jltypes(ty)
+ for (i, t) in enumerate(LLVM.elements(ty))
+ npath = copy(path)
+ push!(npath, i - 1)
+ push!(todo, (npath, t))
end
-
+ continue
+ end
+
dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloccs")
srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloccs")
val = load!(builder, ty, srcloc)
st = store!(builder, val, dstloc)
- end
+ end
return nothing
end
diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl
index 713e8440..a5d17dbe 100644
--- a/src/compiler/optimize.jl
+++ b/src/compiler/optimize.jl
@@ -321,7 +321,7 @@ function addMachinePasses!(mpm::LLVM.NewPMPassManager)
add!(mpm, AddressSanitizerPass())
end
end
- add!(mpm, NewPMFunctionPassManager()) do fpm
+ return add!(mpm, NewPMFunctionPassManager()) do fpm
add!(fpm, DemoteFloat16Pass())
add!(fpm, GVNPass())
end
@@ -379,9 +379,9 @@ const DumpPostCallConv = Ref(false)
function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true)
addr13NoAlias(mod)
-
+
removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)
-
+
memcpy_sret_split!(mod)
# if we did the move_sret_tofrom_roots, we will have loaded out of the sret, then stored into the rooted.
@@ -392,7 +392,7 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool
# GVN actually forwards
@dispose pb = NewPMPassBuilder() begin
registerEnzymeAndPassPipeline!(pb)
- add!(pb, SimpleGVNPass())
+ add!(pb, SimpleGVNPass())
run!(pb, mod, tm)
end
if DumpPreCallConv[]
diff --git a/src/errors.jl b/src/errors.jl
index 7b18ec26..f27ea8cc 100644
--- a/src/errors.jl
+++ b/src/errors.jl
@@ -1187,16 +1187,18 @@ function julia_error(
data2 = LLVM.Value(data2)
fn = LLVM.Function(LLVM.API.LLVMGetParamParent(data2::LLVM.Argument))
@static if VERSION < v"1.11"
- sretkind = LLVM.kind(if LLVM.version().major >= 12
- LLVM.TypeAttribute("sret", LLVM.Int32Type())
- else
- LLVM.EnumAttribute("sret")
- end)
- if occursin("Could not find use of stored value", msg) && length(parameters(fn)) >= 1 && any(LLVM.kind(attr) == sretkind for attr in collect(LLVM.parameter_attributes(fn, 1)))
- return C_NULL
- end
+ sretkind = LLVM.kind(
+ if LLVM.version().major >= 12
+ LLVM.TypeAttribute("sret", LLVM.Int32Type())
+ else
+ LLVM.EnumAttribute("sret")
+ end
+ )
+ if occursin("Could not find use of stored value", msg) && length(parameters(fn)) >= 1 && any(LLVM.kind(attr) == sretkind for attr in collect(LLVM.parameter_attributes(fn, 1)))
+ return C_NULL
+ end
end
- msgN = sprint() do io::IO
+ msgN = sprint() do io::IO
print(io, msg)
println(io)
println(io, "Fn = ", string(fn))
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index e6e2d2be..8c21bf80 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -480,62 +480,65 @@ end
function memcpy_sret_split!(mod::LLVM.Module)
dl = datalayout(mod)
ctx = context(mod)
- sretkind = LLVM.kind(if LLVM.version().major >= 12
- LLVM.TypeAttribute("sret", LLVM.Int32Type())
- else
- LLVM.EnumAttribute("sret")
- end)
+ sretkind = LLVM.kind(
+ if LLVM.version().major >= 12
+ LLVM.TypeAttribute("sret", LLVM.Int32Type())
+ else
+ LLVM.EnumAttribute("sret")
+ end
+ )
for f in functions(mod)
if length(blocks(f)) == 0
- continue
- end
- if length(parameters(f)) == 0
- continue
- end
- sty = nothing
- for attr in collect(LLVM.parameter_attributes(f, 1))
- if LLVM.kind(attr) == sretkind
- sty = LLVM.value(attr)
- break
- end
- end
- if sty === nothing
- continue
- end
- tracked = CountTrackedPointers(sty)
- if tracked.all || tracked.count == 0
- continue
- end
- todo = LLVM.CallInst[]
- for bb in blocks(f)
+ continue
+ end
+ if length(parameters(f)) == 0
+ continue
+ end
+ sty = nothing
+ for attr in collect(LLVM.parameter_attributes(f, 1))
+ if LLVM.kind(attr) == sretkind
+ sty = LLVM.value(attr)
+ break
+ end
+ end
+ if sty === nothing
+ continue
+ end
+ tracked = CountTrackedPointers(sty)
+ if tracked.all || tracked.count == 0
+ continue
+ end
+ todo = LLVM.CallInst[]
+ for bb in blocks(f)
for cur in instructions(bb)
- if isa(cur, LLVM.CallInst) &&
- isa(LLVM.called_operand(cur), LLVM.Function)
- intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(cur))
- if intr == LLVM.Intrinsic("llvm.memcpy").id
- dst, _ = get_base_and_offset(operands(cur)[1]; offsetAllowed = false)
- if isa(dst, LLVM.Argument) && parameters(f)[1] == dst
- if isa(operands(cur)[3], LLVM.ConstantInt) && LLVM.sizeof(dl, sty) == convert(Int, operands(cur)[3])
- push!(todo, cur)
- end
- end
+ if isa(cur, LLVM.CallInst) &&
+ isa(LLVM.called_operand(cur), LLVM.Function)
+ intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(cur))
+ if intr == LLVM.Intrinsic("llvm.memcpy").id
+ dst, _ = get_base_and_offset(operands(cur)[1]; offsetAllowed = false)
+ if isa(dst, LLVM.Argument) && parameters(f)[1] == dst
+ if isa(operands(cur)[3], LLVM.ConstantInt) && LLVM.sizeof(dl, sty) == convert(Int, operands(cur)[3])
+ push!(todo, cur)
+ end
end
end
- end
- end
- for cur in todo
- B = IRBuilder()
- position!(B, cur)
- dst, _ = get_base_and_offset(operands(cur)[1]; offsetAllowed = false)
- src, _ = get_base_and_offset(operands(cur)[2]; offsetAllowed = false)
- if !LLVM.is_opaque(value_type(dst)) && eltype(value_type(dst)) != eltype(value_type(src))
- src = pointercast!(B, src, LLVM.PointerType(eltype(value_type(dst)), addrspace(value_type(src))), "memcpy_sret_split_pointercast")
- end
- copy_struct_into!(B, sty, dst, src)
- LLVM.API.LLVMInstructionEraseFromParent(cur)
+ end
+ end
+ end
+ for cur in todo
+ B = IRBuilder()
+ position!(B, cur)
+ dst, _ = get_base_and_offset(operands(cur)[1]; offsetAllowed = false)
+ src, _ = get_base_and_offset(operands(cur)[2]; offsetAllowed = false)
+ if !LLVM.is_opaque(value_type(dst)) && eltype(value_type(dst)) != eltype(value_type(src))
+ src = pointercast!(B, src, LLVM.PointerType(eltype(value_type(dst)), addrspace(value_type(src))), "memcpy_sret_split_pointercast")
+ end
+ copy_struct_into!(B, sty, dst, src)
+ LLVM.API.LLVMInstructionEraseFromParent(cur)
end
end
+ return
end
# If there is a phi node of a decayed value, Enzyme may need to cache it |
| function memcpy_sret_split!(mod::LLVM.Module) | ||
| dl = datalayout(mod) | ||
| ctx = context(mod) | ||
| sretkind = LLVM.kind(if LLVM.version().major >= 12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the LLVM version check to be extra safe? Technically, oldest supported version at this point is 15 (the one coming with Julia v1.10), so you could just inline the check here and below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah we just use that everywhere else atm. I suppose we can just do the latter tho at this point
Newer version of `DifferentiationInterfaceTest` is compatible with Julia v1.12.
No description provided.