diff --git a/lib/intrinsics/Project.toml b/lib/intrinsics/Project.toml index 76b3100d..93bea810 100644 --- a/lib/intrinsics/Project.toml +++ b/lib/intrinsics/Project.toml @@ -1,7 +1,7 @@ name = "SPIRVIntrinsics" uuid = "71d1d633-e7e8-4a92-83a1-de8814b09ba8" authors = ["Tim Besard "] -version = "0.5.2" +version = "0.5.3" [deps] ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" diff --git a/lib/intrinsics/src/math.jl b/lib/intrinsics/src/math.jl index d51d603b..39117bb1 100644 --- a/lib/intrinsics/src/math.jl +++ b/lib/intrinsics/src/math.jl @@ -1,7 +1,7 @@ # Math Functions # TODO: vector types -const generic_types = [Float32,Float64] +const generic_types = [Float16, Float32, Float64] const generic_types_float = [Float32] const generic_types_double = [Float64] @@ -33,7 +33,7 @@ for gentype in generic_types @device_override Base.cos(x::$gentype) = @builtin_ccall("cos", $gentype, ($gentype,), x) @device_override Base.cosh(x::$gentype) = @builtin_ccall("cosh", $gentype, ($gentype,), x) -@device_function cospi(x::$gentype) = @builtin_ccall("cospi", $gentype, ($gentype,), x) +@device_override Base.cospi(x::$gentype) = @builtin_ccall("cospi", $gentype, ($gentype,), x) @device_override SpecialFunctions.erfc(x::$gentype) = @builtin_ccall("erfc", $gentype, ($gentype,), x) @device_override SpecialFunctions.erf(x::$gentype) = @builtin_ccall("erf", $gentype, ($gentype,), x) @@ -59,7 +59,10 @@ for gentype in generic_types #@device_override Base.mod(x::$gentype, y::$gentype) = @builtin_ccall("fmod", $gentype, ($gentype, $gentype), x, y) # fract(x::$gentype, $gentype *iptr) = @builtin_ccall("fract", $gentype, ($gentype, $gentype *), x, iptr) -@device_override Base.hypot(x::$gentype, y::$gentype) = @builtin_ccall("hypot", $gentype, ($gentype, $gentype), x, y) +# TODO: remove once https://github.com/pocl/pocl/issues/2034 is addressed +if $gentype != Float16 + @device_override Base.hypot(x::$gentype, y::$gentype) = @builtin_ccall("hypot", $gentype, ($gentype, $gentype), x, y) +end @device_override SpecialFunctions.loggamma(x::$gentype) = @builtin_ccall("lgamma", $gentype, ($gentype,), x) @@ -81,8 +84,6 @@ for gentype in generic_types @device_override Base.:(^)(x::$gentype, y::$gentype) = @builtin_ccall("pow", $gentype, ($gentype, $gentype), x, y) @device_function powr(x::$gentype, y::$gentype) = @builtin_ccall("powr", $gentype, ($gentype, $gentype), x, y) -@device_override Base.rem(x::$gentype, y::$gentype) = @builtin_ccall("remainder", $gentype, ($gentype, $gentype), x, y) - @device_function rint(x::$gentype) = @builtin_ccall("rint", $gentype, ($gentype,), x) @device_override Base.round(x::$gentype) = @builtin_ccall("round", $gentype, ($gentype,), x) @@ -100,13 +101,13 @@ for gentype in generic_types return sinval, cosval[] end @device_override Base.sinh(x::$gentype) = @builtin_ccall("sinh", $gentype, ($gentype,), x) -@device_function sinpi(x::$gentype) = @builtin_ccall("sinpi", $gentype, ($gentype,), x) +@device_override Base.sinpi(x::$gentype) = @builtin_ccall("sinpi", $gentype, ($gentype,), x) @device_override Base.sqrt(x::$gentype) = @builtin_ccall("sqrt", $gentype, ($gentype,), x) @device_override Base.tan(x::$gentype) = @builtin_ccall("tan", $gentype, ($gentype,), x) @device_override Base.tanh(x::$gentype) = @builtin_ccall("tanh", $gentype, ($gentype,), x) -@device_function tanpi(x::$gentype) = @builtin_ccall("tanpi", $gentype, ($gentype,), x) +@device_override Base.tanpi(x::$gentype) = @builtin_ccall("tanpi", $gentype, ($gentype,), x) @device_override SpecialFunctions.gamma(x::$gentype) = @builtin_ccall("tgamma", $gentype, ($gentype,), x) @@ -151,11 +152,13 @@ end # frexp(x::Float64{n}, Int32{n} *exp) = @builtin_ccall("frexp", Float64{n}, (Float64{n}, Int32{n} *), x, exp) # frexp(x::Float64, Int32 *exp) = @builtin_ccall("frexp", Float64, (Float64, Int32 *), x, exp) +@device_function ilogb(x::Float16) = @builtin_ccall("ilogb", Int32, (Float16,), x) # ilogb(x::Float32{n}) = @builtin_ccall("ilogb", Int32{n}, (Float32{n},), x) @device_function ilogb(x::Float32) = @builtin_ccall("ilogb", Int32, (Float32,), x) # ilogb(x::Float64{n}) = @builtin_ccall("ilogb", Int32{n}, (Float64{n},), x) @device_function ilogb(x::Float64) = @builtin_ccall("ilogb", Int32, (Float64,), x) +@device_override Base.ldexp(x::Float16, k::Int32) = @builtin_ccall("ldexp", Float16, (Float16, Int32), x, k) # ldexp(x::Float32{n}, k::Int32{n}) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32{n}), x, k) # ldexp(x::Float32{n}, k::Int32) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32), x, k) @device_override Base.ldexp(x::Float32, k::Int32) = @builtin_ccall("ldexp", Float32, (Float32, Int32), x, k) @@ -168,11 +171,13 @@ end # lgamma_r(x::Float64{n}, Int32{n} *signp) = @builtin_ccall("lgamma_r", Float64{n}, (Float64{n}, Int32{n} *), x, signp) # Float64 lgamma_r(x::Float64, Int32 *signp) = @builtin_ccall("lgamma_r", Float64, (Float64, Int32 *), x, signp) +@device_function nan(nancode::UInt16) = @builtin_ccall("nan", Float16, (UInt16,), nancode) # nan(nancode::uintn) = @builtin_ccall("nan", Float32{n}, (uintn,), nancode) @device_function nan(nancode::UInt32) = @builtin_ccall("nan", Float32, (UInt32,), nancode) # nan(nancode::UInt64{n}) = @builtin_ccall("nan", Float64{n}, (UInt64{n},), nancode) @device_function nan(nancode::UInt64) = @builtin_ccall("nan", Float64, (UInt64,), nancode) +@device_override Base.:(^)(x::Float16, y::Int32) = @builtin_ccall("pown", Float16, (Float16, Int32), x, y) # pown(x::Float32{n}, y::Int32{n}) = @builtin_ccall("pown", Float32{n}, (Float32{n}, Int32{n}), x, y) @device_override Base.:(^)(x::Float32, y::Int32) = @builtin_ccall("pown", Float32, (Float32, Int32), x, y) # pown(x::Float64{n}, y::Int32{n}) = @builtin_ccall("pown", Float64{n}, (Float64{n}, Int32{n}), x, y) @@ -183,10 +188,11 @@ end # remquo(x::Float64{n}, y::Float64{n}, Int32{n} *quo) = @builtin_ccall("remquo", Float64{n}, (Float64{n}, Float64{n}, Int32{n} *), x, y, quo) # remquo(x::Float64, y::Float64, Int32 *quo) = @builtin_ccall("remquo", Float64, (Float64, Float64, Int32 *), x, y, quo) +@device_function rootn(x::Float16, y::Int32) = @builtin_ccall("rootn", Float16, (Float16, Int32), x, y) # rootn(x::Float32{n}, y::Int32{n}) = @builtin_ccall("rootn", Float32{n}, (Float32{n}, Int32{n}), x, y) @device_function rootn(x::Float32, y::Int32) = @builtin_ccall("rootn", Float32, (Float32, Int32), x, y) # rootn(x::Float64{n}, y::Int32{n}) = @builtin_ccall("rootn", Float64{n}, (Float64{n}, Int32{n}), x, y) -# rootn(x::Float64, y::Int32) = @builtin_ccall("rootn", Float64{n}, (Float64, Int32), x, y) +@device_function rootn(x::Float64, y::Int32) = @builtin_ccall("rootn", Float64, (Float64, Int32), x, y) # TODO: half and native diff --git a/lib/intrinsics/src/utils.jl b/lib/intrinsics/src/utils.jl index e1a5a939..2c12db8a 100644 --- a/lib/intrinsics/src/utils.jl +++ b/lib/intrinsics/src/utils.jl @@ -26,6 +26,8 @@ macro builtin_ccall(name, ret, argtypes, args...) "c" elseif T == UInt8 "h" + elseif T == Float16 + "Dh" elseif T == Float32 "f" elseif T == Float64 diff --git a/test/intrinsics.jl b/test/intrinsics.jl index bad6838f..ad970c05 100644 --- a/test/intrinsics.jl +++ b/test/intrinsics.jl @@ -1,3 +1,17 @@ +function call_on_device(f, args...) + function kernel(res, f, args...) + res[] = f(args...) + return + end + T = OpenCL.code_typed(() -> f(args...), ())[][2] + res = CLArray{T, 0}(undef) + @opencl kernel(res, f, args...) + return OpenCL.@allowscalar res[] +end + +const float_types = filter(x -> x <: Base.IEEEFloat, GPUArraysTestSuite.supported_eltypes(CLArray)) +const ispocl = cl.platform().name == "Portable Computing Language" + @testset "intrinsics" begin @testset "barrier" begin @@ -49,4 +63,105 @@ cl.memory_backend() isa cl.SVMBackend && @on_device atomic_work_item_fence(OpenC end +@testset "math" begin + +@testset "unary - $T" for T in float_types + @testset "$f" for f in [ + acos, acosh, + asin, asinh, + atan, atanh, + cbrt, + ceil, + cos, cosh, cospi, + exp, exp2, exp10, expm1, + abs, + floor, + log, log2, log10, log1p, + round, + sin, sinh, sinpi, + sqrt, + tan, tanh, tanpi, + trunc, + ] + x = rand(T) + if f == acosh + x += 1 + end + broken = ispocl && T == Float16 && f in [acosh, asinh, atanh, cbrt, cospi, expm1, log1p, sinpi, tanpi] + @test call_on_device(f, x) ≈ f(x) broken = broken + end +end + +@testset "binary - $T" for T in float_types + @testset "$f" for f in [ + atan, + copysign, + max, + min, + hypot, + (^), + ] + x = rand(T) + y = rand(T) + broken = ispocl && T == Float16 && f == atan + @test call_on_device(f, x, y) ≈ f(x, y) broken = broken + end +end + +@testset "ternary - $T" for T in float_types + @testset "$f" for f in [ + fma, + ] + x = rand(T) + y = rand(T) + z = rand(T) + @test call_on_device(f, x, y, z) ≈ f(x, y, z) + end +end + +@testset "OpenCL-specific unary - $T" for T in float_types + @testset "$f" for f in [ + OpenCL.acospi, + OpenCL.asinpi, + OpenCL.atanpi, + OpenCL.logb, + OpenCL.rint, + OpenCL.rsqrt, + ] + x = rand(T) + broken = ispocl && T == Float16 && !(f in [OpenCL.rint, OpenCL.rsqrt]) + @test call_on_device(f, x) isa Real broken = broken # Just check it doesn't error + end + broken = ispocl && T == Float16 + @test call_on_device(OpenCL.ilogb, T(8.0)) isa Int32 broken = broken + @test call_on_device(OpenCL.nan, Base.uinttype(T)(0)) isa T +end + +@testset "OpenCL-specific binary - $T" for T in float_types + @testset "$f" for f in [ + OpenCL.atanpi, + OpenCL.dim, + OpenCL.maxmag, + OpenCL.minmag, + OpenCL.nextafter, + OpenCL.powr, + ] + x = rand(T) + y = rand(T) + broken = ispocl && T == Float16 && !(f in [OpenCL.maxmag, OpenCL.minmag]) + @test call_on_device(f, x, y) isa Real broken = broken # Just check it doesn't error + end + broken = ispocl && T == Float16 + @test call_on_device(OpenCL.rootn, T(8.0), Int32(3)) ≈ T(2.0) broken = broken +end + +@testset "OpenCL-specific ternary - $T" for T in float_types + x = rand(T) + y = rand(T) + z = rand(T) + @test call_on_device(OpenCL.mad, x, y, z) ≈ x * y + z +end + +end + end