@@ -78,7 +78,7 @@ def reset_mask(
7878
7979
8080def _sum_aux_loss (tree ):
81- return jax .tree_map (jnp .sum , tree )
81+ return jax .tree . map (jnp .sum , tree )
8282
8383
8484class FRnn (base_layer .BaseLayer ):
@@ -143,7 +143,7 @@ def __call__(
143143 state: Final state.
144144 """
145145 # Make a copy of the input structure to avoid side-effect.
146- inputs = jax .tree_map (lambda x : x , inputs )
146+ inputs = jax .tree . map (lambda x : x , inputs )
147147 assert hasattr (inputs , 'act' )
148148 assert hasattr (inputs , 'padding' )
149149 assert isinstance (self .cell , rnn_cell .BaseRnnCell )
@@ -159,7 +159,7 @@ def __call__(
159159 inputs .reset_mask = jnp .ones_like (inputs .padding , dtype = self .fprop_dtype )
160160
161161 if self .reverse :
162- inputs = jax .tree_map (lambda x : jnp .flip (x , axis = [1 ]), inputs )
162+ inputs = jax .tree . map (lambda x : jnp .flip (x , axis = [1 ]), inputs )
163163
164164 if not state0 :
165165 batch_size = inputs .padding .shape [0 ]
@@ -176,7 +176,7 @@ def body_fn(sub, state0, inputs):
176176 if self .is_initializing ():
177177 # inputs has shape [b, t, dim] or [b, t, 1]
178178 # sliced_inputs has shape [b, dim] or [b, 1].
179- sliced_inputs = jax .tree_map (lambda x : x [:, 1 ], inputs )
179+ sliced_inputs = jax .tree . map (lambda x : x [:, 1 ], inputs )
180180 _ = body_fn (self .cell , state0 , sliced_inputs )
181181
182182 # NON_TRAINABLE variables are carried over from one iteration to another.
@@ -248,7 +248,7 @@ def init_states(self, batch_size: int) -> list[NestedMap]:
248248 def extend_step (
249249 self , inputs : NestedMap , state : list [NestedMap ]
250250 ) -> tuple [list [NestedMap ], JTensor ]:
251- inputs = jax .tree_map (lambda x : x , inputs )
251+ inputs = jax .tree . map (lambda x : x , inputs )
252252 new_states = []
253253 for i in range (self .num_layers ):
254254 new_state , act_i = self .frnn [i ].extend_step (inputs , state [i ])
@@ -275,7 +275,7 @@ def __call__(
275275 act: A tensor of [batch, time, dims]. The output.
276276 state: Final state.
277277 """
278- inputs = jax .tree_map (lambda x : x , inputs )
278+ inputs = jax .tree . map (lambda x : x , inputs )
279279
280280 if not state0 :
281281 batch_size = inputs .padding .shape [0 ]
@@ -375,7 +375,7 @@ def __call__(
375375 state: Final state - a list of NestedMap of fwd and bwd states.
376376 """
377377 # This is to create a copy.
378- inputs = jax .tree_map (lambda x : x , inputs )
378+ inputs = jax .tree . map (lambda x : x , inputs )
379379
380380 if not state0 :
381381 batch_size = inputs .padding .shape [0 ]
@@ -428,7 +428,7 @@ def __call__(
428428 state: Final state.
429429 """
430430 # Make a copy of the input structure to avoid side-effect.
431- inputs = jax .tree_map (lambda x : x , inputs )
431+ inputs = jax .tree . map (lambda x : x , inputs )
432432 assert hasattr (inputs , 'act' )
433433 assert hasattr (inputs , 'padding' )
434434 assert isinstance (self .cell , rnn_cell .BaseRnnCell )
@@ -444,7 +444,7 @@ def __call__(
444444 inputs .reset_mask = jnp .ones_like (inputs .padding , dtype = self .fprop_dtype )
445445
446446 if self .reverse :
447- inputs = jax .tree_map (lambda x : jnp .flip (x , axis = [1 ]), inputs )
447+ inputs = jax .tree . map (lambda x : jnp .flip (x , axis = [1 ]), inputs )
448448
449449 if not state0 :
450450 batch_size = inputs .padding .shape [0 ]
@@ -466,7 +466,7 @@ def body_fn(sub, state0, inputs):
466466 if self .is_initializing ():
467467 # inputs has shape [b, t, dim] or [b, t, 1]
468468 # sliced_inputs has shape [b, dim] or [b, 1].
469- sliced_inputs = jax .tree_map (lambda x : x [:, 1 ], inputs )
469+ sliced_inputs = jax .tree . map (lambda x : x [:, 1 ], inputs )
470470 # `body_fn` is sufficient to trigger PARAMS initialization.
471471 _ = body_fn (self .cell , state0 , sliced_inputs )
472472
0 commit comments