-
Notifications
You must be signed in to change notification settings - Fork 65
Description
At the moment we only define ProjectTo for differential types. (With Ref being the first exception)
Consider a generic rrule(*, a::Number, b::Number) which uses ProjectTo to ensure that the tangents are in the right subspace, i.e. something like
julia> function rrule(*, a::Number, b::Number)
function times_pullback(dy)
da = dy * b
db = a * dy
return NoTangent(), ProjectTo(a)(da), ProjectTo(b)(db)
end
return a*b, times_pullback
endwhich looks perfectly reasonable.
However, if we create a type like
julia> struct PositiveReal <: Number
val::Float64
PositiveReal(x) = x > 0 ? new(x) : error("must be larger than 0")
endwhich is not its own differential type (the natural differential for this is a Float64) we are in trouble.
The problem is that since we only promise ProjectTo to project onto valid differential types, so we can't just define
julia> function ProjectTo(x::PositiveReal)
return ProjectTo(x.val)
endsince PositiveReal is not a valid differential type (does not have a zero). For similar reason we do not define ProjectTo(::Tuple), which would solve issues like #440.
The question is: should we loosen this requirement to only project onto differential types? By keeping the requirement we are restricting the use of ProjectTo to functions with arguments that are their own differentials. What bad things happen if we scratch this ProjectTo requirement?