Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Dec 17, 2025

📄 67% (0.67x) speedup for cudnn_ok in keras/src/backend/tensorflow/rnn.py

⏱️ Runtime : 827 microseconds 494 microseconds (best of 5 runs)

📝 Explanation and details

The optimization replaces repeated local imports with module-level caching, delivering a 67% speedup by eliminating expensive import overhead in hot path functions.

Key optimization: The original code imports keras.src.activations and keras.src.ops on every function call within _do_gru_arguments_support_cudnn and _do_lstm_arguments_support_cudnn. The optimized version introduces a cached import mechanism using module-level globals that import these modules only once per process lifetime.

Performance impact: Line profiler data shows the import statements consumed 75% of execution time in the original functions. The optimized version reduces this to around 50% by eliminating redundant import lookups, with the cached approach showing ~3x faster per-hit times for the import operations.

Why this matters in practice: The function references show cudnn_ok is called during GRU layer initialization, specifically when determining CUDNN compatibility. Since RNN layers are frequently instantiated during model construction and potentially during training loops, this import overhead accumulates significantly. The test results demonstrate consistent 47-73% speedups across various parameter combinations, with particularly strong gains (61-67%) in the large-scale parametric tests that simulate real-world usage patterns.

Thread safety: The optimization is safe because Python's import system is inherently thread-safe and the global caching pattern using lazy initialization (if _activations is None) is a standard Python idiom for module-level optimization.

This optimization is especially beneficial for workloads that create multiple RNN layers or repeatedly check CUDNN compatibility, transforming what was previously an import-bound operation into a fast cached lookup.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 330 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest
from keras.src.backend.tensorflow.rnn import cudnn_ok

# --- Minimal stubs for activations and ops to allow unit tests to run ---
class DummyFn:
    """A dummy function object to simulate activation function identity."""
    def __call__(self, x):
        return x

# Simulate keras.src.activations and keras.src.ops modules
class activations:
    tanh = DummyFn()
    sigmoid = DummyFn()

class ops:
    tanh = DummyFn()
    sigmoid = DummyFn()

# Simulate tensorflow functions
class tf:
    @staticmethod
    def tanh(x): return x
    @staticmethod
    def sigmoid(x): return x
    class config:
        @staticmethod
        def list_logical_devices(device_type):
            # This will be monkeypatched in tests to simulate GPU/CPU
            return []
from keras.src.backend.tensorflow.rnn import cudnn_ok

@pytest.fixture
def gpu_available(monkeypatch):
    # Simulate GPU present
    monkeypatch.setattr(tf.config, "list_logical_devices", lambda device_type: ["GPU:0"])
    yield

# --- Basic Test Cases ---

def test_lstm_cudnn_ok_true(gpu_available):
    # All arguments are correct for LSTM and GPU is available
    codeflash_output = cudnn_ok(
        activation=activations.tanh,
        recurrent_activation=activations.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=None,
    ) # 7.31μs -> 6.45μs (13.3% faster)

def test_gru_cudnn_ok_true(gpu_available):
    # All arguments are correct for GRU and GPU is available
    codeflash_output = cudnn_ok(
        activation=tf.tanh,
        recurrent_activation=tf.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=True,
    ) # 5.26μs -> 3.48μs (51.0% faster)

def test_lstm_cudnn_ok_false_cpu():
    # All arguments correct, but no GPU available
    codeflash_output = cudnn_ok(
        activation=activations.tanh,
        recurrent_activation=activations.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=None,
    ) # 4.12μs -> 2.78μs (48.3% faster)

def test_gru_cudnn_ok_false_cpu():
    # All arguments correct, but no GPU available
    codeflash_output = cudnn_ok(
        activation=ops.tanh,
        recurrent_activation=ops.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=True,
    ) # 4.40μs -> 2.62μs (67.9% faster)

def test_lstm_wrong_activation(gpu_available):
    # Wrong activation function for LSTM
    def relu(x): return x
    codeflash_output = cudnn_ok(
        activation=relu,  # Not tanh
        recurrent_activation=activations.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=None,
    ) # 4.42μs -> 2.63μs (67.9% faster)

def test_gru_wrong_recurrent_activation(gpu_available):
    # Wrong recurrent activation for GRU
    def relu(x): return x
    codeflash_output = cudnn_ok(
        activation=activations.tanh,
        recurrent_activation=relu,  # Not sigmoid
        unroll=False,
        use_bias=True,
        reset_after=True,
    ) # 4.22μs -> 2.77μs (52.4% faster)

def test_lstm_reset_after_ignored(gpu_available):
    # reset_after is ignored for LSTM (should still work if None)
    codeflash_output = cudnn_ok(
        activation=activations.tanh,
        recurrent_activation=activations.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=None,
    ) # 4.18μs -> 2.51μs (66.1% faster)

def test_lstm_extra_args_ignored(gpu_available):
    # Extra arguments (should be ignored)
    codeflash_output = cudnn_ok(
        activation=tf.tanh,
        recurrent_activation=tf.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=None,
    ) # 3.87μs -> 2.63μs (46.7% faster)

def test_gru_activation_identity(gpu_available):
    # Activation is a different object with same behavior (should fail)
    class TanhLike:
        def __call__(self, x): return x
    codeflash_output = cudnn_ok(
        activation=TanhLike(),
        recurrent_activation=activations.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=True,
    ) # 4.39μs -> 2.53μs (73.8% faster)

def test_gru_recurrent_activation_identity(gpu_available):
    # recurrent_activation is a different object with same behavior (should fail)
    class SigmoidLike:
        def __call__(self, x): return x
    codeflash_output = cudnn_ok(
        activation=activations.tanh,
        recurrent_activation=SigmoidLike(),
        unroll=False,
        use_bias=True,
        reset_after=True,
    ) # 4.30μs -> 2.65μs (61.8% faster)

def test_lstm_activation_is_ops_tanh(gpu_available):
    # Accept ops.tanh as valid activation
    codeflash_output = cudnn_ok(
        activation=ops.tanh,
        recurrent_activation=ops.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=None,
    ) # 4.28μs -> 2.71μs (57.9% faster)

def test_gru_activation_is_ops_tanh(gpu_available):
    # Accept ops.tanh as valid activation for GRU
    codeflash_output = cudnn_ok(
        activation=ops.tanh,
        recurrent_activation=ops.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=True,
    ) # 3.76μs -> 2.55μs (47.4% faster)

def test_lstm_activation_is_tf_tanh(gpu_available):
    # Accept tf.tanh as valid activation
    codeflash_output = cudnn_ok(
        activation=tf.tanh,
        recurrent_activation=tf.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=None,
    ) # 4.14μs -> 2.52μs (64.1% faster)

def test_gru_activation_is_tf_tanh(gpu_available):
    # Accept tf.tanh as valid activation for GRU
    codeflash_output = cudnn_ok(
        activation=tf.tanh,
        recurrent_activation=tf.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=True,
    ) # 4.29μs -> 2.50μs (71.4% faster)

# --- Large Scale Test Cases ---

@pytest.mark.parametrize("activation", [activations.tanh, tf.tanh, ops.tanh])
@pytest.mark.parametrize("recurrent_activation", [activations.sigmoid, tf.sigmoid, ops.sigmoid])
def test_lstm_all_valid_combinations_large(activation, recurrent_activation, gpu_available):
    # Test all valid activation/recurrent_activation combinations for LSTM
    codeflash_output = cudnn_ok(
        activation=activation,
        recurrent_activation=recurrent_activation,
        unroll=False,
        use_bias=True,
        reset_after=None,
    ) # 40.1μs -> 24.7μs (62.6% faster)

@pytest.mark.parametrize("activation", [activations.tanh, tf.tanh, ops.tanh])
@pytest.mark.parametrize("recurrent_activation", [activations.sigmoid, tf.sigmoid, ops.sigmoid])
def test_gru_all_valid_combinations_large(activation, recurrent_activation, gpu_available):
    # Test all valid activation/recurrent_activation combinations for GRU
    codeflash_output = cudnn_ok(
        activation=activation,
        recurrent_activation=recurrent_activation,
        unroll=False,
        use_bias=True,
        reset_after=True,
    ) # 38.7μs -> 24.0μs (61.6% faster)

@pytest.mark.parametrize("i", range(50))
def test_lstm_many_calls_scalability(i, gpu_available):
    # Scalability: call cudnn_ok many times with valid arguments
    codeflash_output = cudnn_ok(
        activation=activations.tanh,
        recurrent_activation=activations.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=None,
    ) # 211μs -> 130μs (62.0% faster)

@pytest.mark.parametrize("i", range(50))
def test_gru_many_calls_scalability(i, gpu_available):
    # Scalability: call cudnn_ok many times with valid arguments for GRU
    codeflash_output = cudnn_ok(
        activation=activations.tanh,
        recurrent_activation=activations.sigmoid,
        unroll=False,
        use_bias=True,
        reset_after=True,
    ) # 214μs -> 132μs (61.5% faster)

def test_lstm_all_invalid_combinations_large(gpu_available):
    # Try all combinations of invalid arguments for LSTM
    activations_list = [activations.tanh, tf.tanh, ops.tanh, lambda x: x]
    recurrent_activations_list = [activations.sigmoid, tf.sigmoid, ops.sigmoid, lambda x: x]
    for activation in activations_list:
        for recurrent_activation in recurrent_activations_list:
            # Only the first 3 activations and recurrent_activations are valid
            valid = (
                activation in (activations.tanh, tf.tanh, ops.tanh)
                and recurrent_activation in (activations.sigmoid, tf.sigmoid, ops.sigmoid)
            )
            for unroll in [False, True]:
                for use_bias in [True, False]:
                    codeflash_output = cudnn_ok(
                        activation=activation,
                        recurrent_activation=recurrent_activation,
                        unroll=unroll,
                        use_bias=use_bias,
                        reset_after=None,
                    ); result = codeflash_output
                    if valid and not unroll and use_bias:
                        pass
                    else:
                        pass

def test_gru_all_invalid_combinations_large(gpu_available):
    # Try all combinations of invalid arguments for GRU
    activations_list = [activations.tanh, tf.tanh, ops.tanh, lambda x: x]
    recurrent_activations_list = [activations.sigmoid, tf.sigmoid, ops.sigmoid, lambda x: x]
    for activation in activations_list:
        for recurrent_activation in recurrent_activations_list:
            # Only the first 3 activations and recurrent_activations are valid
            valid = (
                activation in (activations.tanh, tf.tanh, ops.tanh)
                and recurrent_activation in (activations.sigmoid, tf.sigmoid, ops.sigmoid)
            )
            for unroll in [False, True]:
                for use_bias in [True, False]:
                    for reset_after in [True, False]:
                        codeflash_output = cudnn_ok(
                            activation=activation,
                            recurrent_activation=recurrent_activation,
                            unroll=unroll,
                            use_bias=use_bias,
                            reset_after=reset_after,
                        ); result = codeflash_output
                        if valid and not unroll and use_bias and reset_after:
                            pass
                        else:
                            pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-cudnn_ok-mjajtt4m and push.

Codeflash Static Badge

The optimization replaces repeated local imports with module-level caching, delivering a **67% speedup** by eliminating expensive import overhead in hot path functions.

**Key optimization**: The original code imports `keras.src.activations` and `keras.src.ops` on every function call within `_do_gru_arguments_support_cudnn` and `_do_lstm_arguments_support_cudnn`. The optimized version introduces a cached import mechanism using module-level globals that import these modules only once per process lifetime.

**Performance impact**: Line profiler data shows the import statements consumed **75% of execution time** in the original functions. The optimized version reduces this to around **50%** by eliminating redundant import lookups, with the cached approach showing ~3x faster per-hit times for the import operations.

**Why this matters in practice**: The function references show `cudnn_ok` is called during GRU layer initialization, specifically when determining CUDNN compatibility. Since RNN layers are frequently instantiated during model construction and potentially during training loops, this import overhead accumulates significantly. The test results demonstrate consistent 47-73% speedups across various parameter combinations, with particularly strong gains (61-67%) in the large-scale parametric tests that simulate real-world usage patterns.

**Thread safety**: The optimization is safe because Python's import system is inherently thread-safe and the global caching pattern using lazy initialization (`if _activations is None`) is a standard Python idiom for module-level optimization.

This optimization is especially beneficial for workloads that create multiple RNN layers or repeatedly check CUDNN compatibility, transforming what was previously an import-bound operation into a fast cached lookup.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 December 17, 2025 21:54
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Dec 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant