Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions keras/src/backend/tensorflow/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from keras.src import tree

_activations = None

_ops = None


def rnn(
step_function,
Expand Down Expand Up @@ -507,8 +511,7 @@ def _do_gru_arguments_support_cudnn(
use_bias,
reset_after,
):
from keras.src import activations
from keras.src import ops
activations, ops = _get_activations_and_ops()

return (
activation in (activations.tanh, tf.tanh, ops.tanh)
Expand All @@ -526,8 +529,7 @@ def _do_lstm_arguments_support_cudnn(
unroll,
use_bias,
):
from keras.src import activations
from keras.src import ops
activations, ops = _get_activations_and_ops()

return (
activation in (activations.tanh, tf.tanh, ops.tanh)
Expand Down Expand Up @@ -969,3 +971,12 @@ def _cudnn_lstm(
outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)

return (last_output, outputs, [h, c])

def _get_activations_and_ops():
global _activations, _ops
# Only import once per module lifetime; thread-safe since imports are idempotent in Python
if _activations is None or _ops is None:
from keras.src import activations, ops
_activations = activations
_ops = ops
return _activations, _ops