Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Dec 17, 2025

📄 372% (3.72x) speedup for compute_expand_dims_output_shape in keras/src/ops/operation_utils.py

⏱️ Runtime : 4.48 milliseconds 947 microseconds (best of 79 runs)

📝 Explanation and details

The key optimization is converting the axis list to a set before the list comprehension that builds the new shape. This changes the axis membership test from O(n) to O(1) complexity.

What changed:

  • Added axis_set = set(axis) after canonicalizing axes
  • Changed ax in axis to ax in axis_set in the list comprehension

Why it's faster:
In the original code, for each position in range(out_ndim), Python searches through the entire axis list to check membership. With many axes or large output dimensions, this becomes expensive. Converting to a set provides constant-time lookups instead of linear searches.

Performance impact:
The optimization shows dramatic speedups for cases with many axes:

  • Large axis lists: 1102-1116% faster (test cases with 1000 axes)
  • Multiple axes: 54% faster for 10 axes insertion
  • Single axis cases show minimal change (±1-5%), as expected since the overhead is negligible

Function usage context:
Based on the function references, compute_expand_dims_output_shape is called from:

  • TensorFlow backend's expand_dims operation (hot path for tensor operations)
  • NumPy ops for output shape computation during model building
  • Test suites for validation

The optimization particularly benefits tensor operations that expand multiple dimensions simultaneously, which can occur frequently in neural network architectures that need to broadcast or reshape tensors across multiple axes.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 68 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import operator

# imports
import pytest  # used for our unit tests
from keras.src.ops.operation_utils import compute_expand_dims_output_shape

# unit tests

# --- Basic Test Cases ---

def test_expand_single_axis_start():
    # Insert dimension at the start
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), 0) # 6.38μs -> 6.31μs (0.998% faster)

def test_expand_single_axis_middle():
    # Insert dimension in the middle
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), 1) # 4.92μs -> 5.00μs (1.56% slower)

def test_expand_single_axis_end():
    # Insert dimension at the end
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), 3) # 4.72μs -> 4.76μs (0.860% slower)

def test_expand_single_axis_negative():
    # Insert dimension using negative axis
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), -1) # 4.73μs -> 4.65μs (1.83% faster)
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), -2) # 1.91μs -> 2.02μs (5.39% slower)
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), -4) # 1.43μs -> 1.75μs (18.3% slower)

def test_expand_multiple_axes_tuple():
    # Insert multiple dimensions using a tuple
    codeflash_output = compute_expand_dims_output_shape((2, 3), (0, 2)) # 5.01μs -> 4.95μs (1.23% faster)

def test_expand_multiple_axes_list():
    # Insert multiple dimensions using a list
    codeflash_output = compute_expand_dims_output_shape((2, 3), [1, 3]) # 5.13μs -> 4.94μs (3.87% faster)

def test_expand_axis_none():
    # axis=None should append a dimension at the end
    codeflash_output = compute_expand_dims_output_shape((2, 3), None) # 4.43μs -> 4.38μs (1.10% faster)

def test_expand_scalar_shape():
    # Expanding a scalar (empty shape)
    codeflash_output = compute_expand_dims_output_shape((), 0) # 4.22μs -> 4.29μs (1.82% slower)
    codeflash_output = compute_expand_dims_output_shape((), None) # 1.94μs -> 2.05μs (5.23% slower)

def test_expand_axis_as_int():
    # axis as int
    codeflash_output = compute_expand_dims_output_shape((5,), 0) # 4.25μs -> 4.26μs (0.047% slower)

def test_expand_axis_as_tuple_of_one():
    # axis as tuple of one element
    codeflash_output = compute_expand_dims_output_shape((5,), (0,)) # 4.40μs -> 4.49μs (1.89% slower)

def test_expand_axis_as_list_of_one():
    # axis as list of one element
    codeflash_output = compute_expand_dims_output_shape((5,), [1]) # 4.44μs -> 4.56μs (2.63% slower)

# --- Edge Test Cases ---

def test_expand_axis_out_of_bounds_positive():
    # Axis out of bounds (too high)
    with pytest.raises(ValueError):
        compute_expand_dims_output_shape((2, 3), 3) # 3.86μs -> 3.63μs (6.33% faster)

def test_expand_axis_out_of_bounds_negative():
    # Axis out of bounds (too negative)
    with pytest.raises(ValueError):
        compute_expand_dims_output_shape((2, 3), -4) # 3.92μs -> 3.65μs (7.31% faster)

