Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Dec 17, 2025

📄 11% (0.11x) speedup for reduce_shape in keras/src/ops/operation_utils.py

⏱️ Runtime : 475 microseconds 429 microseconds (best of 171 runs)

📝 Explanation and details

The optimization achieves a 10% speedup through several targeted improvements:

Key Optimizations:

  1. Direct __index__() call in canonicalize_axis: Replaced operator.index(axis) with a direct axis.__index__() call wrapped in try-except. This eliminates function call overhead while maintaining the same integer validation behavior.

  2. Tuple multiplication for keepdims=True: Changed tuple([1 for _ in shape]) to (1,) * n, avoiding list comprehension and intermediate list creation when generating a tuple of ones.

  3. Pre-computed length: Stored len(shape) as n to avoid repeated length calculations during axis canonicalization.

  4. Variable renaming for clarity: Renamed axis to axis_tuple to better distinguish the canonicalized tuple from the input parameter.

Performance Impact by Test Case:

  • Large shape reductions with keepdims=True show the biggest gains (up to 313% faster) due to the tuple multiplication optimization
  • Basic keepdims=True cases benefit significantly (27-38% faster) from the same optimization
  • Most other cases show modest 1-10% improvements from reduced function call overhead
  • Some single-axis cases are slightly slower (~2%) due to additional variable assignment overhead

Hot Path Relevance:
Based on the function reference showing reduce_shape being called in linalg.py for computing tensor output specifications, this optimization is valuable for neural network operations involving linear algebra computations where shape reductions are frequent. The improvements are most pronounced for operations that maintain dimensions (keepdims=True), which are common in batch processing and broadcasting scenarios.

The optimizations are particularly effective for large tensors and operations that reduce all dimensions while preserving shape structure.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 70 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import operator

# imports
import pytest
from keras.src.ops.operation_utils import reduce_shape

# unit tests

# ------------------------
# 1. Basic Test Cases
# ------------------------

def test_reduce_shape_axis_none_keepdims_false():
    # Reducing all axes, shape should become ()
    codeflash_output = reduce_shape((2, 3, 4), axis=None, keepdims=False) # 1.54μs -> 1.35μs (13.9% faster)

def test_reduce_shape_axis_none_keepdims_true():
    # Reducing all axes with keepdims, shape should become (1, 1, 1)
    codeflash_output = reduce_shape((2, 3, 4), axis=None, keepdims=True) # 1.89μs -> 1.37μs (37.9% faster)

def test_reduce_shape_single_axis_no_keepdims():
    # Reducing axis 1, should remove the second dimension
    codeflash_output = reduce_shape((2, 3, 4), axis=1, keepdims=False) # 4.68μs -> 4.78μs (1.95% slower)

def test_reduce_shape_single_axis_keepdims():
    # Reducing axis 0 with keepdims, should set first dimension to 1
    codeflash_output = reduce_shape((2, 3, 4), axis=0, keepdims=True) # 3.35μs -> 3.23μs (3.75% faster)

def test_reduce_shape_multiple_axes_no_keepdims():
    # Reducing axes 0 and 2, should remove first and last dimensions
    codeflash_output = reduce_shape((2, 3, 4), axis=(0, 2), keepdims=False) # 5.01μs -> 4.83μs (3.81% faster)

def test_reduce_shape_multiple_axes_keepdims():
    # Reducing axes 0 and 2 with keepdims, should set first and last dimensions to 1
    codeflash_output = reduce_shape((2, 3, 4), axis=(0, 2), keepdims=True) # 3.70μs -> 3.58μs (3.27% faster)

def test_reduce_shape_negative_axis_no_keepdims():
    # Negative axis should be canonicalized (axis -1 is last dimension)
    codeflash_output = reduce_shape((2, 3, 4), axis=-1, keepdims=False) # 4.01μs -> 4.09μs (1.91% slower)

def test_reduce_shape_negative_axis_keepdims():
    # Negative axis with keepdims
    codeflash_output = reduce_shape((2, 3, 4), axis=-2, keepdims=True) # 3.32μs -> 3.21μs (3.49% faster)

def test_reduce_shape_tuple_axis():
    # Axis as tuple, mixed positive and negative
    codeflash_output = reduce_shape((2, 3, 4, 5), axis=(1, -1), keepdims=False) # 4.75μs -> 4.83μs (1.55% slower)

def test_reduce_shape_axis_zero_dim():
    # Shape with a zero dimension, should still work
    codeflash_output = reduce_shape((0, 3), axis=0, keepdims=False) # 3.80μs -> 3.87μs (1.84% slower)

# ------------------------
# 2. Edge Test Cases
# ------------------------

def test_reduce_shape_empty_shape_axis_none():
    # Reducing all axes of a scalar (shape=()), should return ()
    codeflash_output = reduce_shape((), axis=None, keepdims=False) # 1.12μs -> 1.02μs (9.97% faster)
    codeflash_output = reduce_shape((), axis=None, keepdims=True) # 965ns -> 747ns (29.2% faster)

def test_reduce_shape_empty_shape_axis_raises():
    # Reducing axis 0 on a scalar should raise
    with pytest.raises(ValueError):
        reduce_shape((), axis=0, keepdims=False) # 3.62μs -> 3.63μs (0.220% slower)

def test_reduce_shape_axis_out_of_bounds_positive():
    # Axis too large should raise
    with pytest.raises(ValueError):
        reduce_shape((2, 3), axis=2, keepdims=False) # 3.78μs -> 3.64μs (3.71% faster)

def test_reduce_shape_axis_out_of_bounds_negative():
    # Axis too negative should raise
    with pytest.raises(ValueError):
        reduce_shape((2, 3), axis=-3, keepdims=False) # 3.71μs -> 3.62μs (2.66% faster)

def test_reduce_shape_duplicate_axes_keepdims():
    # Duplicate axes with keepdims, should set both to 1
    codeflash_output = reduce_shape((2, 3, 4), axis=(1, 1), keepdims=True) # 3.88μs -> 3.80μs (2.05% faster)

def test_reduce_shape_duplicate_axes_no_keepdims():
    # Duplicate axes without keepdims, should remove axis only once
    codeflash_output = reduce_shape((2, 3, 4), axis=(1, 1), keepdims=False) # 4.68μs -> 4.50μs (4.05% faster)

def test_reduce_shape_all_axes():
    # Reduce all axes explicitly
    codeflash_output = reduce_shape((2, 3, 4), axis=(0, 1, 2), keepdims=False) # 4.91μs -> 4.79μs (2.55% faster)
    codeflash_output = reduce_shape((2, 3, 4), axis=(0, 1, 2), keepdims=True) # 2.10μs -> 2.11μs (0.284% slower)

def test_reduce_shape_axis_unsorted_tuple():
    # Unsorted axes tuple, should still work
    codeflash_output = reduce_shape((2, 3, 4, 5), axis=(3, 1), keepdims=False) # 4.18μs -> 4.11μs (1.75% faster)

def test_reduce_shape_axis_as_list():
    # Axis as a list, should work if converted to tuple
    codeflash_output = reduce_shape((2, 3, 4), axis=[0, 2], keepdims=False) # 4.41μs -> 4.27μs (3.18% faster)

def test_reduce_shape_axis_as_numpy_integer():
    # Axis as numpy integer type (simulate with int subclass)
    class MyInt(int): pass
    codeflash_output = reduce_shape((2, 3, 4), axis=MyInt(1), keepdims=False) # 4.11μs -> 3.92μs (4.64% faster)

def test_reduce_shape_axis_is_empty_tuple():
    # Axis is empty tuple: should return original shape or empty depending on keepdims
    codeflash_output = reduce_shape((2, 3, 4), axis=(), keepdims=False) # 2.99μs -> 2.74μs (8.82% faster)
    codeflash_output = reduce_shape((2, 3, 4), axis=(), keepdims=True) # 1.04μs -> 1.11μs (6.23% slower)

def test_reduce_shape_shape_with_one_dim():
    # Shape with one dimension, reduce axis 0
    codeflash_output = reduce_shape((10,), axis=0, keepdims=False) # 3.88μs -> 3.69μs (5.21% faster)
    codeflash_output = reduce_shape((10,), axis=0, keepdims=True) # 1.54μs -> 1.64μs (5.99% slower)

def test_reduce_shape_axis_is_zero():
    # Axis is 0, should remove first dimension
    codeflash_output = reduce_shape((2, 3, 4), axis=0, keepdims=False) # 3.63μs -> 3.43μs (5.81% faster)

# ------------------------
# 3. Large Scale Test Cases
# ------------------------

