Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/intrinsics/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SPIRVIntrinsics"
uuid = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
authors = ["Tim Besard <[email protected]>"]
version = "0.5.2"
version = "0.5.3"

[deps]
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expand Down
22 changes: 14 additions & 8 deletions lib/intrinsics/src/math.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Math Functions

# TODO: vector types
const generic_types = [Float32,Float64]
const generic_types = [Float16, Float32, Float64]
const generic_types_float = [Float32]
const generic_types_double = [Float64]

Expand Down Expand Up @@ -33,7 +33,7 @@ for gentype in generic_types

@device_override Base.cos(x::$gentype) = @builtin_ccall("cos", $gentype, ($gentype,), x)
@device_override Base.cosh(x::$gentype) = @builtin_ccall("cosh", $gentype, ($gentype,), x)
@device_function cospi(x::$gentype) = @builtin_ccall("cospi", $gentype, ($gentype,), x)
@device_override Base.cospi(x::$gentype) = @builtin_ccall("cospi", $gentype, ($gentype,), x)

@device_override SpecialFunctions.erfc(x::$gentype) = @builtin_ccall("erfc", $gentype, ($gentype,), x)
@device_override SpecialFunctions.erf(x::$gentype) = @builtin_ccall("erf", $gentype, ($gentype,), x)
Expand All @@ -59,7 +59,10 @@ for gentype in generic_types
#@device_override Base.mod(x::$gentype, y::$gentype) = @builtin_ccall("fmod", $gentype, ($gentype, $gentype), x, y)
# fract(x::$gentype, $gentype *iptr) = @builtin_ccall("fract", $gentype, ($gentype, $gentype *), x, iptr)

@device_override Base.hypot(x::$gentype, y::$gentype) = @builtin_ccall("hypot", $gentype, ($gentype, $gentype), x, y)
# TODO: remove once https://github.com/pocl/pocl/issues/2034 is addressed
if $gentype != Float16
@device_override Base.hypot(x::$gentype, y::$gentype) = @builtin_ccall("hypot", $gentype, ($gentype, $gentype), x, y)
end

@device_override SpecialFunctions.loggamma(x::$gentype) = @builtin_ccall("lgamma", $gentype, ($gentype,), x)

Expand All @@ -81,8 +84,6 @@ for gentype in generic_types
@device_override Base.:(^)(x::$gentype, y::$gentype) = @builtin_ccall("pow", $gentype, ($gentype, $gentype), x, y)
@device_function powr(x::$gentype, y::$gentype) = @builtin_ccall("powr", $gentype, ($gentype, $gentype), x, y)

@device_override Base.rem(x::$gentype, y::$gentype) = @builtin_ccall("remainder", $gentype, ($gentype, $gentype), x, y)

@device_function rint(x::$gentype) = @builtin_ccall("rint", $gentype, ($gentype,), x)

@device_override Base.round(x::$gentype) = @builtin_ccall("round", $gentype, ($gentype,), x)
Expand All @@ -100,13 +101,13 @@ for gentype in generic_types
return sinval, cosval[]
end
@device_override Base.sinh(x::$gentype) = @builtin_ccall("sinh", $gentype, ($gentype,), x)
@device_function sinpi(x::$gentype) = @builtin_ccall("sinpi", $gentype, ($gentype,), x)
@device_override Base.sinpi(x::$gentype) = @builtin_ccall("sinpi", $gentype, ($gentype,), x)

@device_override Base.sqrt(x::$gentype) = @builtin_ccall("sqrt", $gentype, ($gentype,), x)

@device_override Base.tan(x::$gentype) = @builtin_ccall("tan", $gentype, ($gentype,), x)
@device_override Base.tanh(x::$gentype) = @builtin_ccall("tanh", $gentype, ($gentype,), x)
@device_function tanpi(x::$gentype) = @builtin_ccall("tanpi", $gentype, ($gentype,), x)
@device_override Base.tanpi(x::$gentype) = @builtin_ccall("tanpi", $gentype, ($gentype,), x)

@device_override SpecialFunctions.gamma(x::$gentype) = @builtin_ccall("tgamma", $gentype, ($gentype,), x)

Expand Down Expand Up @@ -151,11 +152,13 @@ end
# frexp(x::Float64{n}, Int32{n} *exp) = @builtin_ccall("frexp", Float64{n}, (Float64{n}, Int32{n} *), x, exp)
# frexp(x::Float64, Int32 *exp) = @builtin_ccall("frexp", Float64, (Float64, Int32 *), x, exp)

@device_function ilogb(x::Float16) = @builtin_ccall("ilogb", Int32, (Float16,), x)
# ilogb(x::Float32{n}) = @builtin_ccall("ilogb", Int32{n}, (Float32{n},), x)
@device_function ilogb(x::Float32) = @builtin_ccall("ilogb", Int32, (Float32,), x)
# ilogb(x::Float64{n}) = @builtin_ccall("ilogb", Int32{n}, (Float64{n},), x)
@device_function ilogb(x::Float64) = @builtin_ccall("ilogb", Int32, (Float64,), x)

@device_override Base.ldexp(x::Float16, k::Int32) = @builtin_ccall("ldexp", Float16, (Float16, Int32), x, k)
# ldexp(x::Float32{n}, k::Int32{n}) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32{n}), x, k)
# ldexp(x::Float32{n}, k::Int32) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32), x, k)
@device_override Base.ldexp(x::Float32, k::Int32) = @builtin_ccall("ldexp", Float32, (Float32, Int32), x, k)
Expand All @@ -168,11 +171,13 @@ end
# lgamma_r(x::Float64{n}, Int32{n} *signp) = @builtin_ccall("lgamma_r", Float64{n}, (Float64{n}, Int32{n} *), x, signp)
# Float64 lgamma_r(x::Float64, Int32 *signp) = @builtin_ccall("lgamma_r", Float64, (Float64, Int32 *), x, signp)

@device_function nan(nancode::UInt16) = @builtin_ccall("nan", Float16, (UInt16,), nancode)
# nan(nancode::uintn) = @builtin_ccall("nan", Float32{n}, (uintn,), nancode)
@device_function nan(nancode::UInt32) = @builtin_ccall("nan", Float32, (UInt32,), nancode)
# nan(nancode::UInt64{n}) = @builtin_ccall("nan", Float64{n}, (UInt64{n},), nancode)
@device_function nan(nancode::UInt64) = @builtin_ccall("nan", Float64, (UInt64,), nancode)

@device_override Base.:(^)(x::Float16, y::Int32) = @builtin_ccall("pown", Float16, (Float16, Int32), x, y)
# pown(x::Float32{n}, y::Int32{n}) = @builtin_ccall("pown", Float32{n}, (Float32{n}, Int32{n}), x, y)
@device_override Base.:(^)(x::Float32, y::Int32) = @builtin_ccall("pown", Float32, (Float32, Int32), x, y)
# pown(x::Float64{n}, y::Int32{n}) = @builtin_ccall("pown", Float64{n}, (Float64{n}, Int32{n}), x, y)
Expand All @@ -183,10 +188,11 @@ end
# remquo(x::Float64{n}, y::Float64{n}, Int32{n} *quo) = @builtin_ccall("remquo", Float64{n}, (Float64{n}, Float64{n}, Int32{n} *), x, y, quo)
# remquo(x::Float64, y::Float64, Int32 *quo) = @builtin_ccall("remquo", Float64, (Float64, Float64, Int32 *), x, y, quo)

@device_function rootn(x::Float16, y::Int32) = @builtin_ccall("rootn", Float16, (Float16, Int32), x, y)
# rootn(x::Float32{n}, y::Int32{n}) = @builtin_ccall("rootn", Float32{n}, (Float32{n}, Int32{n}), x, y)
@device_function rootn(x::Float32, y::Int32) = @builtin_ccall("rootn", Float32, (Float32, Int32), x, y)
# rootn(x::Float64{n}, y::Int32{n}) = @builtin_ccall("rootn", Float64{n}, (Float64{n}, Int32{n}), x, y)
# rootn(x::Float64, y::Int32) = @builtin_ccall("rootn", Float64{n}, (Float64, Int32), x, y)
@device_function rootn(x::Float64, y::Int32) = @builtin_ccall("rootn", Float64, (Float64, Int32), x, y)


