Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 83 additions & 22 deletions ext/AtomixCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,96 @@ end
ptr = Atomix.pointer(ref)
expected = convert(eltype(ref), expected)
desired = convert(eltype(ref), desired)
begin
old = CUDA.atomic_cas!(ptr, expected, desired)
end
old = _cuda_atomic_cas!(ptr, expected, desired)
return (; old = old, success = old === expected)
end

# Native CUDA CAS for supported types
@inline function _cuda_atomic_cas!(ptr::Core.LLVMPtr{T,A}, cmp::T, new::T) where {T,A}
CUDA.atomic_cas!(ptr, cmp, new)
end

# Complex CAS - using separate CAS on real and imaginary components
# Note: This is NOT fully atomic (components updated separately)
# but works for both ComplexF32 and ComplexF64
Comment on lines +37 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof this is a no-go in my opinion. You will thus easily get torn writes.

I think this would need to use 128-byte atomics

@inline function _cuda_atomic_cas!(ptr::Core.LLVMPtr{Complex{T},A}, cmp::Complex{T}, new::Complex{T}) where {T<:Union{Float32,Float64},A}
# Get pointers to real and imaginary components
ptr_re = Base.bitcast(Core.LLVMPtr{T,A}, ptr)
ptr_im = Base.bitcast(Core.LLVMPtr{T,A}, ptr + sizeof(T))

# CAS on real part
old_re = CUDA.atomic_cas!(ptr_re, cmp.re, new.re)
# CAS on imaginary part
old_im = CUDA.atomic_cas!(ptr_im, cmp.im, new.im)

# Return just the old value for consistency with non-Complex CUDA CAS
# Note: The caller checks success by comparing old === expected
# This works because if both components match, the Complex values will be equal
return Complex{T}(old_re, old_im)
end

@inline function Atomix.modify!(ref::CuIndexableRef, op::OP, x, order) where {OP}
x = convert(eltype(ref), x)
ptr = Atomix.pointer(ref)
begin
old = if op === (+)
CUDA.atomic_add!(ptr, x)
elseif op === (-)
CUDA.atomic_sub!(ptr, x)
elseif op === (&)
CUDA.atomic_and!(ptr, x)
elseif op === (|)
CUDA.atomic_or!(ptr, x)
elseif op === xor
CUDA.atomic_xor!(ptr, x)
elseif op === min
CUDA.atomic_min!(ptr, x)
elseif op === max
CUDA.atomic_max!(ptr, x)
else
error("not implemented")
end
end
old = _cuda_atomic_modify!(ptr, op, x)
return old => op(old, x)
end

# Native CUDA atomic operations for supported types
@inline function _cuda_atomic_modify!(ptr::Core.LLVMPtr{T,A}, op::OP, x::T) where {T,A,OP}
if op === (+)
CUDA.atomic_add!(ptr, x)
elseif op === (-)
CUDA.atomic_sub!(ptr, x)
elseif op === (&)
CUDA.atomic_and!(ptr, x)
elseif op === (|)
CUDA.atomic_or!(ptr, x)
elseif op === xor
CUDA.atomic_xor!(ptr, x)
elseif op === min
CUDA.atomic_min!(ptr, x)
elseif op === max
CUDA.atomic_max!(ptr, x)
else
error("not implemented")
end
end

# Complex atomic operations - separate atomics on real and imaginary parts
# This works for operations that decompose component-wise (+, -, right)
# Note: This provides per-component atomicity, not full Complex atomicity
# (other threads may observe intermediate states, but final result is correct)
@inline function _cuda_atomic_modify!(ptr::Core.LLVMPtr{Complex{T},A}, op::OP, x::Complex{T}) where {T<:Union{Float32,Float64},A,OP}
Comment on lines +83 to +87
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, you are not gurantueed that a user is only using one kind of atomic operation on a memory location.

(e.g. someone doing a mul for good measure).

# Get pointers to real and imaginary components
ptr_re = Base.bitcast(Core.LLVMPtr{T,A}, ptr)
ptr_im = Base.bitcast(Core.LLVMPtr{T,A}, ptr + sizeof(T))

if op === (+)
old_re = CUDA.atomic_add!(ptr_re, x.re)
old_im = CUDA.atomic_add!(ptr_im, x.im)
return Complex{T}(old_re, old_im)
elseif op === (-)
old_re = CUDA.atomic_sub!(ptr_re, x.re)
old_im = CUDA.atomic_sub!(ptr_im, x.im)
return Complex{T}(old_re, old_im)
else
# For other operations (like right for swap), use CAS loop
# Read the old value component by component (not atomic together)
old_re = CUDA.atomic_add!(ptr_re, zero(T)) # atomic read
old_im = CUDA.atomic_add!(ptr_im, zero(T)) # atomic read
old = Complex{T}(old_re, old_im)

# Compute new value
new = op(old, x)

# Try to swap using CAS (will only succeed if value hasn't changed)
# This is a simplified version - a full CAS loop would be more robust
_cuda_atomic_cas!(ptr, old, new)

# Return the old value we read
return old
end
end

end # module AtomixCUDAExt
102 changes: 80 additions & 22 deletions ext/AtomixMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,33 @@ end
ptr = Atomix.pointer(ref)
expected = convert(eltype(ref), expected)
desired = convert(eltype(ref), desired)
begin
old = Metal.atomic_compare_exchange_weak_explicit(ptr, expected, desired)
end
old = _metal_atomic_cas!(ptr, expected, desired)
return (; old = old, success = old === expected)
end

# Native Metal CAS for supported types
@inline function _metal_atomic_cas!(ptr::Core.LLVMPtr{T,A}, cmp::T, new::T) where {T,A}
Metal.atomic_compare_exchange_weak_explicit(ptr, cmp, new)
end

# Complex CAS - using separate CAS on real and imaginary components
# Note: This is NOT fully atomic (components updated separately)
# but works for both ComplexF32 and ComplexF64
@inline function _metal_atomic_cas!(ptr::Core.LLVMPtr{Complex{T},A}, cmp::Complex{T}, new::Complex{T}) where {T<:Union{Float32,Float64},A}
# Get pointers to real and imaginary components
ptr_re = Base.bitcast(Core.LLVMPtr{T,A}, ptr)
ptr_im = Base.bitcast(Core.LLVMPtr{T,A}, ptr + sizeof(T))

# CAS on real part
old_re = Metal.atomic_compare_exchange_weak_explicit(ptr_re, cmp.re, new.re)
# CAS on imaginary part
old_im = Metal.atomic_compare_exchange_weak_explicit(ptr_im, cmp.im, new.im)

# Return just the old value for consistency
# The caller checks success by comparing old === expected
return Complex{T}(old_re, old_im)
end


# CAS is needed for FP ops on ThreadGroup memory
@inline function Atomix.modify!(ref::IndexableRef{<:MtlDeviceArray{<:AbstractFloat, <:Any, Metal.AS.ThreadGroup}} , op::OP, x, order) where {OP}
Expand All @@ -42,26 +63,63 @@ end
@inline function Atomix.modify!(ref::MtlIndexableRef, op::OP, x, order) where {OP}
x = convert(eltype(ref), x)
ptr = Atomix.pointer(ref)
begin
old = if op === (+)
Metal.atomic_fetch_add_explicit(ptr, x)
elseif op === (-)
Metal.atomic_fetch_sub_explicit(ptr, x)
elseif op === (&)
Metal.atomic_fetch_and_explicit(ptr, x)
elseif op === (|)
Metal.atomic_fetch_or_explicit(ptr, x)
elseif op === xor
Metal.atomic_fetch_xor_explicit(ptr, x)
elseif op === min
Metal.atomic_fetch_min_explicit(ptr, x)
elseif op === max
Metal.atomic_fetch_max_explicit(ptr, x)
else
error("not implemented")
end
end
old = _metal_atomic_modify!(ptr, op, x)
return old => op(old, x)
end

# Native Metal atomic operations for supported types
@inline function _metal_atomic_modify!(ptr::Core.LLVMPtr{T,A}, op::OP, x::T) where {T,A,OP}
if op === (+)
Metal.atomic_fetch_add_explicit(ptr, x)
elseif op === (-)
Metal.atomic_fetch_sub_explicit(ptr, x)
elseif op === (&)
Metal.atomic_fetch_and_explicit(ptr, x)
elseif op === (|)
Metal.atomic_fetch_or_explicit(ptr, x)
elseif op === xor
Metal.atomic_fetch_xor_explicit(ptr, x)
elseif op === min
Metal.atomic_fetch_min_explicit(ptr, x)
elseif op === max
Metal.atomic_fetch_max_explicit(ptr, x)
else
error("not implemented")
end
end

# Complex atomic operations - separate atomics on real and imaginary parts
# This works for operations that decompose component-wise (+, -, right)
# Note: This provides per-component atomicity, not full Complex atomicity
@inline function _metal_atomic_modify!(ptr::Core.LLVMPtr{Complex{T},A}, op::OP, x::Complex{T}) where {T<:Union{Float32,Float64},A,OP}
# Get pointers to real and imaginary components
ptr_re = Base.bitcast(Core.LLVMPtr{T,A}, ptr)
ptr_im = Base.bitcast(Core.LLVMPtr{T,A}, ptr + sizeof(T))

if op === (+)
old_re = Metal.atomic_fetch_add_explicit(ptr_re, x.re)
old_im = Metal.atomic_fetch_add_explicit(ptr_im, x.im)
return Complex{T}(old_re, old_im)
elseif op === (-)
old_re = Metal.atomic_fetch_sub_explicit(ptr_re, x.re)
old_im = Metal.atomic_fetch_sub_explicit(ptr_im, x.im)
return Complex{T}(old_re, old_im)
else
# For other operations (like right for swap), use CAS loop
# Read the old value component by component (not atomic together)
old_re = Metal.atomic_fetch_add_explicit(ptr_re, zero(T)) # atomic read
old_im = Metal.atomic_fetch_add_explicit(ptr_im, zero(T)) # atomic read
old = Complex{T}(old_re, old_im)

# Compute new value
new = op(old, x)

# Try to swap using CAS
_metal_atomic_cas!(ptr, old, new)

# Return the old value we read
return old
end
end

end # module AtomixMetalExt
100 changes: 78 additions & 22 deletions ext/AtomixOpenCLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,92 @@ end
ptr = Atomix.pointer(ref)
expected = convert(eltype(ref), expected)
desired = convert(eltype(ref), desired)
begin
old = SPIRVIntrinsics.atomic_cmpxchg!(ptr, expected, desired)
end
old = _opencl_atomic_cas!(ptr, expected, desired)
return (; old = old, success = old === expected)
end

# Native OpenCL CAS for supported types
@inline function _opencl_atomic_cas!(ptr::Core.LLVMPtr{T,A}, cmp::T, new::T) where {T,A}
SPIRVIntrinsics.atomic_cmpxchg!(ptr, cmp, new)
end

# Complex CAS - using separate CAS on real and imaginary components
# Note: This is NOT fully atomic (components updated separately)
# but works for both ComplexF32 and ComplexF64
@inline function _opencl_atomic_cas!(ptr::Core.LLVMPtr{Complex{T},A}, cmp::Complex{T}, new::Complex{T}) where {T<:Union{Float32,Float64},A}
# Get pointers to real and imaginary components
ptr_re = Base.bitcast(Core.LLVMPtr{T,A}, ptr)
ptr_im = Base.bitcast(Core.LLVMPtr{T,A}, ptr + sizeof(T))

# CAS on real part
old_re = SPIRVIntrinsics.atomic_cmpxchg!(ptr_re, cmp.re, new.re)
# CAS on imaginary part
old_im = SPIRVIntrinsics.atomic_cmpxchg!(ptr_im, cmp.im, new.im)

# Return just the old value for consistency
return Complex{T}(old_re, old_im)
end

@inline function Atomix.modify!(ref::CLIndexableRef, op::OP, x, order) where {OP}
x = convert(eltype(ref), x)
ptr = Atomix.pointer(ref)
begin
old = if op === (+)
SPIRVIntrinsics.atomic_add!(ptr, x)
elseif op === (-)
SPIRVIntrinsics.atomic_sub!(ptr, x)
elseif op === (&)
SPIRVIntrinsics.atomic_and!(ptr, x)
elseif op === (|)
SPIRVIntrinsics.atomic_or!(ptr, x)
elseif op === xor
SPIRVIntrinsics.atomic_xor!(ptr, x)
elseif op === min
SPIRVIntrinsics.atomic_min!(ptr, x)
elseif op === max
SPIRVIntrinsics.atomic_max!(ptr, x)
else
error("not implemented")
end
end
old = _opencl_atomic_modify!(ptr, op, x)
return old => op(old, x)
end

# Native OpenCL atomic operations for supported types
@inline function _opencl_atomic_modify!(ptr::Core.LLVMPtr{T,A}, op::OP, x::T) where {T,A,OP}
if op === (+)
SPIRVIntrinsics.atomic_add!(ptr, x)
elseif op === (-)
SPIRVIntrinsics.atomic_sub!(ptr, x)
elseif op === (&)
SPIRVIntrinsics.atomic_and!(ptr, x)
elseif op === (|)
SPIRVIntrinsics.atomic_or!(ptr, x)
elseif op === xor
SPIRVIntrinsics.atomic_xor!(ptr, x)
elseif op === min
SPIRVIntrinsics.atomic_min!(ptr, x)
elseif op === max
SPIRVIntrinsics.atomic_max!(ptr, x)
else
error("not implemented")
end
end

# Complex atomic operations - separate atomics on real and imaginary parts
# This works for operations that decompose component-wise (+, -, right)
# Note: This provides per-component atomicity, not full Complex atomicity
@inline function _opencl_atomic_modify!(ptr::Core.LLVMPtr{Complex{T},A}, op::OP, x::Complex{T}) where {T<:Union{Float32,Float64},A,OP}
# Get pointers to real and imaginary components
ptr_re = Base.bitcast(Core.LLVMPtr{T,A}, ptr)
ptr_im = Base.bitcast(Core.LLVMPtr{T,A}, ptr + sizeof(T))

if op === (+)
old_re = SPIRVIntrinsics.atomic_add!(ptr_re, x.re)
old_im = SPIRVIntrinsics.atomic_add!(ptr_im, x.im)
return Complex{T}(old_re, old_im)
elseif op === (-)
old_re = SPIRVIntrinsics.atomic_sub!(ptr_re, x.re)
old_im = SPIRVIntrinsics.atomic_sub!(ptr_im, x.im)
return Complex{T}(old_re, old_im)
else
# For other operations (like right for swap), use CAS loop
old_re = SPIRVIntrinsics.atomic_add!(ptr_re, zero(T)) # atomic read
old_im = SPIRVIntrinsics.atomic_add!(ptr_im, zero(T)) # atomic read
old = Complex{T}(old_re, old_im)

# Compute new value
new = op(old, x)

# Try to swap using CAS
_opencl_atomic_cas!(ptr, old, new)

# Return the old value we read
return old
end
end

end # module AtomixOpenCLExt

Loading