Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ PythonCall = "0.9.25"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.16"
Reactant_jll = "0.0.251"
Reactant_jll = "0.0.252"
ScopedValues = "1.3.0"
Scratch = "1.2"
Sockets = "1.10"
Expand Down
41 changes: 24 additions & 17 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,8 @@ function __get_compile_options_and_kwargs(;
end

function compile_mlir(f, args; client=nothing, kwargs...)
backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend())
client = client !== nothing ? client : XLA.default_backend()
backend = XLA.platform_name(client)

if backend == "CUDA"
backend = "GPU"
Expand All @@ -1414,6 +1415,7 @@ function compile_mlir(f, args; client=nothing, kwargs...)
compile_options;
backend,
runtime=XLA.runtime(client),
client,
kwargs...,
)

Expand All @@ -1430,11 +1432,9 @@ end

const PartitionKA = Ref{Bool}(true)

const cubinChip = Ref{String}("sm_60")
const cubinFormat = Ref{String}("bin")
const cuindexBitWidth = Ref{Int}(32)
const cubinFormat = Ref{String}("bin")
const cuOptLevel = Ref{Int}(2)
const cuWarpSize = Ref{Int}(32)

# Wgatever the relevant highest version from our LLVM is within NVPTX.td
# Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684
Expand Down Expand Up @@ -1580,8 +1580,11 @@ function compile_mlir!(
backend="gpu",
runtime::Union{Val{:PJRT},Val{:IFRT}},
legalize_stablehlo_to_mhlo::Bool=false,
client=nothing,
kwargs...,
)
client = client !== nothing ? client : XLA.default_backend()

# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues

Expand Down Expand Up @@ -1655,25 +1658,27 @@ function compile_mlir!(
else
jit = "lower-jit{openmp=$(OpenMP[]) backend=cpu},symbol-dce"
end
elseif DEBUG_KERNEL[]
curesulthandler = dlsym(
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
)
@assert curesulthandler !== nothing
curesulthandler = Base.reinterpret(UInt, curesulthandler)
else
kern = if is_raising
"lower-kernel{backend=cpu},symbol-dce,canonicalize"
else
"lower-kernel,canonicalize"
end
jit = "lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
else
kern = if is_raising
"lower-kernel{backend=cpu},symbol-dce,canonicalize"

device_properties = XLA.device_properties(XLA.default_device(client))
cubinChip = "sm_$(device_properties.major)$(device_properties.minor)"

if DEBUG_KERNEL[]
curesulthandler = dlsym(
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
)
@assert curesulthandler !== nothing
curesulthandler = Base.reinterpret(UInt, curesulthandler)
extra_lowerjit_options = "debug=true cuResultHandlerPtr=$curesulthandler "
else
"lower-kernel,canonicalize"
extra_lowerjit_options = ""
end
jit = "lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
jit = "lower-jit{$(extra_lowerjit_options)cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
end

recognize_comms = true
Expand Down Expand Up @@ -3477,7 +3482,8 @@ function compile_xla(
context_gc_vector[ctx] = Vector{Union{TracedRArray,TracedRNumber}}(undef, 0)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid

backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend())
client = client !== nothing ? client : XLA.default_backend()
backend = XLA.platform_name(client)

if backend == "CUDA"
backend = "GPU"
Expand All @@ -3498,6 +3504,7 @@ function compile_xla(
compile_options;
backend,
runtime=XLA.runtime(client),
client,
kwargs...,
)

Expand Down
72 changes: 72 additions & 0 deletions src/xla/Device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ function device_kind end
function default_memory end
function memories end
function is_addressable end
function get_local_hardware_id end

"""
device_ordinal(device::Device)
Expand All @@ -29,3 +30,74 @@ end
function is_addressable(device::AbstractDevice)
return device ∈ addressable_devices(client(device))
end

# Keep in sync with API.cpp
struct DeviceProperties
total_global_mem::Csize_t
shared_mem_per_block::Csize_t
regs_per_block::Cint
warp_size::Cint
max_threads_per_block::Cint
max_threads_dim::NTuple{3,Cint}
max_grid_size::NTuple{3,Cint}
total_const_mem::Csize_t
major::Cint
minor::Cint
multi_processor_count::Cint
can_map_host_memory::Cint
l2_cache_size::Cint
max_threads_per_multiprocessor::Cint
end

const DEVICE_PROPERTIES_CACHE = Dict{Tuple{Int,String},DeviceProperties}()

"""
device_properties(device::AbstractDevice)

Get a struct containing device properties. Which exact fields are populated relies on the
underlying device implementation.
"""
function device_properties(device::AbstractDevice)
pname = platform_name(client(device))
local_hardware_id = get_local_hardware_id(device)

if haskey(DEVICE_PROPERTIES_CACHE, (local_hardware_id, pname))
return DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)]
end

jldevprops = Ref{DeviceProperties}()
if pname == "cuda"
GC.@preserve jldevprops begin
@ccall MLIR.API.mlir_c.ReactantCudaDeviceGetProperties(
jldevprops::Ptr{Cvoid}, local_hardware_id::Cint
)::Cvoid
end
else
@warn "`get_properties` not implemented for platform: $(pname)" maxlog = 1
end
DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)] = jldevprops[]
return jldevprops[]
end

function Base.show(io::IO, ::MIME"text/plain", props::DeviceProperties)
return print(
io,
"""
DeviceProperties
----------------
Total Global Mem: $(_format_bytes(props.total_global_mem))
Shared Mem Per Block: $(_format_bytes(props.shared_mem_per_block))
Regs Per Block: $(props.regs_per_block)
Warp Size: $(props.warp_size)
Max Threads Per Block: $(props.max_threads_per_block)
Max Threads Dim: $(props.max_threads_dim)
Max Grid Size: $(props.max_grid_size)
Total Const Mem: $(_format_bytes(props.total_const_mem))
Version: $(VersionNumber(props.major, props.minor))
Multi Processor Count: $(props.multi_processor_count)
Can Map Host Memory: $(props.can_map_host_memory)
L2 Cache Size: $(props.l2_cache_size)
Max Threads Per Multiprocessor: $(props.max_threads_per_multiprocessor)
""",
)
end
8 changes: 8 additions & 0 deletions src/xla/IFRT/Device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ function XLA.get_local_device_id(::Device)
return error("Not implemented for ifrt devices")
end

function XLA.get_local_hardware_id(device::Device)
GC.@preserve device begin
return @ccall MLIR.API.mlir_c.ifrt_DeviceGetLocalHardwareId(
device.device::Ptr{Cvoid}
)::Cint
end
end

function XLA.default_memory(device::Device)
GC.@preserve device begin
return Memory(
Expand Down
8 changes: 8 additions & 0 deletions src/xla/PJRT/Device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ function XLA.get_local_device_id(device::Device)
end
end

function XLA.get_local_hardware_id(device::Device)
GC.@preserve device begin
return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalHardwareId(
device.device::Ptr{Cvoid}
)::Cint
end
end

function XLA.is_addressable(device::Device)
GC.@preserve device begin
return @ccall MLIR.API.mlir_c.pjrt_device_is_addressable(
Expand Down
2 changes: 1 addition & 1 deletion src/xla/Stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct JLAllocatorStats
peak_pool_bytes::Int64
end

_format_bytes(x) = Base.format_bytes(x)
_format_bytes(x) = x < 0 ? nothing : Base.format_bytes(x)
_format_bytes(x::Nothing) = x

"""
Expand Down
9 changes: 0 additions & 9 deletions src/xla/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,6 @@ for runtime in (:PJRT, :IFRT)
)
state.clients["cuda"] = gpu
state.default_client = gpu

# set values for cuda. This is being done here since we need cuda
# to be initialized before we can use it. initializing the devices
# implicitly initializes cuda.
cc_major = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMajor()::Int32
cc_minor = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMinor()::Int32
Reactant.Compiler.cubinChip[] = "sm_$(cc_major)$(cc_minor)"

Reactant.Compiler.cuWarpSize[] = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetWarpSizeInThreads()::Int32
catch e
println(stdout, e)
end
Expand Down
Loading