From d9f65ed8cbba4ca1775c26e7339357b913dfbdd4 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 12 Dec 2021 23:42:41 +0100 Subject: [PATCH 01/16] Add missing rules for SpecialFunctions 0.10 --- src/rules.jl | 78 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 61 insertions(+), 17 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index ebc1ac7..9f7ba13 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -117,6 +117,8 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule SpecialFunctions.erfinv(x) = :( (sqrt(π) / 2) * exp(SpecialFunctions.erfinv($x)^2) ) @define_diffrule SpecialFunctions.erfc(x) = :( -(2 / sqrt(π)) * exp(-$x * $x) ) +@define_diffrule SpecialFunctions.logerfc(x) = + :( -(2 * exp(- $x^2 - SpecialFunctions.logerfc($x))) / sqrt(π) ) @define_diffrule SpecialFunctions.erfcinv(x) = :( -(sqrt(π) / 2) * exp(SpecialFunctions.erfcinv($x)^2) ) @define_diffrule SpecialFunctions.erfi(x) = :( (2 / sqrt(π)) * exp($x * $x) ) @@ -124,6 +126,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) :( (2 * $x * SpecialFunctions.erfcx($x)) - (2 / sqrt(π)) ) @define_diffrule SpecialFunctions.logerfcx(x) = :( 2 * ($x - inv(SpecialFunctions.erfcx($x) * sqrt(π))) ) + @define_diffrule SpecialFunctions.dawson(x) = :( 1 - (2 * $x * SpecialFunctions.dawson($x)) ) @define_diffrule SpecialFunctions.digamma(x) = @@ -132,14 +135,35 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) :( inv(SpecialFunctions.trigamma(SpecialFunctions.invdigamma($x))) ) @define_diffrule SpecialFunctions.trigamma(x) = :( SpecialFunctions.polygamma(2, $x) ) + +# derivatives for `airybix` and `airybiprimex` are only correct for real inputs +# `airyaix` and `airyaiprimex` are only defined for positive real inputs +# `airybix` and `airybiprimex` are unscaled for negative real inputs @define_diffrule SpecialFunctions.airyai(x) = :( SpecialFunctions.airyaiprime($x) ) @define_diffrule SpecialFunctions.airyaiprime(x) = :( $x * SpecialFunctions.airyai($x) ) +@define_diffrule SpecialFunctions.airyaix(x) = + :( SpecialFunctions.airyaiprimex($x) + sqrt($x) * SpecialFunctions.airyaix($x) ) +@define_diffrule SpecialFunctions.airyaiprimex(x) = + :( $x * SpecialFunctions.airyaix($x) + sqrt($x) * SpecialFunctions.airyaiprimex($x) ) @define_diffrule SpecialFunctions.airybi(x) = :( SpecialFunctions.airybiprime($x) ) @define_diffrule SpecialFunctions.airybiprime(x) = :( $x * SpecialFunctions.airybi($x) ) +@define_diffrule SpecialFunctions.airybix(x) = + :( if $x > zero($x) + SpecialFunctions.airybiprimex($x) - sqrt($x) * SpecialFunctions.airybix($x) + else + SpecialFunctions.airybiprimex($x) + end ) +@define_diffrule SpecialFunctions.airybiprimex(x) = + :( if $x > zero($x) + $x * SpecialFunctions.airybix($x) - sqrt($x) * SpecialFunctions.airybiprimex($x) + else + $x * SpecialFunctions.airybix($x) + end ) + @define_diffrule SpecialFunctions.besselj0(x) = :( -SpecialFunctions.besselj1($x) ) @define_diffrule SpecialFunctions.besselj1(x) = @@ -149,49 +173,69 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule SpecialFunctions.bessely1(x) = :( (SpecialFunctions.bessely0($x) - SpecialFunctions.bessely(2, $x)) / 2 ) +@define_diffrule SpecialFunctions.sinint(x) = :( sinc($x / π) ) +@define_diffrule SpecialFunctions.cosint(x) = :( cos($x) / $x ) + +@define_diffrule SpecialFunctions.ellipk(m) = + :( (SpecialFunctions.ellipe($m) / (1 - $m) - SpecialFunctions.ellipk($m)) / (2 * $m) ) +@define_diffrule SpecialFunctions.ellipe(m) = + :( (SpecialFunctions.ellipe($m) - SpecialFunctions.ellipk($m)) / (2 * $m) ) + # TODO: # # eta # zeta -# airyaix -# airyaiprimex -# airybix -# airybiprimex # binary # #--------# +# derivatives with respect to the order `ν` exist but are not implemented +# (analogously to the ChainRules definitions in SpecialFunctions) + +# derivatives for `besselix`, `besseljx` and `besselyx` are only correct for real inputs +# see https://github.com/JuliaMath/SpecialFunctions.jl/blob/master/src/chainrules.jl +# for forward-mode and reverse-mode derivatives for complex inputs + @define_diffrule SpecialFunctions.besselj(ν, x) = :NaN, :( (SpecialFunctions.besselj($ν - 1, $x) - SpecialFunctions.besselj($ν + 1, $x)) / 2 ) +@define_diffrule SpecialFunctions.besseljx(ν, x) = + :NaN, :( (SpecialFunctions.besseljx($ν - 1, $x) - SpecialFunctions.besseljx($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besseli(ν, x) = :NaN, :( (SpecialFunctions.besseli($ν - 1, $x) + SpecialFunctions.besseli($ν + 1, $x)) / 2 ) +@define_diffrule SpecialFunctions.besselix(ν, x) = + :NaN, :( (SpecialFunctions.besselix($ν - 1, $x) + SpecialFunctions.besselix($ν + 1, $x)) / 2 - sign($x) * SpecialFunctions.besselix($ν, $x) ) @define_diffrule SpecialFunctions.bessely(ν, x) = :NaN, :( (SpecialFunctions.bessely($ν - 1, $x) - SpecialFunctions.bessely($ν + 1, $x)) / 2 ) +@define_diffrule SpecialFunctions.besselyx(ν, x) = + :NaN, :( (SpecialFunctions.besselyx($ν - 1, $x) - SpecialFunctions.besselyx($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besselk(ν, x) = :NaN, :( -(SpecialFunctions.besselk($ν - 1, $x) + SpecialFunctions.besselk($ν + 1, $x)) / 2 ) +@define_diffrule SpecialFunctions.besselkx(ν, x) = + :NaN, :( -(SpecialFunctions.besselkx($ν - 1, $x) + SpecialFunctions.besselkx($ν + 1, $x)) / 2 + SpecialFunctions.besselkx($ν, $x) ) +@define_diffrule SpecialFunctions.besselh(ν, x) = + :NaN, :( (SpecialFunctions.besselh($ν - 1, $x) - SpecialFunctions.besselh($ν + 1, $x)) / 2 ) +@define_diffrule SpecialFunctions.besselhx(ν, x) = + :NaN, :( (SpecialFunctions.besselhx($ν - 1, $x) - SpecialFunctions.besselhx($ν + 1, $x)) / 2 - im * SpecialFunctions.besselhx($ν, $x) ) @define_diffrule SpecialFunctions.hankelh1(ν, x) = :NaN, :( (SpecialFunctions.hankelh1($ν - 1, $x) - SpecialFunctions.hankelh1($ν + 1, $x)) / 2 ) +@define_diffrule SpecialFunctions.hankelh1x(ν, x) = + :NaN, :( (SpecialFunctions.hankelh1x($ν - 1, $x) - SpecialFunctions.hankelh1x($ν + 1, $x)) / 2 - im * SpecialFunctions.hankelh1x($ν, $x) ) @define_diffrule SpecialFunctions.hankelh2(ν, x) = :NaN, :( (SpecialFunctions.hankelh2($ν - 1, $x) - SpecialFunctions.hankelh2($ν + 1, $x)) / 2 ) +@define_diffrule SpecialFunctions.hankelh2x(ν, x) = + :NaN, :( (SpecialFunctions.hankelh2x($ν - 1, $x) - SpecialFunctions.hankelh2x($ν + 1, $x)) / 2 + im * SpecialFunctions.hankelh2x($ν, $x) ) + @define_diffrule SpecialFunctions.polygamma(m, x) = :NaN, :( SpecialFunctions.polygamma($m + 1, $x) ) + @define_diffrule SpecialFunctions.beta(a, b) = :( SpecialFunctions.beta($a, $b)*(SpecialFunctions.digamma($a) - SpecialFunctions.digamma($a + $b)) ), :( SpecialFunctions.beta($a, $b)*(SpecialFunctions.digamma($b) - SpecialFunctions.digamma($a + $b)) ) -@define_diffrule SpecialFunctions.logbeta(a, b) = +@define_diffrule SpecialFunctions.logbeta(a, b) = :( SpecialFunctions.digamma($a) - SpecialFunctions.digamma($a + $b) ), :( SpecialFunctions.digamma($b) - SpecialFunctions.digamma($a + $b) ) -# TODO: -# -# zeta -# besseljx -# besselyx -# besselix -# besselkx -# besselh -# besselhx -# hankelh1x -# hankelh2 -# hankelh2x +# derivative wrt to `s` is not implemented +@define_diffrule SpecialFunctions.zeta(s, z) = + :NaN, :( - $s * SpecialFunctions.zeta($s + 1, $z) ) # ternary # #---------# From 74e23f44e2c7cf3d3522b6f811c66d7eeea2fb67 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 12 Dec 2021 23:43:13 +0100 Subject: [PATCH 02/16] Update tests --- Project.toml | 3 ++- test/runtests.jl | 31 ++++++++++++++++++++----------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 769ec33..994508e 100644 --- a/Project.toml +++ b/Project.toml @@ -15,8 +15,9 @@ SpecialFunctions = "0.10, 1.0, 2" julia = "1.3" [extras] +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Random"] +test = ["FiniteDifferences", "Test", "Random"] diff --git a/test/runtests.jl b/test/runtests.jl index 072c962..129e388 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,14 +1,12 @@ using DiffRules +using FiniteDifferences using Test import SpecialFunctions, NaNMath, LogExpFunctions import Random Random.seed!(1) -function finitediff(f, x) - ϵ = cbrt(eps(typeof(x))) * max(one(typeof(x)), abs(x)) - return (f(x + ϵ) - f(x - ϵ)) / (ϵ + ϵ) -end +const finitediff = central_fdm(5, 1) @testset "DiffRules" begin @testset "check rules" begin @@ -32,7 +30,7 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) @eval begin let goo = rand() + $modifier - @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) + @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-9 atol=1e-9 # test for 2pi functions if "mod2pi" == string($M.$f) goo = 4pi + $modifier @@ -51,11 +49,11 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) foo, bar = rand(1:10), rand() end dx, dy = $(derivs[1]), $(derivs[2]) - if !(isnan(dx)) - @test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05) + if !isnan(dx) + @test dx ≈ finitediff(z -> $M.$f(z, bar), float(foo)) rtol=1e-9 atol=1e-9 end - if !(isnan(dy)) - @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) + if !isnan(dy) + @test dy ≈ finitediff(z -> $M.$f(foo, z), bar) rtol=1e-9 atol=1e-9 end end end @@ -89,7 +87,7 @@ for xtype in [:Float64, :BigFloat, :Int64] x = $xtype(rand(1 : 10)) y = $mode dx, dy = $(derivs[1]), $(derivs[2]) - @test isapprox(dx, finitediff(z -> rem2pi(z, y), float(x)), rtol=0.05) + @test dx ≈ finitediff(z -> rem2pi(z, y), float(x)) rtol=1e-9 atol=1e-9 @test isnan(dy) end end @@ -105,13 +103,24 @@ for xtype in [:Float64, :BigFloat] x = rand($xtype) y = $ytype(rand(1 : 10)) dx, dy = $(derivs[1]), $(derivs[2]) - @test isapprox(dx, finitediff(z -> ldexp(z, y), x), rtol=0.05) + @test dx ≈ finitediff(z -> ldexp(z, y), x) rtol=1e-9 atol=1e-9 @test isnan(dy) end end end end +# Check negative branch for `airybix` and `airybiprimex` +for f in (:airybix, :airybiprimex) + deriv = DiffRules.diffrule(:SpecialFunctions, f, :goo) + @eval begin + let + goo = -rand() + @test $deriv ≈ finitediff(SpecialFunctions.$f, goo) rtol=1e-9 atol=1e-9 + end + end +end + end @testset "diffrules" begin From 793648080aa08d08a2b4c04c4e0f8b038393e134 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 12 Dec 2021 23:43:32 +0100 Subject: [PATCH 03/16] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 994508e..0c23504 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DiffRules" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.9.0" +version = "1.10.0" [deps] LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" From 5fbfcc528ace4c56fa4a4c6293a3a11f3ec829d9 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 13 Dec 2021 00:03:36 +0100 Subject: [PATCH 04/16] Update runtests.jl --- test/runtests.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 129e388..4966f0e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,12 +18,16 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) if arity == 1 @test DiffRules.hasdiffrule(M, f, 1) deriv = DiffRules.diffrule(M, f, :goo) - modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth) + modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh) 1.0 + elseif f === :acoth + 1.1 # values too close to 1 are problematic for finite differencing elseif f === :log1mexp -1.0 elseif f === :log2mexp -0.5 + elseif f in (:airyaix, :airyaiprimex) + 0.1 # values too close to 0 are problematic for finite differencing else 0.0 end From 382d359573776595f775397c4ea9ef0507b4afd4 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 13 Dec 2021 22:13:56 +0100 Subject: [PATCH 05/16] Add rule for `erf(x, y)` --- src/rules.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/rules.jl b/src/rules.jl index 9f7ba13..5ddc1e0 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -189,6 +189,9 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # binary # #--------# +@define_diffrule SpecialFunctions.erf(x, y) = + :( -2 / sqrt(π) * exp(-x^2) ), :( 2 / sqrt(π) * exp(-y^2) ) + # derivatives with respect to the order `ν` exist but are not implemented # (analogously to the ChainRules definitions in SpecialFunctions) From 3204111608581bbc0ec85f03cec4ce3fecf5030e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 13 Dec 2021 22:31:48 +0100 Subject: [PATCH 06/16] Fix typos --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 5ddc1e0..9fe00ac 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -190,7 +190,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) #--------# @define_diffrule SpecialFunctions.erf(x, y) = - :( -2 / sqrt(π) * exp(-x^2) ), :( 2 / sqrt(π) * exp(-y^2) ) + :( -2 / sqrt(π) * exp(-$x^2) ), :( 2 / sqrt(π) * exp(-$y^2) ) # derivatives with respect to the order `ν` exist but are not implemented # (analogously to the ChainRules definitions in SpecialFunctions) From 485a60055dc5b9b79918db5584678cbc3d328dc6 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 13 Dec 2021 22:48:08 +0100 Subject: [PATCH 07/16] Try to use `forward_fdm` --- test/runtests.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4966f0e..8225de6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,9 @@ import SpecialFunctions, NaNMath, LogExpFunctions import Random Random.seed!(1) -const finitediff = central_fdm(5, 1) +# less accurate than `central_fdm` but avoids singularities for +# e.g. `acoth`, `log`, `airyaix`, `airyaiprimex +const finitediff = forward_fdm(5, 1) @testset "DiffRules" begin @testset "check rules" begin @@ -18,16 +20,12 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) if arity == 1 @test DiffRules.hasdiffrule(M, f, 1) deriv = DiffRules.diffrule(M, f, :goo) - modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh) + modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth) 1.0 - elseif f === :acoth - 1.1 # values too close to 1 are problematic for finite differencing elseif f === :log1mexp -1.0 elseif f === :log2mexp -0.5 - elseif f in (:airyaix, :airyaiprimex) - 0.1 # values too close to 0 are problematic for finite differencing else 0.0 end From 9af877156c8f6eaf900852f020026ed70aed01af Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 13 Dec 2021 23:05:05 +0100 Subject: [PATCH 08/16] Use forward_fdm only for log, acoth, airyaix and airyaiprimex --- test/runtests.jl | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 8225de6..2046e74 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,9 +6,10 @@ import SpecialFunctions, NaNMath, LogExpFunctions import Random Random.seed!(1) -# less accurate than `central_fdm` but avoids singularities for +# `forward_fdm` is less accurate than `central_fdm` but avoids singularities for # e.g. `acoth`, `log`, `airyaix`, `airyaiprimex -const finitediff = forward_fdm(5, 1) +const finitediff = central_fdm(5, 1) +const finitediff_forward = forward_fdm(5, 1) @testset "DiffRules" begin @testset "check rules" begin @@ -32,7 +33,13 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) @eval begin let goo = rand() + $modifier - @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-9 atol=1e-9 + fd_deriv = if f in (:acoth, :log, :airyaix, :airyaiprimex) + # avoid singularities + finitediff_forward($M.$f, goo) + else + finitediff($M.$f, goo) + end + @test $deriv ≈ fd_deriv rtol=1e-9 atol=1e-9 # test for 2pi functions if "mod2pi" == string($M.$f) goo = 4pi + $modifier From 9002e416bef2b6d23d5c73b3ca96ba11a66ef48f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 13 Dec 2021 23:07:44 +0100 Subject: [PATCH 09/16] Update runtests.jl --- test/runtests.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 2046e74..7dd9c20 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -30,16 +30,16 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) else 0.0 end + fd = if f in (:acoth, :log, :airyaix, :airyaiprimex) + # avoid singularities + finitediff_forward + else + finitediff + end @eval begin let goo = rand() + $modifier - fd_deriv = if f in (:acoth, :log, :airyaix, :airyaiprimex) - # avoid singularities - finitediff_forward($M.$f, goo) - else - finitediff($M.$f, goo) - end - @test $deriv ≈ fd_deriv rtol=1e-9 atol=1e-9 + @test $deriv ≈ $fd($M.$f, goo) rtol=1e-9 atol=1e-9 # test for 2pi functions if "mod2pi" == string($M.$f) goo = 4pi + $modifier From bd8b0bbf2e1936bd7707d401ff0c1ab8f284ee2e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 13 Dec 2021 23:12:44 +0100 Subject: [PATCH 10/16] Fix binary tests --- test/runtests.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 7dd9c20..fea8b7c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,19 +50,25 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) elseif arity == 2 @test DiffRules.hasdiffrule(M, f, 2) derivs = DiffRules.diffrule(M, f, :foo, :bar) + fd = if f === :log + # avoid singularities + finitediff_forward + else + finitediff + end @eval begin let if "mod" == string($M.$f) foo, bar = rand() + 13, rand() + 5 # make sure x/y is not integer else - foo, bar = rand(1:10), rand() + foo, bar = rand(1.0:10.0), rand() end dx, dy = $(derivs[1]), $(derivs[2]) if !isnan(dx) - @test dx ≈ finitediff(z -> $M.$f(z, bar), float(foo)) rtol=1e-9 atol=1e-9 + @test dx ≈ $fd(z -> $M.$f(z, bar), foo) rtol=1e-9 atol=1e-9 end if !isnan(dy) - @test dy ≈ finitediff(z -> $M.$f(foo, z), bar) rtol=1e-9 atol=1e-9 + @test dy ≈ $fd(z -> $M.$f(foo, z), bar) rtol=1e-9 atol=1e-9 end end end From fe26ccfb3786e89cd667288c6d1cccd9f594f1a0 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 13 Dec 2021 23:22:04 +0100 Subject: [PATCH 11/16] Update runtests.jl --- test/runtests.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index fea8b7c..4e75e77 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -58,14 +58,14 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) end @eval begin let - if "mod" == string($M.$f) - foo, bar = rand() + 13, rand() + 5 # make sure x/y is not integer + foo, bar = if "mod" == string($M.$f) + rand() + 13, rand() + 5 # make sure x/y is not integer else - foo, bar = rand(1.0:10.0), rand() + rand(1:10), rand() end dx, dy = $(derivs[1]), $(derivs[2]) if !isnan(dx) - @test dx ≈ $fd(z -> $M.$f(z, bar), foo) rtol=1e-9 atol=1e-9 + @test dx ≈ $fd(z -> $M.$f(z, bar), float(foo)) rtol=1e-9 atol=1e-9 end if !isnan(dy) @test dy ≈ $fd(z -> $M.$f(foo, z), bar) rtol=1e-9 atol=1e-9 From 9c9c54560042e27514a127d7b2772815014a26f8 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 14 Dec 2021 00:15:16 +0100 Subject: [PATCH 12/16] Update tests --- test/runtests.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4e75e77..d36e3d7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,7 +41,7 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) goo = rand() + $modifier @test $deriv ≈ $fd($M.$f, goo) rtol=1e-9 atol=1e-9 # test for 2pi functions - if "mod2pi" == string($M.$f) + if $(f === :mod2pi) goo = 4pi + $modifier @test NaN === $deriv end @@ -50,7 +50,7 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) elseif arity == 2 @test DiffRules.hasdiffrule(M, f, 2) derivs = DiffRules.diffrule(M, f, :foo, :bar) - fd = if f === :log + fd = if f in (:log, :^) # avoid singularities finitediff_forward else @@ -58,14 +58,16 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) end @eval begin let - foo, bar = if "mod" == string($M.$f) + foo, bar = if $(f === :mod) rand() + 13, rand() + 5 # make sure x/y is not integer + elseif $(f === :polygamma) + rand(1:10), rand() # only supports integers as first arguments else - rand(1:10), rand() + rand(), rand() end dx, dy = $(derivs[1]), $(derivs[2]) if !isnan(dx) - @test dx ≈ $fd(z -> $M.$f(z, bar), float(foo)) rtol=1e-9 atol=1e-9 + @test dx ≈ $fd(z -> $M.$f(z, bar), foo) rtol=1e-9 atol=1e-9 end if !isnan(dy) @test dy ≈ $fd(z -> $M.$f(foo, z), bar) rtol=1e-9 atol=1e-9 From 9a2acd322a5a1106b36f17064dd15a68b8b6bd8c Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 14 Dec 2021 22:44:22 +0100 Subject: [PATCH 13/16] Fix tests --- test/runtests.jl | 55 +++++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index d36e3d7..3edb1cf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,10 +6,7 @@ import SpecialFunctions, NaNMath, LogExpFunctions import Random Random.seed!(1) -# `forward_fdm` is less accurate than `central_fdm` but avoids singularities for -# e.g. `acoth`, `log`, `airyaix`, `airyaiprimex const finitediff = central_fdm(5, 1) -const finitediff_forward = forward_fdm(5, 1) @testset "DiffRules" begin @testset "check rules" begin @@ -21,28 +18,25 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) if arity == 1 @test DiffRules.hasdiffrule(M, f, 1) deriv = DiffRules.diffrule(M, f, :goo) - modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth) - 1.0 - elseif f === :log1mexp - -1.0 - elseif f === :log2mexp - -0.5 - else - 0.0 - end - fd = if f in (:acoth, :log, :airyaix, :airyaiprimex) - # avoid singularities - finitediff_forward - else - finitediff - end @eval begin let - goo = rand() + $modifier - @test $deriv ≈ $fd($M.$f, goo) rtol=1e-9 atol=1e-9 + goo = if $(f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) + # avoid singularities with finite differencing + rand() + 1.5 + elseif $(f in (:log, :airyaix, :airyaiprimex)) + # avoid singularities with finite differencing + rand() + 0.5 + elseif $(f === :log1mexp) + rand() - 1.0 + elseif $(f in (:log2mexp, :erfinv)) + rand() - 0.5 + else + rand() + end + @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-9 atol=1e-9 # test for 2pi functions if $(f === :mod2pi) - goo = 4pi + $modifier + goo = 4 * pi @test NaN === $deriv end end @@ -50,27 +44,30 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) elseif arity == 2 @test DiffRules.hasdiffrule(M, f, 2) derivs = DiffRules.diffrule(M, f, :foo, :bar) - fd = if f in (:log, :^) - # avoid singularities - finitediff_forward - else - finitediff - end @eval begin let foo, bar = if $(f === :mod) rand() + 13, rand() + 5 # make sure x/y is not integer elseif $(f === :polygamma) rand(1:10), rand() # only supports integers as first arguments + elseif $(f in (:bessely, :besselyx)) + # avoid singularities with finite differencing + rand(), rand() + 0.5 + elseif $(f === :log) + # avoid singularities with finite differencing + rand() + 1.5, rand() + elseif $(f === :^) + # avoid singularities with finite differencing + rand() + 0.5, rand() else rand(), rand() end dx, dy = $(derivs[1]), $(derivs[2]) if !isnan(dx) - @test dx ≈ $fd(z -> $M.$f(z, bar), foo) rtol=1e-9 atol=1e-9 + @test dx ≈ finitediff(z -> $M.$f(z, bar), foo) rtol=1e-9 atol=1e-9 end if !isnan(dy) - @test dy ≈ $fd(z -> $M.$f(foo, z), bar) rtol=1e-9 atol=1e-9 + @test dy ≈ finitediff(z -> $M.$f(foo, z), bar) rtol=1e-9 atol=1e-9 end end end From 39e70630bfbaa4db233120fa43e2bf14ed123257 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 20 Feb 2022 22:49:30 +0100 Subject: [PATCH 14/16] Update runtests.jl --- test/runtests.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 50d509e..79134fa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,4 @@ using DiffRules -using FiniteDifferences using Test using FiniteDifferences From 7cef376939362089a641a702a49c8809c6f44b81 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 5 May 2022 09:44:42 +0200 Subject: [PATCH 15/16] Add rule for `logabsgamma` --- src/rules.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/rules.jl b/src/rules.jl index e7b2b5c..fc57a88 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -65,6 +65,8 @@ :( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) ) @define_diffrule SpecialFunctions.loggamma(x) = :( SpecialFunctions.digamma($x) ) +@define_diffrule SpecialFunctions.logabsgamma(x) = + :( SpecialFunctions.digamma($x) ), :false @define_diffrule Base.abs(x) = :( $(_abs_deriv)($x) ) From 5264ac7e97ed977a9ae929d3f2bc96eafc10752b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 6 May 2022 01:09:23 +0200 Subject: [PATCH 16/16] Revert "Add rule for `logabsgamma`" This reverts commit 7cef376939362089a641a702a49c8809c6f44b81. --- src/rules.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index fc57a88..e7b2b5c 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -65,8 +65,6 @@ :( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) ) @define_diffrule SpecialFunctions.loggamma(x) = :( SpecialFunctions.digamma($x) ) -@define_diffrule SpecialFunctions.logabsgamma(x) = - :( SpecialFunctions.digamma($x) ), :false @define_diffrule Base.abs(x) = :( $(_abs_deriv)($x) )