diff --git a/Project.toml b/Project.toml index ead5f2e65..ca6971199 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.7.0" +version = "1.8.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 7cf9f1d98..a6d95c80b 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -184,6 +184,8 @@ let @scalar_rule max(x, y) @setup(gt = x > y) (gt, !gt) @scalar_rule min(x, y) @setup(gt = x > y) (!gt, gt) + @scalar_rule copysign(y, x) (ifelse(signbit(x)!=signbit(y), -one(y), +one(y)), NoTangent()) + # Unary functions @scalar_rule +x true @scalar_rule -x -1 diff --git a/test/rulesets/Base/fastmath_able.jl b/test/rulesets/Base/fastmath_able.jl index a24b3bc22..e916189dd 100644 --- a/test/rulesets/Base/fastmath_able.jl +++ b/test/rulesets/Base/fastmath_able.jl @@ -192,6 +192,16 @@ const FASTABLE_AST = quote end end + @testset "copysign" begin + # don't go too close to zero as the numerics may jump over it yielding wrong results + @testset "at $y" for y in (-1.1, 0.1, 100.0) + @testset "at $x" for x in (-1.1, -0.1, 33.0) + test_frule(copysign, y, x) + test_rrule(copysign, y, x) + end + end + end + @testset "sign" begin @testset "real" begin @testset "at $x" for x in (-1.1, -1.1, 0.5, 100.0)