Skip to content

Conversation

@gdalle
Copy link
Contributor

@gdalle gdalle commented Oct 8, 2025

@gdalle gdalle marked this pull request as draft October 8, 2025 11:48
@github-actions
Copy link
Contributor

github-actions bot commented Oct 8, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl
index 5994721a..734c3849 100644
--- a/lib/EnzymeCore/src/EnzymeCore.jl
+++ b/lib/EnzymeCore/src/EnzymeCore.jl
@@ -866,7 +866,7 @@ LargestChunk
 """
 function pick_chunksize end
 
-pick_chunksize(::SmallestChunk, a_or_n::Union{Integer,AbstractArray}) = Val(1)
+pick_chunksize(::SmallestChunk, a_or_n::Union{Integer, AbstractArray}) = Val(1)
 
 pick_chunksize(::LargestChunk, n::Integer) = Val(n)
 pick_chunksize(::LargestChunk, a::AbstractArray) = Val(length(a))  # allows inference on static arrays
@@ -874,13 +874,13 @@ pick_chunksize(::LargestChunk, a::AbstractArray) = Val(length(a))  # allows infe
 pick_chunksize(::AutoChunk, n::Integer) = Val(min(DEFAULT_CHUNK_SIZE, n))  # TODO: improve
 pick_chunksize(s::AutoChunk, a::AbstractArray) = pick_chunksize(s, length(a))
 
-function pick_chunksize(s::FixedChunk{C}, a_or_n::Union{Integer,AbstractArray}) where {C}
+function pick_chunksize(s::FixedChunk{C}, a_or_n::Union{Integer, AbstractArray}) where {C}
     check_length(s, a_or_n)
     return Val{C}()
 end
 
 function check_length(::FixedChunk{C}, n::Integer) where {C}
-    if n < C
+    return if n < C
         error("Chunk size $C is larger than length $n")
     end
 end
diff --git a/src/sugar.jl b/src/sugar.jl
index 6ab645d4..fb151c91 100644
--- a/src/sugar.jl
+++ b/src/sugar.jl
@@ -418,7 +418,7 @@ end
 const ExtendedChunkStrategy = Union{ChunkStrategy, Nothing, Val}
 
 # eats and returns a type because generated functions work on argument types
-get_strategy(chunk::Type{CS}) where {CS<:ChunkStrategy} = chunk
+get_strategy(chunk::Type{CS}) where {CS <: ChunkStrategy} = chunk
 
 function get_strategy(::Type{Nothing})
     Base.depwarn(
@@ -457,7 +457,7 @@ end
 @inline tupleconcat(x, y) = (x..., y...)
 @inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...)
 
-@generated function create_shadows(chunk::ExtendedChunkStrategy, x::X, vargs::Vararg{Any,N}) where {X, N}
+@generated function create_shadows(chunk::ExtendedChunkStrategy, x::X, vargs::Vararg{Any, N}) where {X, N}
     chunk_strategy = get_strategy(chunk)
     args =  Union{Symbol,Expr}[:x]
     tys =  Type[X]
@@ -614,9 +614,9 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
     f::F,
     x::ty_0,
     args::Vararg{Any,N};
-    chunk::ExtendedChunkStrategy = SmallestChunk(),
+        chunk::ExtendedChunkStrategy = SmallestChunk(),
     shadows::ST = create_shadows(chunk, x, args...),
-) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,StrongZero,ST, ty_0, N}
+    ) where {F, ReturnPrimal, ABI, ErrIfFuncWritten, RuntimeActivity, StrongZero, ST, ty_0, N}
 
     chunk_strategy = get_strategy(chunk)
 
@@ -827,10 +827,10 @@ end
     mode::ReverseMode{ReturnPrimal},
     RT::RType,
     n_outs::OutType,
-    chunk::ExtendedChunkStrategy,
+        chunk::ExtendedChunkStrategy,
     f::F,
     xs::Vararg{Any, Nargs}
-) where {ReturnPrimal,RType, F,Nargs,OutType}
+    ) where {ReturnPrimal, RType, F, Nargs, OutType}
     chunk_strategy = get_strategy(chunk)
     fty = if f <: Enzyme.Annotation
         f.parameters[1]
@@ -1255,8 +1255,8 @@ this function will retun an AbstractArray of shape `size(output)` of values of t
     f::F,
     xs::Vararg{Any, Nargs};
     n_outs::OutType = nothing,
-    chunk::ExtendedChunkStrategy = SmallestChunk(),
-) where {F,Nargs, OutType}
+        chunk::ExtendedChunkStrategy = SmallestChunk(),
+    ) where {F, Nargs, OutType}
 
     fty = if f <: Enzyme.Annotation
         f.parameters[1]
diff --git a/test/sugar.jl b/test/sugar.jl
index 7d465984..ab265851 100644
--- a/test/sugar.jl
+++ b/test/sugar.jl
@@ -672,12 +672,12 @@ fchunk2(x) = map(sin, x) + map(cos, reverse(x))
 
 @testset "Chunking strategies" begin
     @testset "ChunkedOneHot" begin
-        @test Enzyme.chunkedonehot(ones(3), Enzyme.SmallestChunk()) isa Tuple{NTuple{1},NTuple{1},NTuple{1}}
+        @test Enzyme.chunkedonehot(ones(3), Enzyme.SmallestChunk()) isa Tuple{NTuple{1}, NTuple{1}, NTuple{1}}
         @test Enzyme.chunkedonehot(ones(30), Enzyme.LargestChunk()) isa Tuple{NTuple{30}}
         @test Enzyme.chunkedonehot(ones(10), Enzyme.LargestChunk()) isa Tuple{NTuple{10}}
         @test Enzyme.chunkedonehot(ones(30), Enzyme.LargestChunk()) isa Tuple{NTuple{30}}
-        @test Enzyme.chunkedonehot(ones(10), Enzyme.FixedChunk{4}()) isa Tuple{NTuple{4},NTuple{4},NTuple{2}}
-        @test Enzyme.chunkedonehot(ones(10), Enzyme.FixedChunk{5}()) isa Tuple{NTuple{5},NTuple{5}}
+        @test Enzyme.chunkedonehot(ones(10), Enzyme.FixedChunk{4}()) isa Tuple{NTuple{4}, NTuple{4}, NTuple{2}}
+        @test Enzyme.chunkedonehot(ones(10), Enzyme.FixedChunk{5}()) isa Tuple{NTuple{5}, NTuple{5}}
         @test Enzyme.chunkedonehot(ones(10), Enzyme.AutoChunk()) isa Tuple{NTuple{10}}
         @test Enzyme.chunkedonehot(ones(30), Enzyme.AutoChunk()) isa Tuple{NTuple{16}, NTuple{14}}
         @test Enzyme.chunkedonehot(ones(30), Enzyme.AutoChunk()) isa Tuple{NTuple{16}, NTuple{14}}

@gdalle gdalle marked this pull request as ready for review October 29, 2025 09:00
@gdalle gdalle requested a review from wsmoses October 29, 2025 09:00
@gdalle
Copy link
Contributor Author

gdalle commented Oct 29, 2025

@wsmoses I think this is a good first step, and I'm not able to do the reverse-mode Jacobian fix on my own. Can someone help? Also, should we make create_shadows public and documented?

@gdalle
Copy link
Contributor Author

gdalle commented Oct 30, 2025

@wsmoses this is good to go, I fixed the reverse mode too

"""
struct AutoChunk <: ChunkStrategy end

