diff --git a/Project.toml b/Project.toml index fbde73622a..f08ffd3b90 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.251" +Reactant_jll = "0.0.252" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" diff --git a/src/Compiler.jl b/src/Compiler.jl index a6be948b1c..5823c9a07b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1395,7 +1395,8 @@ function __get_compile_options_and_kwargs(; end function compile_mlir(f, args; client=nothing, kwargs...) - backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend()) + client = client !== nothing ? client : XLA.default_backend() + backend = XLA.platform_name(client) if backend == "CUDA" backend = "GPU" @@ -1414,6 +1415,7 @@ function compile_mlir(f, args; client=nothing, kwargs...) compile_options; backend, runtime=XLA.runtime(client), + client, kwargs..., ) @@ -1430,11 +1432,9 @@ end const PartitionKA = Ref{Bool}(true) -const cubinChip = Ref{String}("sm_60") -const cubinFormat = Ref{String}("bin") const cuindexBitWidth = Ref{Int}(32) +const cubinFormat = Ref{String}("bin") const cuOptLevel = Ref{Int}(2) -const cuWarpSize = Ref{Int}(32) # Wgatever the relevant highest version from our LLVM is within NVPTX.td # Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684 @@ -1580,8 +1580,11 @@ function compile_mlir!( backend="gpu", runtime::Union{Val{:PJRT},Val{:IFRT}}, legalize_stablehlo_to_mhlo::Bool=false, + client=nothing, kwargs..., ) + client = client !== nothing ? client : XLA.default_backend() + # Explicitly don't use block! to avoid creating a closure, which creates # both compile-time and relocatability issues @@ -1655,25 +1658,27 @@ function compile_mlir!( else jit = "lower-jit{openmp=$(OpenMP[]) backend=cpu},symbol-dce" end - elseif DEBUG_KERNEL[] - curesulthandler = dlsym( - Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" - ) - @assert curesulthandler !== nothing - curesulthandler = Base.reinterpret(UInt, curesulthandler) + else kern = if is_raising "lower-kernel{backend=cpu},symbol-dce,canonicalize" else "lower-kernel,canonicalize" end - jit = "lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" - else - kern = if is_raising - "lower-kernel{backend=cpu},symbol-dce,canonicalize" + + device_properties = XLA.device_properties(XLA.default_device(client)) + cubinChip = "sm_$(device_properties.major)$(device_properties.minor)" + + if DEBUG_KERNEL[] + curesulthandler = dlsym( + Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" + ) + @assert curesulthandler !== nothing + curesulthandler = Base.reinterpret(UInt, curesulthandler) + extra_lowerjit_options = "debug=true cuResultHandlerPtr=$curesulthandler " else - "lower-kernel,canonicalize" + extra_lowerjit_options = "" end - jit = "lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" + jit = "lower-jit{$(extra_lowerjit_options)cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" end recognize_comms = true @@ -3477,7 +3482,8 @@ function compile_xla( context_gc_vector[ctx] = Vector{Union{TracedRArray,TracedRNumber}}(undef, 0) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend()) + client = client !== nothing ? client : XLA.default_backend() + backend = XLA.platform_name(client) if backend == "CUDA" backend = "GPU" @@ -3498,6 +3504,7 @@ function compile_xla( compile_options; backend, runtime=XLA.runtime(client), + client, kwargs..., ) diff --git a/src/xla/Device.jl b/src/xla/Device.jl index 19e9ef737f..fd76bb6e3e 100644 --- a/src/xla/Device.jl +++ b/src/xla/Device.jl @@ -11,6 +11,7 @@ function device_kind end function default_memory end function memories end function is_addressable end +function get_local_hardware_id end """ device_ordinal(device::Device) @@ -29,3 +30,74 @@ end function is_addressable(device::AbstractDevice) return device ∈ addressable_devices(client(device)) end + +# Keep in sync with API.cpp +struct DeviceProperties + total_global_mem::Csize_t + shared_mem_per_block::Csize_t + regs_per_block::Cint + warp_size::Cint + max_threads_per_block::Cint + max_threads_dim::NTuple{3,Cint} + max_grid_size::NTuple{3,Cint} + total_const_mem::Csize_t + major::Cint + minor::Cint + multi_processor_count::Cint + can_map_host_memory::Cint + l2_cache_size::Cint + max_threads_per_multiprocessor::Cint +end + +const DEVICE_PROPERTIES_CACHE = Dict{Tuple{Int,String},DeviceProperties}() + +""" + device_properties(device::AbstractDevice) + +Get a struct containing device properties. Which exact fields are populated relies on the +underlying device implementation. +""" +function device_properties(device::AbstractDevice) + pname = platform_name(client(device)) + local_hardware_id = get_local_hardware_id(device) + + if haskey(DEVICE_PROPERTIES_CACHE, (local_hardware_id, pname)) + return DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)] + end + + jldevprops = Ref{DeviceProperties}() + if pname == "cuda" + GC.@preserve jldevprops begin + @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetProperties( + jldevprops::Ptr{Cvoid}, local_hardware_id::Cint + )::Cvoid + end + else + @warn "`get_properties` not implemented for platform: $(pname)" maxlog = 1 + end + DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)] = jldevprops[] + return jldevprops[] +end + +function Base.show(io::IO, ::MIME"text/plain", props::DeviceProperties) + return print( + io, + """ + DeviceProperties + ---------------- + Total Global Mem: $(_format_bytes(props.total_global_mem)) + Shared Mem Per Block: $(_format_bytes(props.shared_mem_per_block)) + Regs Per Block: $(props.regs_per_block) + Warp Size: $(props.warp_size) + Max Threads Per Block: $(props.max_threads_per_block) + Max Threads Dim: $(props.max_threads_dim) + Max Grid Size: $(props.max_grid_size) + Total Const Mem: $(_format_bytes(props.total_const_mem)) + Version: $(VersionNumber(props.major, props.minor)) + Multi Processor Count: $(props.multi_processor_count) + Can Map Host Memory: $(props.can_map_host_memory) + L2 Cache Size: $(props.l2_cache_size) + Max Threads Per Multiprocessor: $(props.max_threads_per_multiprocessor) + """, + ) +end diff --git a/src/xla/IFRT/Device.jl b/src/xla/IFRT/Device.jl index 7d269e166c..672900454a 100644 --- a/src/xla/IFRT/Device.jl +++ b/src/xla/IFRT/Device.jl @@ -31,6 +31,14 @@ function XLA.get_local_device_id(::Device) return error("Not implemented for ifrt devices") end +function XLA.get_local_hardware_id(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.ifrt_DeviceGetLocalHardwareId( + device.device::Ptr{Cvoid} + )::Cint + end +end + function XLA.default_memory(device::Device) GC.@preserve device begin return Memory( diff --git a/src/xla/PJRT/Device.jl b/src/xla/PJRT/Device.jl index 2a29c6279b..4a4dd178e7 100644 --- a/src/xla/PJRT/Device.jl +++ b/src/xla/PJRT/Device.jl @@ -33,6 +33,14 @@ function XLA.get_local_device_id(device::Device) end end +function XLA.get_local_hardware_id(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalHardwareId( + device.device::Ptr{Cvoid} + )::Cint + end +end + function XLA.is_addressable(device::Device) GC.@preserve device begin return @ccall MLIR.API.mlir_c.pjrt_device_is_addressable( diff --git a/src/xla/Stats.jl b/src/xla/Stats.jl index bc66cc348a..59f62609c2 100644 --- a/src/xla/Stats.jl +++ b/src/xla/Stats.jl @@ -13,7 +13,7 @@ struct JLAllocatorStats peak_pool_bytes::Int64 end -_format_bytes(x) = Base.format_bytes(x) +_format_bytes(x) = x < 0 ? nothing : Base.format_bytes(x) _format_bytes(x::Nothing) = x """ diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 1a7ffc17f2..f14139b890 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -234,15 +234,6 @@ for runtime in (:PJRT, :IFRT) ) state.clients["cuda"] = gpu state.default_client = gpu - - # set values for cuda. This is being done here since we need cuda - # to be initialized before we can use it. initializing the devices - # implicitly initializes cuda. - cc_major = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMajor()::Int32 - cc_minor = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMinor()::Int32 - Reactant.Compiler.cubinChip[] = "sm_$(cc_major)$(cc_minor)" - - Reactant.Compiler.cuWarpSize[] = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetWarpSizeInThreads()::Int32 catch e println(stdout, e) end