# TODO: half and native
Expand Down
2 changes: 2 additions & 0 deletions lib/intrinsics/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ macro builtin_ccall(name, ret, argtypes, args...)
"c"
elseif T == UInt8
"h"
elseif T == Float16
"Dh"
elseif T == Float32
"f"
elseif T == Float64
Expand Down
115 changes: 115 additions & 0 deletions test/intrinsics.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
function call_on_device(f, args...)
function kernel(res, f, args...)
res[] = f(args...)
return
end
T = OpenCL.code_typed(() -> f(args...), ())[][2]
res = CLArray{T, 0}(undef)
@opencl kernel(res, f, args...)
return OpenCL.@allowscalar res[]
end

const float_types = filter(x -> x <: Base.IEEEFloat, GPUArraysTestSuite.supported_eltypes(CLArray))
const ispocl = cl.platform().name == "Portable Computing Language"

@testset "intrinsics" begin

@testset "barrier" begin
Expand Down Expand Up @@ -49,4 +63,105 @@ cl.memory_backend() isa cl.SVMBackend && @on_device atomic_work_item_fence(OpenC

end

@testset "math" begin

@testset "unary - $T" for T in float_types
@testset "$f" for f in [
acos, acosh,
asin, asinh,
atan, atanh,
cbrt,
ceil,
cos, cosh, cospi,
exp, exp2, exp10, expm1,
abs,
floor,
log, log2, log10, log1p,
round,
sin, sinh, sinpi,
sqrt,
tan, tanh, tanpi,
trunc,
]
x = rand(T)
if f == acosh
x += 1
end
broken = ispocl && T == Float16 && f in [acosh, asinh, atanh, cbrt, cospi, expm1, log1p, sinpi, tanpi]
@test call_on_device(f, x) ≈ f(x) broken = broken
end
end

@testset "binary - $T" for T in float_types
@testset "$f" for f in [
atan,
copysign,
max,
min,
hypot,
(^),
]
x = rand(T)
y = rand(T)
broken = ispocl && T == Float16 && f == atan
@test call_on_device(f, x, y) ≈ f(x, y) broken = broken
end
end

@testset "ternary - $T" for T in float_types
@testset "$f" for f in [
fma,
]
x = rand(T)
y = rand(T)
z = rand(T)
@test call_on_device(f, x, y, z) ≈ f(x, y, z)
end
end

@testset "OpenCL-specific unary - $T" for T in float_types
@testset "$f" for f in [
OpenCL.acospi,
OpenCL.asinpi,
OpenCL.atanpi,
OpenCL.logb,
OpenCL.rint,
OpenCL.rsqrt,
]
x = rand(T)
broken = ispocl && T == Float16 && !(f in [OpenCL.rint, OpenCL.rsqrt])
@test call_on_device(f, x) isa Real broken = broken # Just check it doesn't error
end
broken = ispocl && T == Float16
@test call_on_device(OpenCL.ilogb, T(8.0)) isa Int32 broken = broken
@test call_on_device(OpenCL.nan, Base.uinttype(T)(0)) isa T
end

@testset "OpenCL-specific binary - $T" for T in float_types
@testset "$f" for f in [
OpenCL.atanpi,
OpenCL.dim,
OpenCL.maxmag,
OpenCL.minmag,
OpenCL.nextafter,
OpenCL.powr,
]
x = rand(T)
y = rand(T)
broken = ispocl && T == Float16 && !(f in [OpenCL.maxmag, OpenCL.minmag])
@test call_on_device(f, x, y) isa Real broken = broken # Just check it doesn't error
end
broken = ispocl && T == Float16
@test call_on_device(OpenCL.rootn, T(8.0), Int32(3)) ≈ T(2.0) broken = broken
end

@testset "OpenCL-specific ternary - $T" for T in float_types
x = rand(T)
y = rand(T)
z = rand(T)
@test call_on_device(OpenCL.mad, x, y, z) ≈ x * y + z
end

end

end
Loading