def test_expand_axis_type_error():
    # Axis is not int/tuple/list
    with pytest.raises(ValueError):
        compute_expand_dims_output_shape((2, 3), "foo") # 2.05μs -> 1.90μs (7.95% faster)

def test_expand_axis_empty_tuple():
    # Axis as empty tuple, should return input shape
    codeflash_output = compute_expand_dims_output_shape((2, 3), ()) # 4.66μs -> 4.92μs (5.23% slower)

def test_expand_axis_empty_list():
    # Axis as empty list, should return input shape
    codeflash_output = compute_expand_dims_output_shape((2, 3), []) # 3.87μs -> 3.94μs (1.65% slower)

def test_expand_axis_all_positions():
    # Insert at all possible positions for 2D shape
    for i in range(3):
        shape = (2, 3)
        out = list(shape)
        out.insert(i, 1)
        codeflash_output = compute_expand_dims_output_shape(shape, i) # 8.41μs -> 8.68μs (3.05% slower)

def test_expand_axis_reverse_order():
    # Axes in reverse order
    codeflash_output = compute_expand_dims_output_shape((2, 3), (2, 0)) # 4.91μs -> 4.81μs (2.00% faster)

def test_expand_axis_unsorted():
    # Unsorted axes
    codeflash_output = compute_expand_dims_output_shape((2, 3), (1, 0)) # 4.98μs -> 4.72μs (5.42% faster)

def test_expand_axis_negative_and_positive():
    # Mix of negative and positive axes
    codeflash_output = compute_expand_dims_output_shape((2, 3), (0, -1)) # 4.92μs -> 4.94μs (0.425% slower)

def test_expand_axis_large_negative():
    # Large negative axis (should fail)
    with pytest.raises(ValueError):
        compute_expand_dims_output_shape((2, 3), (-5,)) # 3.94μs -> 3.73μs (5.76% faster)

def test_expand_axis_large_positive():
    # Large positive axis (should fail)
    with pytest.raises(ValueError):
        compute_expand_dims_output_shape((2, 3), (4,)) # 3.99μs -> 3.71μs (7.58% faster)

def test_expand_axis_with_non_integer_in_list():
    # List with a non-integer element
    with pytest.raises(TypeError):
        compute_expand_dims_output_shape((2, 3), [0, "a"]) # 4.29μs -> 4.11μs (4.53% faster)

def test_expand_axis_with_bool():
    # Axis as boolean should be treated as int (True==1, False==0)
    codeflash_output = compute_expand_dims_output_shape((2, 3), True) # 5.26μs -> 5.27μs (0.133% slower)
    codeflash_output = compute_expand_dims_output_shape((2, 3), False) # 2.13μs -> 2.56μs (17.0% slower)

# --- Large Scale Test Cases ---

def test_expand_large_shape_single_axis():
    # Large input shape, single axis
    shape = tuple(range(1, 1001))
    out = (1,) + shape
    codeflash_output = compute_expand_dims_output_shape(shape, 0) # 47.8μs -> 49.6μs (3.50% slower)

def test_expand_large_shape_end():
    # Large input shape, insert at the end
    shape = tuple(range(1, 1001))
    out = shape + (1,)
    codeflash_output = compute_expand_dims_output_shape(shape, 1000) # 47.8μs -> 48.7μs (1.77% slower)

def test_expand_large_shape_middle():
    # Large input shape, insert in the middle
    shape = tuple(range(1, 1001))
    out = shape[:500] + (1,) + shape[500:]
    codeflash_output = compute_expand_dims_output_shape(shape, 500) # 48.2μs -> 48.8μs (1.26% slower)

def test_expand_large_shape_multiple_axes():
    # Insert 10 singleton dimensions at various positions
    shape = tuple(range(1, 991))
    axes = list(range(0, 20, 2))  # [0,2,4,6,8,10,12,14,16,18]
    out = []
    idx = 0
    shape_iter = iter(shape)
    for i in range(len(shape) + len(axes)):
        if i in axes:
            out.append(1)
        else:
            out.append(next(shape_iter))
    codeflash_output = compute_expand_dims_output_shape(shape, axes) # 77.0μs -> 49.9μs (54.5% faster)

def test_expand_large_shape_all_axes():
    # Insert a singleton dimension at every possible position
    shape = ()
    axes = list(range(1000))
    out = tuple(1 for _ in range(1000))
    codeflash_output = compute_expand_dims_output_shape(shape, axes) # 1.90ms -> 155μs (1116% faster)
import operator

# imports
import pytest  # used for our unit tests
from keras.src.ops.operation_utils import compute_expand_dims_output_shape

# unit tests

# ------------------------- BASIC TEST CASES -------------------------

def test_basic_single_axis_start():
    # Insert axis at the start
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), 0) # 6.28μs -> 6.15μs (2.23% faster)

def test_basic_single_axis_middle():
    # Insert axis in the middle
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), 1) # 4.94μs -> 4.95μs (0.283% slower)

def test_basic_single_axis_end():
    # Insert axis at the end
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), 3) # 4.63μs -> 4.71μs (1.70% slower)

def test_basic_negative_axis():
    # Insert axis using negative index
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), -1) # 4.68μs -> 4.63μs (1.10% faster)
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), -2) # 2.07μs -> 2.06μs (0.535% faster)
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), -4) # 1.43μs -> 1.73μs (17.4% slower)

def test_basic_axis_as_tuple():
    # Insert axis using tuple
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), (1,)) # 4.63μs -> 4.67μs (0.857% slower)

def test_basic_axis_as_list():
    # Insert axis using list
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), [1]) # 4.71μs -> 4.51μs (4.45% faster)

def test_basic_multiple_axes_sorted():
    # Insert multiple axes in sorted order
    codeflash_output = compute_expand_dims_output_shape((2, 3), (0, 2)) # 5.10μs -> 4.87μs (4.81% faster)

def test_basic_multiple_axes_unsorted():
    # Insert multiple axes in unsorted order
    codeflash_output = compute_expand_dims_output_shape((2, 3), (2, 0)) # 5.13μs -> 4.88μs (5.12% faster)
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), (2, 0)) # 2.61μs -> 2.61μs (0.000% faster)

def test_basic_axis_is_none():
    # axis=None should insert at the end
    codeflash_output = compute_expand_dims_output_shape((2, 3, 4), None) # 4.45μs -> 4.43μs (0.383% faster)

def test_basic_empty_input_shape():
    # Expanding dims on a scalar
    codeflash_output = compute_expand_dims_output_shape((), 0) # 4.30μs -> 4.23μs (1.65% faster)
    codeflash_output = compute_expand_dims_output_shape((), None) # 2.03μs -> 2.02μs (0.695% faster)

def test_basic_axis_as_int():
    # Accepts integer axis
    codeflash_output = compute_expand_dims_output_shape((5, 6), 1) # 4.35μs -> 4.21μs (3.35% faster)

# ------------------------- EDGE TEST CASES -------------------------

def test_edge_axis_out_of_bounds_positive():
    # Axis too large
    with pytest.raises(ValueError):
        compute_expand_dims_output_shape((2, 3), 4) # 3.81μs -> 3.57μs (6.78% faster)

def test_edge_axis_out_of_bounds_negative():
    # Axis too negative
    with pytest.raises(ValueError):
        compute_expand_dims_output_shape((2, 3), -4) # 3.76μs -> 3.48μs (8.04% faster)

def test_edge_axis_type_invalid():
    # Axis type not int/tuple/list
    with pytest.raises(ValueError):
        compute_expand_dims_output_shape((2, 3), "1") # 1.98μs -> 1.79μs (10.3% faster)

def test_edge_axis_tuple_with_invalid_type():
    # Axis tuple contains invalid type
    with pytest.raises(TypeError):
        compute_expand_dims_output_shape((2, 3), (1, "a")) # 4.52μs -> 4.17μs (8.40% faster)

def test_edge_axis_list_with_invalid_type():
    # Axis list contains invalid type
    with pytest.raises(TypeError):
        compute_expand_dims_output_shape((2, 3), [1, None]) # 4.32μs -> 3.95μs (9.34% faster)

def test_edge_all_axes():
    # Insert at all possible positions (for 2D input, axes 0,1,2)
    codeflash_output = compute_expand_dims_output_shape((2, 3), (0, 1, 2)) # 7.20μs -> 7.07μs (1.84% faster)

def test_edge_empty_axis_tuple():
    # No axis supplied as empty tuple/list: should return input shape
    codeflash_output = compute_expand_dims_output_shape((2, 3), ()) # 3.79μs -> 3.81μs (0.629% slower)
    codeflash_output = compute_expand_dims_output_shape((2, 3), []) # 1.89μs -> 1.90μs (0.580% slower)

def test_edge_input_shape_with_zero_dim():
    # Input shape contains zero
    codeflash_output = compute_expand_dims_output_shape((0, 3), 1) # 4.77μs -> 4.71μs (1.15% faster)

def test_edge_input_shape_with_one_dim():
    # Input shape contains one
    codeflash_output = compute_expand_dims_output_shape((1, 3), 1) # 4.64μs -> 4.44μs (4.57% faster)

def test_edge_axis_as_float():
    # Passing float as axis should raise error
    with pytest.raises(ValueError):
        compute_expand_dims_output_shape((2, 3), 1.0) # 3.77μs -> 3.68μs (2.64% faster)

# ------------------------- LARGE SCALE TEST CASES -------------------------

def test_large_input_shape_single_axis():
    # Large input shape, single axis insert
    input_shape = tuple(range(1000))
    codeflash_output = compute_expand_dims_output_shape(input_shape, 500); out = codeflash_output # 50.1μs -> 50.3μs (0.320% slower)

def test_large_input_shape_multiple_axes():
    # Large input shape, multiple axes
    input_shape = tuple(range(1000))
    axes = (0, 500, 1000)
    codeflash_output = compute_expand_dims_output_shape(input_shape, axes); out = codeflash_output # 54.8μs -> 51.3μs (6.87% faster)

def test_large_axis_list():
    # Large axis list, all at start
    input_shape = (2, 3)
    axes = tuple(range(1000))
    codeflash_output = compute_expand_dims_output_shape(input_shape, axes); out = codeflash_output # 1.92ms -> 159μs (1102% faster)

def test_large_input_shape_and_axis_end():
    # Insert axis at the end of a large shape
    input_shape = tuple(range(1000))
    codeflash_output = compute_expand_dims_output_shape(input_shape, 1000); out = codeflash_output # 50.0μs -> 51.0μs (1.98% slower)

def test_large_input_shape_and_axis_start():
    # Insert axis at the start of a large shape
    input_shape = tuple(range(1000))
    codeflash_output = compute_expand_dims_output_shape(input_shape, 0); out = codeflash_output # 47.7μs -> 48.7μs (2.16% slower)

# ------------------------- ADDITIONAL FUNCTIONALITY TESTS -------------------------

def test_axis_tuple_with_negative_and_positive():
    # Mix of negative and positive axes
    input_shape = (2, 3, 4)
    # out_ndim = 3+2=5, so -1 -> 4, -4 -> 1
    codeflash_output = compute_expand_dims_output_shape(input_shape, (1, -1, -4)); out = codeflash_output # 6.20μs -> 6.04μs (2.51% faster)

def test_axis_tuple_with_all_negatives():
    # All negative axes
    input_shape = (2, 3, 4)
    codeflash_output = compute_expand_dims_output_shape(input_shape, (-1, -2, -3)); out = codeflash_output # 5.68μs -> 5.49μs (3.61% faster)

To edit these changes git checkout codeflash/optimize-compute_expand_dims_output_shape-mjafsocr and push.

Codeflash Static Badge

The key optimization is converting the `axis` list to a `set` before the list comprehension that builds the new shape. This changes the axis membership test from O(n) to O(1) complexity.

**What changed:**
- Added `axis_set = set(axis)` after canonicalizing axes
- Changed `ax in axis` to `ax in axis_set` in the list comprehension

**Why it's faster:**
In the original code, for each position in `range(out_ndim)`, Python searches through the entire `axis` list to check membership. With many axes or large output dimensions, this becomes expensive. Converting to a set provides constant-time lookups instead of linear searches.

**Performance impact:**
The optimization shows dramatic speedups for cases with many axes:
- Large axis lists: **1102-1116% faster** (test cases with 1000 axes)
- Multiple axes: **54% faster** for 10 axes insertion
- Single axis cases show minimal change (±1-5%), as expected since the overhead is negligible

**Function usage context:**
Based on the function references, `compute_expand_dims_output_shape` is called from:
- TensorFlow backend's `expand_dims` operation (hot path for tensor operations)
- NumPy ops for output shape computation during model building
- Test suites for validation

The optimization particularly benefits tensor operations that expand multiple dimensions simultaneously, which can occur frequently in neural network architectures that need to broadcast or reshape tensors across multiple axes.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 December 17, 2025 20:01
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Dec 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant