Skip to content
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5, 0.6"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.16"
Enzyme_jll = "0.0.226"
Enzyme_jll = "0.0.229"
GPUArraysCore = "0.1.6, 0.2"
GPUCompiler = "1.6.2"
LLVM = "9.1"
Expand Down
186 changes: 150 additions & 36 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3534,7 +3534,7 @@ function create_abi_wrapper(
],
)
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
si = store!(builder, eval, ptr)
extract_struct_into!(builder, ptr, eval)
returnNum += 1
if i == 3 && shadow_init
shadows = LLVM.Value[]
Expand Down Expand Up @@ -3576,7 +3576,7 @@ function create_abi_wrapper(
],
)
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
si = store!(builder, eval, ptr)
extract_struct_into!(builder, ptr, eval)
returnNum += 1
end
end
Expand Down Expand Up @@ -3676,7 +3676,7 @@ function create_abi_wrapper(
],
)
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
si = store!(builder, eval, ptr)
extract_struct_into!(builder, ptr, eval)
end
@assert count_Sret == numLLVMReturns
else
Expand All @@ -3693,9 +3693,8 @@ function create_abi_wrapper(
makeInstanceOf(builder, sret_types[returnNum+1])
end,
)
store!(
extract_struct_into!(
builder,
eval,
inbounds_gep!(
builder,
jltype,
Expand All @@ -3708,6 +3707,7 @@ function create_abi_wrapper(
),
],
),
eval,
)
returnNum += 1
end
Expand All @@ -3719,9 +3719,8 @@ function create_abi_wrapper(
isboxed = GPUCompiler.deserves_argbox(T′)
if !isboxed
eval = extract_value!(builder, val, returnNum)
store!(
extract_struct_into!(
builder,
eval,
inbounds_gep!(
builder,
jltype,
Expand All @@ -3732,6 +3731,7 @@ function create_abi_wrapper(
LLVM.ConstantInt(LLVM.IntType(32), activeNum),
],
),
eval,
)
returnNum += 1
end
Expand All @@ -3750,7 +3750,6 @@ function create_abi_wrapper(
ret!(builder)
end

# make sure that arguments are rooted if necessary
reinsert_gcmarker!(llvm_f)
if LLVM.API.LLVMVerifyFunction(llvm_f, LLVM.API.LLVMReturnStatusAction) != 0
msg = sprint() do io
Expand Down Expand Up @@ -3814,20 +3813,21 @@ end
RootAndSRetPointerToValue = 5,
)

function to_llvm(lst::Vector{Cuint})
vals = LLVM.Value[]
push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
for i in lst
push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
end
return vals
end

function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, sret::LLVM.Value, root_ty::LLVM.LLVMType, rootRet::Union{LLVM.Value, Nothing}, direction::SRetRootMovement; must_cache::Bool = false)
count = 0
todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
Cuint[],
jltype,
)]
function to_llvm(lst::Vector{Cuint})
vals = LLVM.Value[]
push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
for i in lst
push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
end
return vals
end

extracted = LLVM.Value[]

Expand Down Expand Up @@ -3972,14 +3972,6 @@ function copy_floats_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::
Cuint[],
jltype,
)]
function to_llvm(lst::Vector{Cuint})
vals = LLVM.Value[]
push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
for i in lst
push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
end
return vals
end

extracted = LLVM.Value[]

Expand All @@ -3991,8 +3983,8 @@ function copy_floats_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::
end

if isa(ty, LLVM.FloatingPointType)
dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloc")
srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloc")
dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloccf")
srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloccf")
val = load!(builder, ty, srcloc)
st = store!(builder, val, dstloc)
continue
Expand Down Expand Up @@ -4035,16 +4027,17 @@ function extract_nonjlvalues_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMTyp
Cuint[],
jltype,
)]
function to_llvm(lst::Vector{Cuint})
vals = LLVM.Value[]
push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
for i in lst
push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
end
return vals
end

extracted = LLVM.Value[]
extracted = LLVM.Value[]

if addrspace(value_type(dst)) == 10
PT2 = if LLVM.is_opaque(value_type(dst))
LLVM.PointerType(11)
else
LLVM.PointerType(eltype(value_type(dst)), 11)
end
dst = addrspacecast!(builder, PT2, dst)
end

while length(todo) != 0
path, ty = popfirst!(todo)
Expand Down Expand Up @@ -4082,14 +4075,135 @@ function extract_nonjlvalues_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMTyp
continue
end

dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloc")
dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstlocnjl")
val = Enzyme.API.e_extract_value!(builder, src, path)
st = store!(builder, val, dstloc)
end

return nothing
end

function extract_struct_into!(builder::LLVM.IRBuilder, dst::LLVM.Value, src::LLVM.Value)
count = 0
jltype = value_type(src)
todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
Cuint[],
jltype,
)]

extracted = LLVM.Value[]

if addrspace(value_type(dst)) == 10
PT2 = if LLVM.is_opaque(value_type(dst))
LLVM.PointerType(11)
else
LLVM.PointerType(eltype(value_type(dst)), 11)
end
dst = addrspacecast!(builder, PT2, dst)
end

while length(todo) != 0
path, ty = popfirst!(todo)

if isa(ty, LLVM.ArrayType) && any_jltypes(ty)
for i = 1:length(ty)
npath = copy(path)
push!(npath, i - 1)
push!(todo, (npath, eltype(ty)))
end
continue
end

if isa(ty, LLVM.VectorType) && any_jltypes(ty)
for i = 1:size(ty)
npath = copy(path)
push!(npath, i - 1)
push!(todo, (npath, eltype(ty)))
end
continue
end

if isa(ty, LLVM.StructType) && any_jltypes(ty)
for (i, t) in enumerate(LLVM.elements(ty))
npath = copy(path)
push!(npath, i - 1)
push!(todo, (npath, t))
end
continue
end

dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstlocsi")
val = length(path) == 0 ? src : Enzyme.API.e_extract_value!(builder, src, path)
st = store!(builder, val, dstloc)
end

return nothing
end

function copy_struct_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::LLVM.Value, src::LLVM.Value)
count = 0
todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
Cuint[],
jltype,
)]

extracted = LLVM.Value[]

if addrspace(value_type(dst)) == 10
PT2 = if LLVM.is_opaque(value_type(dst))
LLVM.PointerType(11)
else
LLVM.PointerType(eltype(value_type(dst)), 11)
end
dst = addrspacecast!(builder, PT2, dst)
end

if addrspace(value_type(src)) == 10
PT2 = if LLVM.is_opaque(value_type(src))
LLVM.PointerType(11)
else
LLVM.PointerType(eltype(value_type(src)), 11)
end
src = addrspacecast!(builder, src, PT2)
end

while length(todo) != 0
path, ty = popfirst!(todo)

if isa(ty, LLVM.ArrayType) && any_jltypes(ty)
for i = 1:length(ty)
npath = copy(path)
push!(npath, i - 1)
push!(todo, (npath, eltype(ty)))
end
continue
end

if isa(ty, LLVM.VectorType) && any_jltypes(ty)
for i = 1:size(ty)
npath = copy(path)
push!(npath, i - 1)
push!(todo, (npath, eltype(ty)))
end
continue
end

if isa(ty, LLVM.StructType) && any_jltypes(ty)
for (i, t) in enumerate(LLVM.elements(ty))
npath = copy(path)
push!(npath, i - 1)
push!(todo, (npath, t))
end
continue
end

dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloccs")
srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloccs")
val = load!(builder, ty, srcloc)
st = store!(builder, val, dstloc)
end
return nothing
end

# Modified from GPUCompiler/src/irgen.jl:365 lower_byval
function lower_convention(
Expand Down
33 changes: 33 additions & 0 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ end
LLVM.@function_pass "jl-inst-simplify" JLInstSimplifyPass
LLVM.@module_pass "preserve-nvvm" PreserveNVVMPass
LLVM.@module_pass "preserve-nvvm-end" PreserveNVVMEndPass
LLVM.@module_pass "simple-gvn" SimpleGVNPass

const RunAttributor = Ref(VERSION < v"1.12")

Expand Down Expand Up @@ -298,12 +299,29 @@ function addOptimizationPasses!(mpm::LLVM.NewPMPassManager)
end
end

if VERSION < v"1.14.0-DEV.61"
import Libdl
const RUN_ASAN_PASS = any(contains("libclang_rt.asan"), Libdl.dllist())
end

function addMachinePasses!(mpm::LLVM.NewPMPassManager)
add!(mpm, NewPMFunctionPassManager()) do fpm
if VERSION < v"1.12.0-DEV.1390"
add!(fpm, CombineMulAddPass())
end
add!(fpm, DivRemPairsPass())
add!(fpm, AnnotationRemarksPass())
end
@static if VERSION >= v"1.14.0-DEV.61"
if Base.JLOptions().target_sanitize_address
add!(mpm, AddressSanitizerPass())
end
else
if RUN_ASAN_PASS
add!(mpm, AddressSanitizerPass())
end
end
add!(mpm, NewPMFunctionPassManager()) do fpm
add!(fpm, DemoteFloat16Pass())
add!(fpm, GVNPass())
end
Expand Down Expand Up @@ -361,7 +379,22 @@ const DumpPostCallConv = Ref(false)

function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true)
addr13NoAlias(mod)

removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)


memcpy_sret_split!(mod)
# if we did the move_sret_tofrom_roots, we will have loaded out of the sret, then stored into the rooted.
# we should forward the value we actually stored [fixing the sret to therefore be writeonly and also ensuring
# we can find the root store from the jlvaluet]
# Instcombine breaks apart struct stores into individual components
run!(InstCombinePass(), mod)
# GVN actually forwards
@dispose pb = NewPMPassBuilder() begin
registerEnzymeAndPassPipeline!(pb)
add!(pb, SimpleGVNPass())
run!(pb, mod, tm)
end
if DumpPreCallConv[]
API.EnzymeDumpModuleRef(mod.ref)
end
Expand Down
16 changes: 13 additions & 3 deletions src/errors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1184,11 +1184,21 @@ function julia_error(
elseif errtype == API.ET_InternalError
throw(EnzymeInternalError(msg, ir, bt))
elseif errtype == API.ET_GCRewrite
msgN = sprint() do io::IO
data2 = LLVM.Value(data2)
fn = LLVM.Function(LLVM.API.LLVMGetParamParent(data2::LLVM.Argument))
@static if VERSION < v"1.11"
sretkind = LLVM.kind(if LLVM.version().major >= 12
LLVM.TypeAttribute("sret", LLVM.Int32Type())
else
LLVM.EnumAttribute("sret")
end)
if occursin("Could not find use of stored value", msg) && length(parameters(fn)) >= 1 && any(LLVM.kind(attr) == sretkind for attr in collect(LLVM.parameter_attributes(fn, 1)))
return C_NULL
end
end
msgN = sprint() do io::IO
print(io, msg)
println(io)
data2 = LLVM.Value(data2)
fn = LLVM.Function(LLVM.API.LLVMGetParamParent(data2::LLVM.Argument))
println(io, "Fn = ", string(fn))
println(io, "arg = ", string(data2::LLVM.Argument))
if data !== C_NULL
Expand Down
Loading
Loading