Skip to content

Commit 43ae832

Browse files
Merge pull request #7 from ToucheSir/re-all-fields
Reconstruct with all fieldnames
2 parents fa1a212 + 17487ff commit 43ae832

File tree

3 files changed

+57
-15
lines changed

3 files changed

+57
-15
lines changed

README.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,25 @@ Foo(1.0, [1.0, 2.0, 3.0])
4949

5050
`functor` returns the parts of the object that can be inspected, as well as a `re` function that takes those values and restructures them back into an object of the original type.
5151

52-
For a discussion regarding implementing functors for which only a subset of the fields are "seen" by `functor`, see [here](https://github.com/FluxML/Functors.jl/issues/3#issuecomment-626747663).
52+
To include only certain fields, pass a tuple of field names to `@functor`:
53+
54+
```julia
55+
julia> struct Baz
56+
x
57+
y
58+
end
59+
60+
julia> @functor Baz (x,)
61+
62+
julia> model = Baz(1, 2)
63+
Baz(1, 2)
64+
65+
julia> fmap(float, model)
66+
Baz(1.0, 2)
67+
```
68+
69+
Any field not in the list will not be returned by `functor` and passed through as-is during reconstruction. This is done by invoking the default constructor, so structs that define custom inner constructors are expected to provide one that acts like the default.
70+
71+
It is also possible to implement `functor` by hand when greater flexibility is required. See [here](https://github.com/FluxML/Functors.jl/issues/3) for an example.
5372

5473
For a discussion regarding the need for a `cache` in the implementation of `fmap`, see [here](https://github.com/FluxML/Functors.jl/issues/2).

src/functor.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,14 @@ functor(::Type{<:AbstractArray}, x) = x, y -> y
88
functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x
99

1010
function makefunctor(m::Module, T, fs = fieldnames(T))
11+
yᵢ = 0
12+
escargs = map(fieldnames(T)) do f
13+
f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f)
14+
end
15+
escfs = [:($f=x.$f) for f in fs]
16+
1117
@eval m begin
12-
$Functors.functor(::Type{<:$T}, x) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...)
18+
$Functors.functor(::Type{<:$T}, x) = ($(escfs...),), y -> $T($(escargs...))
1319
end
1420
end
1521

test/basics.jl

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,38 @@
11
using Functors, Test
22

3-
struct Foo
4-
x
5-
y
6-
end
3+
@testset "Nested" begin
4+
struct Foo
5+
x
6+
y
7+
end
78

8-
@functor Foo
9+
@functor Foo
910

10-
struct Bar
11-
x
12-
end
11+
struct Bar
12+
x
13+
end
1314

14-
@functor Bar
15+
@functor Bar
1516

16-
model = Bar(Foo(1, [1, 2, 3]))
17+
model = Bar(Foo(1, [1, 2, 3]))
1718

18-
model′ = fmap(float, model)
19+
model′ = fmap(float, model)
1920

20-
@test model.x.y == model′.x.y
21-
@test model′.x.y isa Vector{Float64}
21+
@test model.x.y == model′.x.y
22+
@test model′.x.y isa Vector{Float64}
23+
end
24+
25+
@testset "Property list" begin
26+
struct Baz
27+
x
28+
y
29+
z
30+
end
31+
32+
@functor Baz (y,)
33+
34+
model = Baz(1, 2, 3)
35+
model′ = fmap(x -> 2x, model)
36+
37+
@test (model′.x, model′.y, model′.z) == (1, 4, 3)
38+
end

0 commit comments

Comments
 (0)