const DEFAULT_CHUNK_SIZE = 16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for type reasons, should we make this Val(16), otherwise it wont necessarily be a constant within pick_chunksize?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so because

  1. The constant is correctly inferred when returned as a Val in a function
julia> const C = 3
3

julia> f() = Val(C)
f (generic function with 1 method)

julia> @code_warntype f()
MethodInstance for f()
  from f() @ Main REPL[2]:1
Arguments
  #self#::Core.Const(Main.f)
Body::Val{3}
1%1 = Main.Val::Core.Const(Val)
│   %2 = (%1)(Main.C)::Core.Const(Val{3}())
└──      return %2
  1. It doesn't matter anyway because pick_chunksize(::AutoChunk, a) will be type-unstable since it depends on the array size:
julia> g(n) = Val(min(n, C))
g (generic function with 1 method)

julia> @code_warntype g(4)
MethodInstance for g(::Int64)
  from g(n) @ Main REPL[4]:1
Arguments
  #self#::Core.Const(Main.g)
  n::Int64
Body::Val
1%1 = Main.Val::Core.Const(Val)
│   %2 = Main.min::Core.Const(min)
│   %3 = (%2)(n, Main.C)::Int64%4 = (%1)(%3)::Val
└──      return %4

src/sugar.jl Outdated

