From 3e21cd9e98ae3fff645a7e65b9e286ebda04cd27 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Oct 2025 14:30:15 -0500 Subject: [PATCH 1/5] feat: julia api to access device properties [skip ci] --- src/Compiler.jl | 41 +++++++++++++---------- src/xla/Device.jl | 76 ++++++++++++++++++++++++++++++++++++++++++ src/xla/IFRT/Device.jl | 8 +++++ src/xla/PJRT/Device.jl | 8 +++++ src/xla/Stats.jl | 2 +- src/xla/XLA.jl | 9 ----- 6 files changed, 117 insertions(+), 27 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index a6be948b1c..5823c9a07b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1395,7 +1395,8 @@ function __get_compile_options_and_kwargs(; end function compile_mlir(f, args; client=nothing, kwargs...) - backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend()) + client = client !== nothing ? client : XLA.default_backend() + backend = XLA.platform_name(client) if backend == "CUDA" backend = "GPU" @@ -1414,6 +1415,7 @@ function compile_mlir(f, args; client=nothing, kwargs...) compile_options; backend, runtime=XLA.runtime(client), + client, kwargs..., ) @@ -1430,11 +1432,9 @@ end const PartitionKA = Ref{Bool}(true) -const cubinChip = Ref{String}("sm_60") -const cubinFormat = Ref{String}("bin") const cuindexBitWidth = Ref{Int}(32) +const cubinFormat = Ref{String}("bin") const cuOptLevel = Ref{Int}(2) -const cuWarpSize = Ref{Int}(32) # Wgatever the relevant highest version from our LLVM is within NVPTX.td # Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684 @@ -1580,8 +1580,11 @@ function compile_mlir!( backend="gpu", runtime::Union{Val{:PJRT},Val{:IFRT}}, legalize_stablehlo_to_mhlo::Bool=false, + client=nothing, kwargs..., ) + client = client !== nothing ? client : XLA.default_backend() + # Explicitly don't use block! to avoid creating a closure, which creates # both compile-time and relocatability issues @@ -1655,25 +1658,27 @@ function compile_mlir!( else jit = "lower-jit{openmp=$(OpenMP[]) backend=cpu},symbol-dce" end - elseif DEBUG_KERNEL[] - curesulthandler = dlsym( - Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" - ) - @assert curesulthandler !== nothing - curesulthandler = Base.reinterpret(UInt, curesulthandler) + else kern = if is_raising "lower-kernel{backend=cpu},symbol-dce,canonicalize" else "lower-kernel,canonicalize" end - jit = "lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" - else - kern = if is_raising - "lower-kernel{backend=cpu},symbol-dce,canonicalize" + + device_properties = XLA.device_properties(XLA.default_device(client)) + cubinChip = "sm_$(device_properties.major)$(device_properties.minor)" + + if DEBUG_KERNEL[] + curesulthandler = dlsym( + Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" + ) + @assert curesulthandler !== nothing + curesulthandler = Base.reinterpret(UInt, curesulthandler) + extra_lowerjit_options = "debug=true cuResultHandlerPtr=$curesulthandler " else - "lower-kernel,canonicalize" + extra_lowerjit_options = "" end - jit = "lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" + jit = "lower-jit{$(extra_lowerjit_options)cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" end recognize_comms = true @@ -3477,7 +3482,8 @@ function compile_xla( context_gc_vector[ctx] = Vector{Union{TracedRArray,TracedRNumber}}(undef, 0) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend()) + client = client !== nothing ? client : XLA.default_backend() + backend = XLA.platform_name(client) if backend == "CUDA" backend = "GPU" @@ -3498,6 +3504,7 @@ function compile_xla( compile_options; backend, runtime=XLA.runtime(client), + client, kwargs..., ) diff --git a/src/xla/Device.jl b/src/xla/Device.jl index 19e9ef737f..5e28cc3ce3 100644 --- a/src/xla/Device.jl +++ b/src/xla/Device.jl @@ -11,6 +11,7 @@ function device_kind end function default_memory end function memories end function is_addressable end +function get_local_hardware_id end """ device_ordinal(device::Device) @@ -29,3 +30,78 @@ end function is_addressable(device::AbstractDevice) return device ∈ addressable_devices(client(device)) end + +# Keep in sync with API.cpp +struct DeviceProperties + total_global_mem::Csize_t + shared_mem_per_block::Csize_t + regs_per_block::Cint + warp_size::Cint + max_threads_per_block::Cint + max_threads_dim::NTuple{3,Cint} + max_grid_size::NTuple{3,Cint} + clock_rate::Cint + total_const_mem::Csize_t + major::Cint + minor::Cint + multi_processor_count::Cint + can_map_host_memory::Cint + compute_mode::Cint + l2_cache_size::Cint + max_threads_per_multiprocessor::Cint +end + +const DEVICE_PROPERTIES_CACHE = Dict{Tuple{Int,String},DeviceProperties}() + +""" + device_properties(device::AbstractDevice) + +Get a struct containing device properties. Which exact fields are populated relies on the +underlying device implementation. +""" +function device_properties(device::AbstractDevice) + pname = platform_name(client(device)) + local_hardware_id = get_local_hardware_id(device) + + if haskey(DEVICE_PROPERTIES_CACHE, (local_hardware_id, pname)) + return DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)] + end + + jldevprops = Ref{DeviceProperties}() + if pname == "cuda" + GC.@preserve jldevprops begin + @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetProperties( + jldevprops::Ptr{Cvoid}, local_hardware_id::Cint + )::Cvoid + end + else + @warn "`get_properties` not implemented for platform: $(pname)" maxlog = 1 + end + DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)] = jldevprops[] + return jldevprops[] +end + +function Base.show(io::IO, ::MIME"text/plain", props::DeviceProperties) + return print( + io, + """ + DeviceProperties + ---------------- + Total Global Mem: $(_format_bytes(props.total_global_mem)) + Shared Mem Per Block: $(_format_bytes(props.shared_mem_per_block)) + Regs Per Block: $(props.regs_per_block) + Warp Size: $(props.warp_size) + Max Threads Per Block: $(props.max_threads_per_block) + Max Threads Dim: $(props.max_threads_dim) + Max Grid Size: $(props.max_grid_size) + Clock Rate: $(props.clock_rate) + Total Const Mem: $(_format_bytes(props.total_const_mem)) + Version: $(VersionNumber(props.major, props.minor)) + Multi Processor Count: $(props.multi_processor_count) + Can Map Host Memory: $(props.can_map_host_memory) + Compute Mode: $(props.compute_mode) + L2 Cache Size: $(props.l2_cache_size) + Max Threads Per Multiprocessor: $(props.max_threads_per_multiprocessor) + """, + ) +end diff --git a/src/xla/IFRT/Device.jl b/src/xla/IFRT/Device.jl index 7d269e166c..dba2c98cd5 100644 --- a/src/xla/IFRT/Device.jl +++ b/src/xla/IFRT/Device.jl @@ -31,6 +31,14 @@ function XLA.get_local_device_id(::Device) return error("Not implemented for ifrt devices") end +function XLA.get_local_hardware_id(::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.ifrt_DeviceGetLocalHardwareId( + device.device::Ptr{Cvoid} + )::Cint + end +end + function XLA.default_memory(device::Device) GC.@preserve device begin return Memory( diff --git a/src/xla/PJRT/Device.jl b/src/xla/PJRT/Device.jl index 2a29c6279b..4a4dd178e7 100644 --- a/src/xla/PJRT/Device.jl +++ b/src/xla/PJRT/Device.jl @@ -33,6 +33,14 @@ function XLA.get_local_device_id(device::Device) end end +function XLA.get_local_hardware_id(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalHardwareId( + device.device::Ptr{Cvoid} + )::Cint + end +end + function XLA.is_addressable(device::Device) GC.@preserve device begin return @ccall MLIR.API.mlir_c.pjrt_device_is_addressable( diff --git a/src/xla/Stats.jl b/src/xla/Stats.jl index bc66cc348a..59f62609c2 100644 --- a/src/xla/Stats.jl +++ b/src/xla/Stats.jl @@ -13,7 +13,7 @@ struct JLAllocatorStats peak_pool_bytes::Int64 end -_format_bytes(x) = Base.format_bytes(x) +_format_bytes(x) = x < 0 ? nothing : Base.format_bytes(x) _format_bytes(x::Nothing) = x """ diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 1a7ffc17f2..f14139b890 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -234,15 +234,6 @@ for runtime in (:PJRT, :IFRT) ) state.clients["cuda"] = gpu state.default_client = gpu - - # set values for cuda. This is being done here since we need cuda - # to be initialized before we can use it. initializing the devices - # implicitly initializes cuda. - cc_major = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMajor()::Int32 - cc_minor = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMinor()::Int32 - Reactant.Compiler.cubinChip[] = "sm_$(cc_major)$(cc_minor)" - - Reactant.Compiler.cuWarpSize[] = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetWarpSizeInThreads()::Int32 catch e println(stdout, e) end From 74fa66d2a7fa79a7cdd3365cebad1d75c7a01a52 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Oct 2025 16:53:08 -0400 Subject: [PATCH 2/5] fix: apply suggestion from @avik-pal [skp ci] --- src/xla/IFRT/Device.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/xla/IFRT/Device.jl b/src/xla/IFRT/Device.jl index dba2c98cd5..672900454a 100644 --- a/src/xla/IFRT/Device.jl +++ b/src/xla/IFRT/Device.jl @@ -31,7 +31,7 @@ function XLA.get_local_device_id(::Device) return error("Not implemented for ifrt devices") end -function XLA.get_local_hardware_id(::Device) +function XLA.get_local_hardware_id(device::Device) GC.@preserve device begin return @ccall MLIR.API.mlir_c.ifrt_DeviceGetLocalHardwareId( device.device::Ptr{Cvoid} From fc22d8ce1c87532455d3e0be2e72912cf3bd3369 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Oct 2025 07:58:36 -0500 Subject: [PATCH 3/5] chore: bump reactant_jll --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fbde73622a..f08ffd3b90 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.251" +Reactant_jll = "0.0.252" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" From 377dd10215982297b23c331a0743477afef5c161 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Oct 2025 19:23:19 -0500 Subject: [PATCH 4/5] fix: remove deleted fields [skip ci] --- Project.toml | 2 +- src/xla/Device.jl | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index f08ffd3b90..fbde73622a 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.252" +Reactant_jll = "0.0.251" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" diff --git a/src/xla/Device.jl b/src/xla/Device.jl index 5e28cc3ce3..fd76bb6e3e 100644 --- a/src/xla/Device.jl +++ b/src/xla/Device.jl @@ -40,13 +40,11 @@ struct DeviceProperties max_threads_per_block::Cint max_threads_dim::NTuple{3,Cint} max_grid_size::NTuple{3,Cint} - clock_rate::Cint total_const_mem::Csize_t major::Cint minor::Cint multi_processor_count::Cint can_map_host_memory::Cint - compute_mode::Cint l2_cache_size::Cint max_threads_per_multiprocessor::Cint end @@ -94,12 +92,10 @@ function Base.show(io::IO, ::MIME"text/plain", props::DeviceProperties) Max Threads Per Block: $(props.max_threads_per_block) Max Threads Dim: $(props.max_threads_dim) Max Grid Size: $(props.max_grid_size) - Clock Rate: $(props.clock_rate) Total Const Mem: $(_format_bytes(props.total_const_mem)) Version: $(VersionNumber(props.major, props.minor)) Multi Processor Count: $(props.multi_processor_count) Can Map Host Memory: $(props.can_map_host_memory) - Compute Mode: $(props.compute_mode) L2 Cache Size: $(props.l2_cache_size) Max Threads Per Multiprocessor: $(props.max_threads_per_multiprocessor) """, From d5372aa88820b4a14344aff04f8467ceaae173a4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 19 Oct 2025 10:49:01 -0400 Subject: [PATCH 5/5] chore: bump jll --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fbde73622a..f08ffd3b90 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.251" +Reactant_jll = "0.0.252" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10"