@@ -2,7 +2,9 @@ istraining() = false
22
33ChainRulesCore. rrule (:: typeof (istraining)) = true , _ -> (NoTangent (),)
44
5- _isactive (m) = isnothing (m. active) ? istraining () : m. active
5+ _isactive (m) = isnothing (m. active) ? istraining () : Bool (m. active)
6+
7+ ChainRulesCore. @non_differentiable _isactive (:: Any )
68
79_dropout_shape (s, :: Colon ) = size (s)
810_dropout_shape (s, dims) = tuple ((i ∉ dims ? 1 : si for (i, si) ∈ enumerate (size (s))). .. )
@@ -31,26 +33,50 @@ automatically managed using the [`Dropout`](@ref) layer instead of the
3133
3234The [`Dropout`](@ref) layer is what you should use in most scenarios.
3335"""
34- function dropout (rng, x, p; dims= :, active:: Bool = true )
35- active || return x
36- y = dropout_mask (rng, x, p, dims= dims)
37- return x .* y
38- end
36+ dropout (rng, x, p; dims= :, active:: Bool = true ) = _dropout (rng, x, p; dims, active)
3937dropout (x, p; kwargs... ) = dropout (rng_from_array (x), x, p; kwargs... )
4038
41- dropout_mask (rng:: CUDA.RNG , x:: CuArray , p; kwargs... ) = _dropout_mask (rng, x, p; kwargs... )
42- dropout_mask (rng, x:: CuArray , p; kwargs... ) =
39+ # Internal function without kwargs to keep Zygote generated code type stable
40+ function _dropout (rng, x, p, dims, active)
41+ mask = active ? dropout_mask (rng, x, p, dims= dims) : nothing
42+ return _apply_mask (x, mask)
43+ end
44+
45+ function ChainRulesCore. rrule (:: typeof (_dropout), rng, x, p, dims, active)
46+ mask = active ? dropout_mask (rng, x, p, dims= dims) : nothing
47+ MT = Core. Compiler. return_type (dropout_mask, Tuple{typeof (rng),typeof (x),typeof (p),typeof (dims)})
48+ project_x = ProjectTo (x)
49+ return _apply_mask (x, mask), DropoutPullback {MT,typeof(project_x)} (mask, project_x)
50+ end
51+
52+ # Also needed for type stability. Otherwise inference lifts the Union into a
53+ # Union{Pullback{Nothing}, Pullback{AbstractArray}}
54+ struct DropoutPullback{M<: AbstractArray ,P<: ProjectTo{AbstractArray} }
55+ mask:: Union{Nothing,M}
56+ project:: P
57+ end
58+
59+ function (pb:: DropoutPullback )(dy)
60+ dx = pb. project (_apply_mask (dy, pb. mask))
61+ return (NoTangent (), NoTangent (), dx, NoTangent ())
62+ end
63+
64+ _apply_mask (x, :: Nothing ) = x
65+ _apply_mask (x, mask) = x .* mask
66+
67+ dropout_mask (rng:: CUDA.RNG , x:: CuArray , p, dims) = _dropout_mask (rng, x, p, dims)
68+ dropout_mask (rng, x:: CuArray , p, dims) =
4369 throw (ArgumentError (" x isa CuArray, but rng isa $(typeof (rng)) . dropout_mask only support CUDA.RNG for CuArrays." ))
44- dropout_mask (rng, x, p; kwargs ... ) = _dropout_mask (rng, x, p; kwargs ... )
45- function _dropout_mask (rng, x, p; dims= : )
70+ dropout_mask (rng, x, p, dims ) = _dropout_mask (rng, x, p, dims )
71+ function _dropout_mask (rng, x, p, dims)
4672 realfptype = float (real (eltype (x)))
4773 y = rand! (rng, similar (x, realfptype, _dropout_shape (x, dims)))
4874 y .= _dropout_kernel .(y, p, 1 - p)
4975 return y
5076end
5177
5278# TODO move this to NNlib
53- ChainRulesCore. @non_differentiable dropout_mask (:: Any , :: Any , :: Any )
79+ ChainRulesCore. @non_differentiable dropout_mask (:: Any , :: Any , :: Any , :: Any )
5480
5581"""
5682 Dropout(p; dims=:, rng = rng_from_array())
82108@functor Dropout
83109trainable (a:: Dropout ) = (;)
84110
85- function (a:: Dropout )(x)
86- _isactive (a) || return x
87- return dropout (a. rng, x, a. p; dims= a. dims, active= true )
88- end
111+ (a:: Dropout )(x) = _dropout (a. rng, x, a. p, a. dims, _isactive (a))
89112
90113testmode! (m:: Dropout , mode= true ) =
91114 (m. active = (isnothing (mode) || mode == :auto ) ? nothing : ! mode; m)
@@ -172,7 +195,7 @@ LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]
172195
173196@functor LayerNorm
174197
175- (a:: LayerNorm )(x) = a. diag (normalise (x, dims = 1 : length (a. size), ϵ = a. ϵ))
198+ (a:: LayerNorm )(x) = a. diag (_normalize (x, 1 : length (a. size), a. ϵ))
176199
177200function Base. show (io:: IO , l:: LayerNorm )
178201 print (io, " LayerNorm(" , join (l. size, " , " ))
0 commit comments