"""
gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing)
gradient(::ForwardMode, f, x, args...; chunk=nothing, shadows=create_shadows(chunk, x, args...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should change chunk=nothing, to the relevant correct explicit default.

we should also not support val/nothing inside of here and isntead add a deprecated method (or perhaps first check in the expr) if its one of the legacy methods and mark as deprecated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in the latest commit.

An issue with the current code is that the deprecation warning will only be visible at the first function call, since that is the only time where the generating function is actually generated:

julia> jacobian(Forward, copy, ones(2); chunk=nothing)
┌ Warning: The `chunk=nothing` configuration will be deprecated in a future release. Please use `chunk=SmallestChunk()` instead.
│   caller = #s719#135 at sugar.jl:461 [inlined]
└ @ Core ~/Documents/GitHub/Julia/Enzyme.jl/src/sugar.jl:461
┌ Warning: The `chunk=nothing` configuration will be deprecated in a future release. Please use `chunk=SmallestChunk()` instead.
│   caller = #s717#137 at sugar.jl:621 [inlined]
└ @ Core ~/Documents/GitHub/Julia/Enzyme.jl/src/sugar.jl:621
([1.0 0.0; 0.0 1.0],)

julia> jacobian(Forward, copy, ones(2); chunk=nothing)
([1.0 0.0; 0.0 1.0],)

Not sure whether that's an issue or not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we move it to a permanent warning at every call, we should probably add tests for it too

@gdalle
Copy link
Contributor Author

gdalle commented Oct 30, 2025

Do you want this to error?

jacobian(Forward, copy, ones(2); chunk=FixedSize{3}())

Currently Enzyme doesn't mind when the chunk size is larger than the input, I'm not sure what the expected behavior there is (and it's hard to deduce from the code). For comparison:

julia> cfg = ForwardDiff.JacobianConfig(copy, ones(2), ForwardDiff.Chunk{3}());

julia> ForwardDiff.jacobian(copy, ones(2), cfg, Val(true))
ERROR: ArgumentError: chunk size cannot be greater than ForwardDiff.structural_length(x) (3 > 2)

@codecov
Copy link

codecov bot commented Oct 30, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 53.47%. Comparing base (107b327) to head (dc5ff05).
⚠️ Report is 9 commits behind head on main.

❗ There is a different number of reports uploaded between BASE (107b327) and HEAD (dc5ff05). Click for more details.

HEAD has 29 uploads less than BASE
Flag BASE (107b327) HEAD (dc5ff05)
34 5
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2659       +/-   ##
===========================================
- Coverage   72.61%   53.47%   -19.15%     
===========================================
  Files          58       12       -46     
  Lines       18746     1210    -17536     
===========================================
- Hits        13613      647    -12966     
+ Misses       5133      563     -4570     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants