Skip to content

Commit 3f4cbb4

Browse files
Jake VanderPlaspax authors
authored andcommitted
Replace deprecated jax.tree_* functions with jax.tree.*
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 634095268
1 parent d903d68 commit 3f4cbb4

15 files changed

Lines changed: 191 additions & 141 deletions

praxis/layers/adapters_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_residual_adapter_tf_equivalent(self):
8787
theta = initial_vars['params'].copy()
8888
theta['layer_norm'] = theta['norm']
8989
del theta['norm']
90-
theta = jax.tree_map(np.array, theta)
90+
theta = jax.tree.map(np.array, theta)
9191
theta = py_utils.NestedMap.FromNestedDict(theta)
9292
theta.down_w = tf.convert_to_tensor(theta.down_w)
9393
theta.up_w = tf.convert_to_tensor(theta.up_w)

praxis/layers/attentions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2247,12 +2247,12 @@ def _vmap_on_broadcast_prefixes(
22472247

22482248
# Wraps fn with slicing on args_to_slice and broadcast_args_to_slice.
22492249
def _sliced_fn(layer, args, args_to_slice, broadcast_args_to_slice, states):
2250-
sliced = jax.tree_map(
2250+
sliced = jax.tree.map(
22512251
lambda x, d: self._slice_decode_chunk(x, chunk_id, d),
22522252
args_to_slice,
22532253
args_time_dims,
22542254
)
2255-
broadcast_sliced = jax.tree_map(
2255+
broadcast_sliced = jax.tree.map(
22562256
lambda x, d: self._slice_decode_chunk(x, chunk_id, d),
22572257
broadcast_args_to_slice,
22582258
broadcast_args_time_dims,

praxis/layers/convolutions_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def test_causal_conv2d_layer(self):
141141

142142
prng_key = jax.random.PRNGKey(seed=123)
143143
initial_vars = conv_layer.init(prng_key, inputs)
144-
initial_vars = jax.tree_map(jnp.ones_like, initial_vars)
144+
initial_vars = jax.tree.map(jnp.ones_like, initial_vars)
145145

146146
# Test odd length sequence.
147147
output = conv_layer.apply(initial_vars, inputs)
@@ -156,7 +156,7 @@ def test_causal_conv2d_layer(self):
156156

157157
prng_key = jax.random.PRNGKey(seed=123)
158158
initial_vars = conv_layer.init(prng_key, inputs)
159-
initial_vars = jax.tree_map(jnp.ones_like, initial_vars)
159+
initial_vars = jax.tree.map(jnp.ones_like, initial_vars)
160160

161161
output = conv_layer.apply(initial_vars, inputs)
162162
np_output = np.array(

praxis/layers/flax_adapter_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,14 @@ def test_mix_layer(self):
149149
def assert_learnable(x):
150150
assert not x.collections
151151

152-
jax.tree_map(assert_learnable, init_var_meta['params'])
152+
jax.tree.map(assert_learnable, init_var_meta['params'])
153153

154154
def assert_non_learnable(x):
155155
assert WeightHParamsCollection.NON_TRAINABLE in x.collections
156156
assert WeightHParamsCollection.REQUIRES_MEAN_SYNC in x.collections
157157

158-
jax.tree_map(assert_non_learnable, init_var_meta['batch_stats'])
159-
jax.tree_map(assert_non_learnable, init_var_meta['non_trainable'])
158+
jax.tree.map(assert_non_learnable, init_var_meta['batch_stats'])
159+
jax.tree.map(assert_non_learnable, init_var_meta['non_trainable'])
160160
init_vars = test_layer.init(prng_key, input_x)
161161
_ = test_layer.apply(init_vars, input_x, mutable=True)
162162
_ = test_layer.apply(

praxis/layers/frnn.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def reset_mask(
7878

7979

8080
def _sum_aux_loss(tree):
81-
return jax.tree_map(jnp.sum, tree)
81+
return jax.tree.map(jnp.sum, tree)
8282

8383

8484
class FRnn(base_layer.BaseLayer):
@@ -143,7 +143,7 @@ def __call__(
143143
state: Final state.
144144
"""
145145
# Make a copy of the input structure to avoid side-effect.
146-
inputs = jax.tree_map(lambda x: x, inputs)
146+
inputs = jax.tree.map(lambda x: x, inputs)
147147
assert hasattr(inputs, 'act')
148148
assert hasattr(inputs, 'padding')
149149
assert isinstance(self.cell, rnn_cell.BaseRnnCell)
@@ -159,7 +159,7 @@ def __call__(
159159
inputs.reset_mask = jnp.ones_like(inputs.padding, dtype=self.fprop_dtype)
160160

161161
if self.reverse:
162-
inputs = jax.tree_map(lambda x: jnp.flip(x, axis=[1]), inputs)
162+
inputs = jax.tree.map(lambda x: jnp.flip(x, axis=[1]), inputs)
163163

164164
if not state0:
165165
batch_size = inputs.padding.shape[0]
@@ -176,7 +176,7 @@ def body_fn(sub, state0, inputs):
176176
if self.is_initializing():
177177
# inputs has shape [b, t, dim] or [b, t, 1]
178178
# sliced_inputs has shape [b, dim] or [b, 1].
179-
sliced_inputs = jax.tree_map(lambda x: x[:, 1], inputs)
179+
sliced_inputs = jax.tree.map(lambda x: x[:, 1], inputs)
180180
_ = body_fn(self.cell, state0, sliced_inputs)
181181

182182
# NON_TRAINABLE variables are carried over from one iteration to another.
@@ -248,7 +248,7 @@ def init_states(self, batch_size: int) -> list[NestedMap]:
248248
def extend_step(
249249
self, inputs: NestedMap, state: list[NestedMap]
250250
) -> tuple[list[NestedMap], JTensor]:
251-
inputs = jax.tree_map(lambda x: x, inputs)
251+
inputs = jax.tree.map(lambda x: x, inputs)
252252
new_states = []
253253
for i in range(self.num_layers):
254254
new_state, act_i = self.frnn[i].extend_step(inputs, state[i])
@@ -275,7 +275,7 @@ def __call__(
275275
act: A tensor of [batch, time, dims]. The output.
276276
state: Final state.
277277
"""
278-
inputs = jax.tree_map(lambda x: x, inputs)
278+
inputs = jax.tree.map(lambda x: x, inputs)
279279

280280
if not state0:
281281
batch_size = inputs.padding.shape[0]
@@ -375,7 +375,7 @@ def __call__(
375375
state: Final state - a list of NestedMap of fwd and bwd states.
376376
"""
377377
# This is to create a copy.
378-
inputs = jax.tree_map(lambda x: x, inputs)
378+
inputs = jax.tree.map(lambda x: x, inputs)
379379

380380
if not state0:
381381
batch_size = inputs.padding.shape[0]
@@ -428,7 +428,7 @@ def __call__(
428428
state: Final state.
429429
"""
430430
# Make a copy of the input structure to avoid side-effect.
431-
inputs = jax.tree_map(lambda x: x, inputs)
431+
inputs = jax.tree.map(lambda x: x, inputs)
432432
assert hasattr(inputs, 'act')
433433
assert hasattr(inputs, 'padding')
434434
assert isinstance(self.cell, rnn_cell.BaseRnnCell)
@@ -444,7 +444,7 @@ def __call__(
444444
inputs.reset_mask = jnp.ones_like(inputs.padding, dtype=self.fprop_dtype)
445445

446446
if self.reverse:
447-
inputs = jax.tree_map(lambda x: jnp.flip(x, axis=[1]), inputs)
447+
inputs = jax.tree.map(lambda x: jnp.flip(x, axis=[1]), inputs)
448448

449449
if not state0:
450450
batch_size = inputs.padding.shape[0]
@@ -466,7 +466,7 @@ def body_fn(sub, state0, inputs):
466466
if self.is_initializing():
467467
# inputs has shape [b, t, dim] or [b, t, 1]
468468
# sliced_inputs has shape [b, dim] or [b, 1].
469-
sliced_inputs = jax.tree_map(lambda x: x[:, 1], inputs)
469+
sliced_inputs = jax.tree.map(lambda x: x[:, 1], inputs)
470470
# `body_fn` is sufficient to trigger PARAMS initialization.
471471
_ = body_fn(self.cell, state0, sliced_inputs)
472472

praxis/layers/frnn_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def test_frnn_lstm_cell(self, jax_cell_class, output_nonlinearity):
206206

207207
rnn_theta = {'params': theta['params']['cell']}
208208
ys = []
209-
cell_state = jax.tree_map(lambda x: x, state0)
209+
cell_state = jax.tree.map(lambda x: x, state0)
210210
for t in range(act_in.shape[1]):
211211
with base_layer.JaxContext.new_context():
212212
inputs_t = NestedMap(act=act_in[:, t], padding=padding[:, t])
@@ -414,7 +414,7 @@ def test_frnn_reset_cell_state(
414414

415415
rnn_theta = {'params': theta['params']['cell']}
416416
ys = []
417-
cell_state = jax.tree_map(lambda x: x, state0)
417+
cell_state = jax.tree.map(lambda x: x, state0)
418418
for t in range(act_in.shape[1]):
419419
with base_layer.JaxContext.new_context():
420420
inputs_t = NestedMap(

praxis/layers/multi_query_attention.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,12 +1319,16 @@ def _vmap_on_broadcast_prefixes(self, fn: attentions.FnOnDecodeStateChunk,
13191319

13201320
# Wraps fn with slicing on args_to_slice and broadcast_args_to_slice.
13211321
def _sliced_fn(layer, args, args_to_slice, broadcast_args_to_slice, states):
1322-
sliced = jax.tree_map(
1323-
lambda x, d: self._slice_decode_chunk(x, chunk_id, d), args_to_slice,
1324-
args_time_dims)
1325-
broadcast_sliced = jax.tree_map(
1322+
sliced = jax.tree.map(
13261323
lambda x, d: self._slice_decode_chunk(x, chunk_id, d),
1327-
broadcast_args_to_slice, broadcast_args_time_dims)
1324+
args_to_slice,
1325+
args_time_dims,
1326+
)
1327+
broadcast_sliced = jax.tree.map(
1328+
lambda x, d: self._slice_decode_chunk(x, chunk_id, d),
1329+
broadcast_args_to_slice,
1330+
broadcast_args_time_dims,
1331+
)
13281332
return fn(layer, args, sliced, broadcast_sliced, states)
13291333

13301334
broadcast_dim_sizes = self.get_decode_state(

0 commit comments

Comments
 (0)