def test_reduce_shape_large_shape_all_axes_keepdims():
    # Large shape, reduce all axes with keepdims
    shape = tuple(range(1, 1001))  # (1, 2, ..., 1000)
    codeflash_output = reduce_shape(shape, axis=None, keepdims=True); result = codeflash_output # 17.2μs -> 5.09μs (237% faster)

def test_reduce_shape_large_shape_all_axes_no_keepdims():
    # Large shape, reduce all axes without keepdims
    shape = tuple(range(1, 1001))
    codeflash_output = reduce_shape(shape, axis=None, keepdims=False); result = codeflash_output # 3.15μs -> 3.22μs (2.05% slower)

def test_reduce_shape_large_shape_some_axes_keepdims():
    # Reduce every other axis with keepdims
    shape = tuple(range(1, 1001))
    axes = tuple(range(0, 1000, 2))  # even axes
    codeflash_output = reduce_shape(shape, axis=axes, keepdims=True); result = codeflash_output # 82.8μs -> 75.3μs (9.98% faster)
    # Even indices set to 1, odd to original
    expected = tuple(1 if i % 2 == 0 else shape[i] for i in range(1000))

def test_reduce_shape_large_shape_some_axes_no_keepdims():
    # Reduce every 10th axis without keepdims
    shape = tuple(range(1, 1001))
    axes = tuple(range(0, 1000, 10))
    codeflash_output = reduce_shape(shape, axis=axes, keepdims=False); result = codeflash_output # 27.7μs -> 26.3μs (5.11% faster)
    # Should remove axes at positions 0, 10, 20, ..., 990
    expected = [shape[i] for i in range(1000) if i % 10 != 0]

def test_reduce_shape_large_shape_negative_axes():
    # Reduce last 10 axes using negative indices
    shape = tuple(range(1, 1001))
    axes = tuple(range(-10, 0))
    codeflash_output = reduce_shape(shape, axis=axes, keepdims=False); result = codeflash_output # 9.92μs -> 9.72μs (2.05% faster)
    expected = shape[:990]

def test_reduce_shape_large_shape_duplicate_axes():
    # Duplicate axes in a large shape
    shape = tuple(range(1, 1001))
    axes = (0, 0, 1, 1, 2, 2)
    codeflash_output = reduce_shape(shape, axis=axes, keepdims=False); result = codeflash_output # 9.39μs -> 9.34μs (0.643% faster)
    # Only first three axes removed
    expected = shape[3:]

def test_reduce_shape_large_shape_axis_as_list():
    # Axis as a list in a large shape
    shape = tuple(range(1, 1001))
    axes = list(range(10))
    codeflash_output = reduce_shape(shape, axis=axes, keepdims=False); result = codeflash_output # 9.97μs -> 9.81μs (1.68% faster)
    expected = shape[10:]

def test_reduce_shape_large_shape_axis_unsorted():
    # Unsorted axes in a large shape
    shape = tuple(range(1, 1001))
    axes = (999, 0, 500)
    codeflash_output = reduce_shape(shape, axis=axes, keepdims=False); result = codeflash_output # 8.36μs -> 8.25μs (1.24% faster)
    # Remove 999, 500, 0 in that order
    expected = list(shape)
    for ax in sorted(axes, reverse=True):
        del expected[ax]
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import operator

# imports
import pytest
from keras.src.ops.operation_utils import reduce_shape

# unit tests

# ------------------------
# BASIC TEST CASES
# ------------------------

def test_reduce_shape_none_axis_keepdims_false():
    # Reduce all axes, keepdims=False should return ()
    codeflash_output = reduce_shape((2, 3, 4), axis=None, keepdims=False) # 1.12μs -> 1.07μs (4.58% faster)

def test_reduce_shape_none_axis_keepdims_true():
    # Reduce all axes, keepdims=True should return tuple of 1s
    codeflash_output = reduce_shape((2, 3, 4), axis=None, keepdims=True) # 1.82μs -> 1.44μs (27.0% faster)

def test_reduce_shape_single_axis_keepdims_false():
    # Reduce axis 1, keepdims=False should remove axis 1
    codeflash_output = reduce_shape((2, 3, 4), axis=1, keepdims=False) # 4.22μs -> 4.15μs (1.69% faster)

def test_reduce_shape_single_axis_keepdims_true():
    # Reduce axis 1, keepdims=True should set axis 1 to 1
    codeflash_output = reduce_shape((2, 3, 4), axis=1, keepdims=True) # 3.36μs -> 3.23μs (4.28% faster)

