Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Dec 17, 2025

📄 17% (0.17x) speedup for get in keras/src/dtype_policies/__init__.py

⏱️ Runtime : 9.81 milliseconds 8.41 milliseconds (best of 23 runs)

📝 Explanation and details

The optimization achieves a 16% speedup through two key changes that eliminate repeated overhead in hot code paths:

1. Import Hoisting in get() Function
The most significant optimization moves the import of _get_quantized_dtype_policy_by_str from inside the get() function to the module level. The profiler shows this import consumed 3.5% of total time (2.1M ns out of 60.1M ns) on every function call. Even though Python caches imports, the lookup and execution overhead adds up when get() is called frequently. Moving it to module scope eliminates this per-call cost entirely.

2. Optimized String Handling in standardize_dtype()
The function now stores str(dtype) in a variable only when needed, avoiding duplicate string conversions. The original code called str(dtype) twice when checking for "torch" or "jax.numpy" patterns. This reduces string creation overhead and improves cache locality.

3. Exception Handling Improvement
Changed the bare except: to except Exception:, which is both more specific and slightly more performant.

Impact Analysis
Based on the function references, dtype_policies.get() is called in DTypePolicyMap constructors and setters, suggesting it's in hot paths during model initialization and dtype policy management. The test results show consistent improvements across all scenarios:

  • Basic cases: 12-42% faster for simple string/None inputs
  • Error cases: 15-67% faster due to reduced overhead before exceptions
  • Large scale tests: 11-52% faster, demonstrating the cumulative benefit

The optimization is particularly effective for workloads with frequent dtype policy lookups, which are common in Keras model construction and mixed-precision training scenarios.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 2030 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 93.3%
🌀 Generated Regression Tests and Runtime
import pytest
from keras.src.dtype_policies.__init__ import get

# Simulate global_state
class GlobalState:
    _state = {}
    def get_global_attribute(self, name, default=None, set_to_default=False):
        if name not in self._state and default is not None:
            if set_to_default:
                self._state[name] = default
            return default
        return self._state.get(name, default)
    def set_global_attribute(self, name, value):
        self._state[name] = value

global_state = GlobalState()

# --- DTypePolicy and Quantized Policies ---

class DTypePolicy:
    def __init__(self, name):
        self.name = standardize_dtype(name)
    def __eq__(self, other):
        return isinstance(other, DTypePolicy) and self.name == other.name
    def get_config(self):
        return {"name": self.name}

class QuantizedDTypePolicy(DTypePolicy):
    def __init__(self, mode, source_name):
        super().__init__(mode)
        self.source_name = source_name
    def __eq__(self, other):
        return (
            isinstance(other, QuantizedDTypePolicy)
            and self.name == other.name
            and self.source_name == other.source_name
        )
    def get_config(self):
        return {"mode": self.name, "source_name": self.source_name}

class QuantizedFloat8DTypePolicy(DTypePolicy):
    def __init__(self, mode, source_name):
        super().__init__(mode)
        self.source_name = source_name
    def __eq__(self, other):
        return (
            isinstance(other, QuantizedFloat8DTypePolicy)
            and self.name == other.name
            and self.source_name == other.source_name
        )
    def get_config(self):
        return {"mode": self.name, "source_name": self.source_name}

class GPTQDTypePolicy(DTypePolicy):
    def __init__(self, mode, source_name):
        super().__init__(mode)
        self.source_name = source_name
    def __eq__(self, other):
        return (
            isinstance(other, GPTQDTypePolicy)
            and self.name == other.name
            and self.source_name == other.source_name
        )
    def get_config(self):
        return {"mode": self.name, "source_name": self.source_name}

QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")
from keras.src.dtype_policies.__init__ import get

# --- End: Minimal stubs ---

# ------------------- UNIT TESTS -------------------

# Basic Test Cases

def test_get_none_returns_default_policy():
    # None should return the default dtype policy (float32)
    global_state._state.clear()
    codeflash_output = get(None); policy = codeflash_output # 4.96μs -> 3.50μs (41.7% faster)

def test_get_string_returns_dtype_policy():
    # Should create a DTypePolicy with the given name
    codeflash_output = get("float64"); policy = codeflash_output # 7.55μs -> 6.71μs (12.4% faster)

def test_get_dict_returns_deserialized_policy():
    # Should deserialize from dict config
    config = {"class_name": "DTypePolicy", "config": {"name": "float16"}}
    codeflash_output = get(config); policy = codeflash_output # 38.7μs -> 37.7μs (2.67% faster)

def test_get_quantized_string_int8():
    # Should parse quantized string for int8
    codeflash_output = get("int8_from_float32"); policy = codeflash_output # 9.96μs -> 9.14μs (9.07% faster)

def test_get_quantized_string_float8():
    # Should parse quantized string for float8
    codeflash_output = get("float8_from_bfloat16"); policy = codeflash_output # 10.7μs -> 9.47μs (12.5% faster)

def test_get_quantized_string_int4():
    # Should parse quantized string for int4
    codeflash_output = get("int4_from_float16"); policy = codeflash_output # 11.2μs -> 10.2μs (9.91% faster)

# Edge Test Cases

def test_get_invalid_string_raises():
    # Should raise for invalid dtype string
    with pytest.raises(ValueError):
        get("not_a_dtype") # 8.12μs -> 6.47μs (25.5% faster)

def test_get_invalid_quantized_string_raises():
    # Should raise for quantized string missing _from_
    with pytest.raises(ValueError):
        get("int8float32") # 4.57μs -> 3.10μs (47.1% faster)

def test_get_quantized_string_wrong_mode_raises():
    # Should raise for unsupported quantization mode
    with pytest.raises(ValueError):
        get("notquant_from_float32") # 7.67μs -> 6.21μs (23.4% faster)

def test_get_quantized_string_missing_source_raises():
    # Should raise for missing source
    with pytest.raises(ValueError):
        get("int8_from_") # 8.45μs -> 7.33μs (15.4% faster)

def test_get_unexpected_type_raises():
    # Should raise for completely unexpected type
    with pytest.raises(ValueError):
        get(12345) # 7.98μs -> 6.12μs (30.4% faster)
import pytest
from keras.src.dtype_policies.__init__ import get

# --- Minimal stub implementations for the required classes and constants ---
# These are necessary for the test environment to work, as the real Keras classes are unavailable.

class DTypePolicy:
    def __init__(self, name):
        self.name = name
    def __eq__(self, other):
        return isinstance(other, DTypePolicy) and self.name == other.name
    def get_config(self):
        return {"name": self.name}
    def __repr__(self):
        return f"DTypePolicy({self.name!r})"

class FloatDTypePolicy(DTypePolicy):
    pass

class QuantizedDTypePolicy(DTypePolicy):
    def __init__(self, mode, source_name):
        super().__init__(mode)
        self.source_name = source_name
    def __eq__(self, other):
        return (
            isinstance(other, QuantizedDTypePolicy)
            and self.name == other.name
            and self.source_name == other.source_name
        )
    def get_config(self):
        return {"mode": self.name, "source_name": self.source_name}
    def __repr__(self):
        return f"QuantizedDTypePolicy({self.name!r}, {self.source_name!r})"

class QuantizedFloat8DTypePolicy(DTypePolicy):
    def __init__(self, mode, source_name):
        super().__init__(mode)
        self.source_name = source_name
    def __eq__(self, other):
        return (
            isinstance(other, QuantizedFloat8DTypePolicy)
            and self.name == other.name
            and self.source_name == other.source_name
        )
    def get_config(self):
        return {"mode": self.name, "source_name": self.source_name}
    def __repr__(self):
        return f"QuantizedFloat8DTypePolicy({self.name!r}, {self.source_name!r})"

class GPTQDTypePolicy(DTypePolicy):
    def __init__(self, mode, source_name):
        super().__init__(mode)
        self.source_name = source_name
    def __eq__(self, other):
        return (
            isinstance(other, GPTQDTypePolicy)
            and self.name == other.name
            and self.source_name == other.source_name
        )
    def get_config(self):
        return {"mode": self.name, "source_name": self.source_name}
    def __repr__(self):
        return f"GPTQDTypePolicy({self.name!r}, {self.source_name!r})"

QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")

# --- Global state stub ---
class GlobalState:
    def __init__(self):
        self.attrs = {}
    def get_global_attribute(self, name, default=None, set_to_default=False):
        codeflash_output = self.attrs.get(name, None); val = codeflash_output
        if val is None and default is not None:
            val = default
            if set_to_default:
                self.attrs[name] = val
        return val
    def set_global_attribute(self, name, value):
        self.attrs[name] = value

global_state = GlobalState()

# --- Serialization stub ---
def deserialize_keras_object(config, module_objects=None, custom_objects=None, **kwargs):
    # Only supports the DTypePolicy classes for this test suite
    if config is None:
        return None
    if isinstance(config, dict):
        codeflash_output = config.get("class_name"); class_name = codeflash_output
        codeflash_output = config.get("config", {}); inner_config = codeflash_output
        if class_name == "DTypePolicy":
            return DTypePolicy(inner_config["name"])
        if class_name == "FloatDTypePolicy":
            return FloatDTypePolicy(inner_config["name"])
        if class_name == "QuantizedDTypePolicy":
            return QuantizedDTypePolicy(inner_config["mode"], inner_config["source_name"])
        if class_name == "QuantizedFloat8DTypePolicy":
            return QuantizedFloat8DTypePolicy(inner_config["mode"], inner_config["source_name"])
        if class_name == "GPTQDTypePolicy":
            return GPTQDTypePolicy(inner_config["mode"], inner_config["source_name"])
        raise ValueError(f"Unknown class_name: {class_name}")
    raise TypeError(f"Could not parse config: {config}")
from keras.src.dtype_policies.__init__ import get

# --- Unit tests ---

# 1. Basic Test Cases

def test_get_none_returns_global_policy():
    # Should return the global policy (float32 by default)
    global_state.attrs.clear()
    codeflash_output = get(None); result = codeflash_output # 5.17μs -> 3.64μs (42.2% faster)

def test_get_string_policy_name():
    # Should create a DTypePolicy for a valid string
    codeflash_output = get("float16"); result = codeflash_output # 8.52μs -> 6.62μs (28.8% faster)

def test_get_string_policy_name_float32():
    # Should create a DTypePolicy for another valid string
    codeflash_output = get("float32"); result = codeflash_output # 6.02μs -> 4.46μs (35.1% faster)

def test_get_dict_config_dtypepolicy():
    # Should deserialize a config dict for DTypePolicy
    config = {"class_name": "DTypePolicy", "config": {"name": "float64"}}
    codeflash_output = get(config); result = codeflash_output # 36.8μs -> 36.6μs (0.327% faster)

def test_get_dict_config_quantized_dtypepolicy():
    # Should deserialize a config dict for QuantizedDTypePolicy
    config = {"class_name": "QuantizedDTypePolicy", "config": {"mode": "int8", "source_name": "float32"}}
    codeflash_output = get(config); result = codeflash_output # 28.8μs -> 29.1μs (1.22% slower)

# 2. Edge Test Cases

def test_get_invalid_string_raises():
    # Should raise ValueError for invalid string
    with pytest.raises(ValueError):
        get("not_a_dtype") # 7.37μs -> 6.11μs (20.6% faster)

def test_get_invalid_type_raises():
    # Should raise ValueError for invalid type (int)
    with pytest.raises(ValueError):
        get(123) # 6.30μs -> 4.61μs (36.9% faster)

def test_get_quantization_mode_string_valid_int8():
    # Should parse quantized policy string for int8
    codeflash_output = get("int8_from_float32"); result = codeflash_output # 8.31μs -> 8.39μs (0.989% slower)

def test_get_quantization_mode_string_valid_float8():
    # Should parse quantized policy string for float8
    codeflash_output = get("float8_from_float32"); result = codeflash_output # 10.4μs -> 9.79μs (6.17% faster)

def test_get_quantization_mode_string_valid_int4():
    # Should parse quantized policy string for int4
    codeflash_output = get("int4_from_float32"); result = codeflash_output # 7.50μs -> 6.80μs (10.4% faster)

def test_get_quantization_mode_string_invalid_format():
    # Should raise ValueError for missing '_from_'
    with pytest.raises(ValueError):
        get("int8float32") # 5.83μs -> 3.48μs (67.6% faster)

def test_get_quantization_mode_string_invalid_mode():
    # Should raise ValueError for unsupported quantization mode
    with pytest.raises(ValueError):
        get("unknown_from_float32") # 8.95μs -> 7.87μs (13.7% faster)

def test_get_dict_missing_class_name_raises():
    # Should raise ValueError for dict missing class_name
    config = {"config": {"name": "float32"}}
    with pytest.raises(ValueError):
        get(config) # 15.6μs -> 13.6μs (14.7% faster)

def test_get_standardize_dtype_invalid_string():
    # Should raise ValueError for invalid dtype string
    with pytest.raises(ValueError):
        get("invalid_dtype") # 10.2μs -> 8.26μs (23.1% faster)

def test_get_standardize_dtype_none_returns_default():
    # Should return default floatx for None
    codeflash_output = get(None); result = codeflash_output # 4.21μs -> 2.96μs (42.4% faster)

# 3. Large Scale Test Cases

def test_get_many_string_policies():
    # Test scalability with many string policies
    names = ["float16", "float32", "float64", "bfloat16"]
    for name in names * 250:  # 1000 elements
        codeflash_output = get(name); result = codeflash_output # 1.62ms -> 1.07ms (52.0% faster)

def test_get_many_dict_configs():
    # Test scalability with many dict configs
    configs = [
        {"class_name": "DTypePolicy", "config": {"name": name}}
        for name in ["float16", "float32", "float64", "bfloat16"]
    ] * 250  # 1000 elements
    for config in configs:
        codeflash_output = get(config); result = codeflash_output # 7.90ms -> 7.09ms (11.4% faster)

def test_get_many_quantized_policies():
    # Test scalability with many quantized policies
    modes = ["int8", "float8", "int4", "gptq"]
    for mode in modes * 250:  # 1000 elements
        policy_str = f"{mode}_from_float32"
        codeflash_output = get(policy_str); result = codeflash_output
        if mode == "int8" or mode == "int4":
            pass
        elif mode == "float8":
            pass
        elif mode == "gptq":
            pass

To edit these changes git checkout codeflash/optimize-get-mjaig9zh and push.

Codeflash Static Badge

The optimization achieves a **16% speedup** through two key changes that eliminate repeated overhead in hot code paths:

**1. Import Hoisting in `get()` Function**
The most significant optimization moves the import of `_get_quantized_dtype_policy_by_str` from inside the `get()` function to the module level. The profiler shows this import consumed **3.5% of total time** (2.1M ns out of 60.1M ns) on every function call. Even though Python caches imports, the lookup and execution overhead adds up when `get()` is called frequently. Moving it to module scope eliminates this per-call cost entirely.

**2. Optimized String Handling in `standardize_dtype()`**
The function now stores `str(dtype)` in a variable only when needed, avoiding duplicate string conversions. The original code called `str(dtype)` twice when checking for "torch" or "jax.numpy" patterns. This reduces string creation overhead and improves cache locality.

**3. Exception Handling Improvement**
Changed the bare `except:` to `except Exception:`, which is both more specific and slightly more performant.

**Impact Analysis**
Based on the function references, `dtype_policies.get()` is called in `DTypePolicyMap` constructors and setters, suggesting it's in hot paths during model initialization and dtype policy management. The test results show consistent improvements across all scenarios:

- **Basic cases**: 12-42% faster for simple string/None inputs
- **Error cases**: 15-67% faster due to reduced overhead before exceptions
- **Large scale tests**: 11-52% faster, demonstrating the cumulative benefit

The optimization is particularly effective for workloads with frequent dtype policy lookups, which are common in Keras model construction and mixed-precision training scenarios.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 December 17, 2025 21:15
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Dec 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant