44Creates a single block of ConvNeXt.
55([reference](https://arxiv.org/abs/2201.03545))
66
7- # Arguments:
7+ # Arguments
88
99 - `planes`: number of input channels.
1010 - `drop_path_rate`: Stochastic depth rate.
2727Creates the layers for a ConvNeXt model.
2828([reference](https://arxiv.org/abs/2201.03545))
2929
30- # Arguments:
30+ # Arguments
3131
3232 - `inchannels`: number of input channels.
3333 - `depths`: list with configuration for depth of each block
@@ -39,60 +39,53 @@ Creates the layers for a ConvNeXt model.
3939"""
4040function convnext (depths, planes; inchannels = 3 , drop_path_rate = 0.0 , λ = 1.0f-6 ,
4141 nclasses = 1000 )
42- @assert length (depths)== length (planes) " `planes` should have exactly one value for each block "
43-
42+ @assert length (depths) == length (planes)
43+ " `planes` should have exactly one value for each block "
4444 downsample_layers = []
4545 stem = Chain (Conv ((4 , 4 ), inchannels => planes[1 ]; stride = 4 ),
46- ChannelLayerNorm (planes[1 ]; ϵ = 1.0f-6 ))
46+ ChannelLayerNorm (planes[1 ]))
4747 push! (downsample_layers, stem)
4848 for m in 1 : (length (depths) - 1 )
49- downsample_layer = Chain (ChannelLayerNorm (planes[m]; ϵ = 1.0f-6 ),
49+ downsample_layer = Chain (ChannelLayerNorm (planes[m]),
5050 Conv ((2 , 2 ), planes[m] => planes[m + 1 ]; stride = 2 ))
5151 push! (downsample_layers, downsample_layer)
5252 end
53-
5453 stages = []
55- dp_rates = LinRange {Float32} ( 0.0 , drop_path_rate, sum (depths))
54+ dp_rates = linear_scheduler ( drop_path_rate; depth = sum (depths))
5655 cur = 0
57- for i in 1 : length (depths)
56+ for i in eachindex (depths)
5857 push! (stages, [convnextblock (planes[i], dp_rates[cur + j], λ) for j in 1 : depths[i]])
5958 cur += depths[i]
6059 end
61-
6260 backbone = collect (Iterators. flatten (Iterators. flatten (zip (downsample_layers, stages))))
6361 head = Chain (GlobalMeanPool (),
6462 MLUtils. flatten,
6563 LayerNorm (planes[end ]),
6664 Dense (planes[end ], nclasses))
67-
6865 return Chain (Chain (backbone), head)
6966end
7067
7168# Configurations for ConvNeXt models
72- convnext_configs = Dict (:tiny => Dict (:depths => [3 , 3 , 9 , 3 ],
73- :planes => [96 , 192 , 384 , 768 ]),
74- :small => Dict (:depths => [3 , 3 , 27 , 3 ],
75- :planes => [96 , 192 , 384 , 768 ]),
76- :base => Dict (:depths => [3 , 3 , 27 , 3 ],
77- :planes => [128 , 256 , 512 , 1024 ]),
78- :large => Dict (:depths => [3 , 3 , 27 , 3 ],
79- :planes => [192 , 384 , 768 , 1536 ]),
80- :xlarge => Dict (:depths => [3 , 3 , 27 , 3 ],
81- :planes => [256 , 512 , 1024 , 2048 ]))
69+ const CONVNEXT_CONFIGS = Dict (:tiny => ([3 , 3 , 9 , 3 ], [96 , 192 , 384 , 768 ]),
70+ :small => ([3 , 3 , 27 , 3 ], [96 , 192 , 384 , 768 ]),
71+ :base => ([3 , 3 , 27 , 3 ], [128 , 256 , 512 , 1024 ]),
72+ :large => ([3 , 3 , 27 , 3 ], [192 , 384 , 768 , 1536 ]),
73+ :xlarge => ([3 , 3 , 27 , 3 ], [256 , 512 , 1024 , 2048 ]))
8274
8375struct ConvNeXt
8476 layers:: Any
8577end
78+ @functor ConvNeXt
8679
8780"""
8881 ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000)
8982
9083Creates a ConvNeXt model.
9184([reference](https://arxiv.org/abs/2201.03545))
9285
93- # Arguments:
86+ # Arguments
9487
95- - `inchannels`: The number of channels in the input. The default value is 3.
88+ - `inchannels`: The number of channels in the input.
9689 - `drop_path_rate`: Stochastic depth rate.
9790 - `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
9891 - `nclasses`: number of output classes
@@ -101,16 +94,12 @@ See also [`Metalhead.convnext`](#).
10194"""
10295function ConvNeXt (mode:: Symbol = :base ; inchannels = 3 , drop_path_rate = 0.0 , λ = 1.0f-6 ,
10396 nclasses = 1000 )
104- @assert mode in keys (convnext_configs) " `size` must be one of $(collect (keys (convnext_configs))) "
105- depths = convnext_configs[mode][:depths ]
106- planes = convnext_configs[mode][:planes ]
107- layers = convnext (depths, planes; inchannels, drop_path_rate, λ, nclasses)
97+ _checkconfig (mode, keys (CONVNEXT_CONFIGS))
98+ layers = convnext (CONVNEXT_CONFIGS[mode]. .. ; inchannels, drop_path_rate, λ, nclasses)
10899 return ConvNeXt (layers)
109100end
110101
111102(m:: ConvNeXt )(x) = m. layers (x)
112103
113- @functor ConvNeXt
114-
115104backbone (m:: ConvNeXt ) = m. layers[1 ]
116105classifier (m:: ConvNeXt ) = m. layers[2 ]
0 commit comments