diff --git a/keras/src/backend/tensorflow/rnn.py b/keras/src/backend/tensorflow/rnn.py index 06d450a18838..e49598d020bd 100644 --- a/keras/src/backend/tensorflow/rnn.py +++ b/keras/src/backend/tensorflow/rnn.py @@ -2,6 +2,10 @@ from keras.src import tree +_activations = None + +_ops = None + def rnn( step_function, @@ -507,8 +511,7 @@ def _do_gru_arguments_support_cudnn( use_bias, reset_after, ): - from keras.src import activations - from keras.src import ops + activations, ops = _get_activations_and_ops() return ( activation in (activations.tanh, tf.tanh, ops.tanh) @@ -526,8 +529,7 @@ def _do_lstm_arguments_support_cudnn( unroll, use_bias, ): - from keras.src import activations - from keras.src import ops + activations, ops = _get_activations_and_ops() return ( activation in (activations.tanh, tf.tanh, ops.tanh) @@ -969,3 +971,12 @@ def _cudnn_lstm( outputs = tf.expand_dims(last_output, axis=0 if time_major else 1) return (last_output, outputs, [h, c]) + +def _get_activations_and_ops(): + global _activations, _ops + # Only import once per module lifetime; thread-safe since imports are idempotent in Python + if _activations is None or _ops is None: + from keras.src import activations, ops + _activations = activations + _ops = ops + return _activations, _ops