def test_reduce_shape_multiple_axes_keepdims_false():
    # Reduce axes 0 and 2, keepdims=False should remove those axes
    codeflash_output = reduce_shape((2, 3, 4), axis=(0, 2), keepdims=False) # 4.83μs -> 4.64μs (4.10% faster)

def test_reduce_shape_multiple_axes_keepdims_true():
    # Reduce axes 0 and 2, keepdims=True should set those axes to 1
    codeflash_output = reduce_shape((2, 3, 4), axis=(0, 2), keepdims=True) # 3.68μs -> 3.40μs (7.99% faster)

def test_reduce_shape_axis_negative():
    # Negative axis should be handled correctly
    codeflash_output = reduce_shape((2, 3, 4), axis=-1, keepdims=False) # 4.07μs -> 3.81μs (6.83% faster)
    codeflash_output = reduce_shape((2, 3, 4), axis=-1, keepdims=True) # 1.64μs -> 1.64μs (0.426% slower)

def test_reduce_shape_axis_tuple_with_negatives():
    # Tuple with negative axes
    codeflash_output = reduce_shape((2, 3, 4, 5), axis=(-1, 0), keepdims=False) # 4.57μs -> 4.51μs (1.35% faster)
    codeflash_output = reduce_shape((2, 3, 4, 5), axis=(-1, 0), keepdims=True) # 1.83μs -> 1.95μs (6.16% slower)

# ------------------------
# EDGE TEST CASES
# ------------------------

def test_reduce_shape_empty_shape():
    # Shape is empty tuple, should always return empty tuple or tuple of 1s
    codeflash_output = reduce_shape((), axis=None, keepdims=False) # 989ns -> 840ns (17.7% faster)
    codeflash_output = reduce_shape((), axis=None, keepdims=True) # 1.00μs -> 814ns (23.0% faster)
    # Reducing any axis on empty shape should raise
    with pytest.raises(ValueError):
        reduce_shape((), axis=0, keepdims=False) # 3.22μs -> 3.13μs (2.78% faster)

def test_reduce_shape_axis_out_of_bounds():
    # Axis out of bounds should raise ValueError
    with pytest.raises(ValueError):
        reduce_shape((2, 3, 4), axis=3, keepdims=False) # 3.69μs -> 3.53μs (4.56% faster)
    with pytest.raises(ValueError):
        reduce_shape((2, 3, 4), axis=-4, keepdims=False) # 1.97μs -> 1.95μs (1.13% faster)
    with pytest.raises(ValueError):
        reduce_shape((2, 3, 4), axis=(0, 4), keepdims=True) # 2.30μs -> 2.19μs (5.08% faster)
    with pytest.raises(ValueError):
        reduce_shape((2, 3, 4), axis=(0, -5), keepdims=True) # 1.60μs -> 1.45μs (10.5% faster)

def test_reduce_shape_axis_duplicates():
    # Duplicated axes in tuple should only remove/set once per occurrence
    codeflash_output = reduce_shape((2, 3, 4), axis=(1, 1), keepdims=False) # 4.71μs -> 4.58μs (2.84% faster)
    codeflash_output = reduce_shape((2, 3, 4), axis=(1, 1), keepdims=True) # 1.75μs -> 1.90μs (7.98% slower)

def test_reduce_shape_axis_unsorted():
    # Unsorted axes tuple should work
    codeflash_output = reduce_shape((2, 3, 4, 5), axis=(3, 1), keepdims=False) # 4.27μs -> 4.28μs (0.140% slower)
    codeflash_output = reduce_shape((2, 3, 4, 5), axis=(3, 1), keepdims=True) # 1.75μs -> 1.79μs (2.23% slower)

def test_reduce_shape_axis_type_error():
    # Axis of invalid type should raise TypeError
    with pytest.raises(TypeError):
        reduce_shape((2, 3, 4), axis="not_an_int", keepdims=False)
    with pytest.raises(TypeError):
        reduce_shape((2, 3, 4), axis=[1, 2], keepdims=False)  # list, not tuple

def test_reduce_shape_axis_empty_tuple():
    # Axis is empty tuple: should return original shape if keepdims=False, or all dims set to 1 if keepdims=True
    codeflash_output = reduce_shape((2, 3, 4), axis=(), keepdims=False) # 3.59μs -> 3.54μs (1.35% faster)
    codeflash_output = reduce_shape((2, 3, 4), axis=(), keepdims=True) # 1.05μs -> 1.15μs (8.35% slower)

def test_reduce_shape_axis_all_axes():
    # Axis is all axes: should reduce all, result is () or all 1s
    codeflash_output = reduce_shape((2, 3, 4), axis=(0, 1, 2), keepdims=False) # 5.53μs -> 5.37μs (3.00% faster)
    codeflash_output = reduce_shape((2, 3, 4), axis=(0, 1, 2), keepdims=True) # 2.10μs -> 2.16μs (2.92% slower)
    # Negative axes version
    codeflash_output = reduce_shape((2, 3, 4), axis=(-1, -2, -3), keepdims=False) # 2.03μs -> 2.11μs (3.70% slower)
    codeflash_output = reduce_shape((2, 3, 4), axis=(-1, -2, -3), keepdims=True) # 1.49μs -> 1.43μs (4.05% faster)

def test_reduce_shape_large_shape_reduce_all():
    # Reduce all axes in a large shape
    shape = tuple(range(1, 1001))  # shape of length 1000
    codeflash_output = reduce_shape(shape, axis=None, keepdims=False) # 3.69μs -> 3.69μs (0.162% slower)
    codeflash_output = reduce_shape(shape, axis=None, keepdims=True) # 17.0μs -> 4.12μs (313% faster)

def test_reduce_shape_large_shape_reduce_some():
    # Reduce every other axis, keepdims=False
    shape = tuple(range(1, 1001))
    axes = tuple(range(0, 1000, 2))  # even axes
    expected_length = 1000 - len(axes)
    codeflash_output = reduce_shape(shape, axis=axes, keepdims=False); result = codeflash_output # 96.2μs -> 89.3μs (7.72% faster)

def test_reduce_shape_large_shape_keepdims_true():
    # Reduce every 100th axis, keepdims=True
    shape = tuple(range(1, 1001))
    axes = tuple(range(0, 1000, 100))
    codeflash_output = reduce_shape(shape, axis=axes, keepdims=True); result = codeflash_output # 8.49μs -> 8.39μs (1.20% faster)
    for ax in axes:
        pass

def test_reduce_shape_large_shape_axis_negative():
    # Reduce last axis (negative index) in large shape
    shape = tuple(range(1, 1001))
    codeflash_output = reduce_shape(shape, axis=-1, keepdims=False); result = codeflash_output # 7.78μs -> 7.82μs (0.397% slower)

def test_reduce_shape_large_shape_axis_tuple_with_negatives():
    # Reduce first and last axis (0, -1)
    shape = tuple(range(1, 1001))
    codeflash_output = reduce_shape(shape, axis=(0, -1), keepdims=False); result = codeflash_output # 8.55μs -> 8.58μs (0.396% slower)

To edit these changes git checkout codeflash/optimize-reduce_shape-mjagz36d and push.

Codeflash Static Badge

The optimization achieves a 10% speedup through several targeted improvements:

**Key Optimizations:**

1. **Direct `__index__()` call in `canonicalize_axis`**: Replaced `operator.index(axis)` with a direct `axis.__index__()` call wrapped in try-except. This eliminates function call overhead while maintaining the same integer validation behavior.

2. **Tuple multiplication for keepdims=True**: Changed `tuple([1 for _ in shape])` to `(1,) * n`, avoiding list comprehension and intermediate list creation when generating a tuple of ones.

3. **Pre-computed length**: Stored `len(shape)` as `n` to avoid repeated length calculations during axis canonicalization.

4. **Variable renaming for clarity**: Renamed `axis` to `axis_tuple` to better distinguish the canonicalized tuple from the input parameter.

**Performance Impact by Test Case:**
- **Large shape reductions with keepdims=True** show the biggest gains (up to 313% faster) due to the tuple multiplication optimization
- **Basic keepdims=True cases** benefit significantly (27-38% faster) from the same optimization
- **Most other cases** show modest 1-10% improvements from reduced function call overhead
- **Some single-axis cases** are slightly slower (~2%) due to additional variable assignment overhead

**Hot Path Relevance:**
Based on the function reference showing `reduce_shape` being called in `linalg.py` for computing tensor output specifications, this optimization is valuable for neural network operations involving linear algebra computations where shape reductions are frequent. The improvements are most pronounced for operations that maintain dimensions (keepdims=True), which are common in batch processing and broadcasting scenarios.

The optimizations are particularly effective for large tensors and operations that reduce all dimensions while preserving shape structure.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 December 17, 2025 20:34
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Dec 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant