diff --git a/keras/src/backend/common/variables.py b/keras/src/backend/common/variables.py index 33a2cc3c5160..e7b73a887a5c 100644 --- a/keras/src/backend/common/variables.py +++ b/keras/src/backend/common/variables.py @@ -573,15 +573,21 @@ def initialize_all_variables(): def standardize_dtype(dtype): if dtype is None: return config.floatx() + + orig_dtype = dtype dtype = dtypes.PYTHON_DTYPES_MAP.get(dtype, dtype) if hasattr(dtype, "name"): dtype = dtype.name elif hasattr(dtype, "__name__"): dtype = dtype.__name__ - elif hasattr(dtype, "__str__") and ( - "torch" in str(dtype) or "jax.numpy" in str(dtype) - ): - dtype = str(dtype).split(".")[-1] + else: + # Only call str(dtype) once if needed + dtype_str = None + if hasattr(dtype, "__str__"): + dtype_str = str(dtype) + # Only check and parse if the str contains what we expect + if "torch" in dtype_str or "jax.numpy" in dtype_str: + dtype = dtype_str.split(".")[-1] if dtype not in dtypes.ALLOWED_DTYPES: raise ValueError(f"Invalid dtype: {dtype}") diff --git a/keras/src/dtype_policies/__init__.py b/keras/src/dtype_policies/__init__.py index 6bf0eb45bbb7..161a25edd92b 100644 --- a/keras/src/dtype_policies/__init__.py +++ b/keras/src/dtype_policies/__init__.py @@ -1,7 +1,7 @@ from keras.src import backend from keras.src.api_export import keras_export from keras.src.dtype_policies import dtype_policy -from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES +from keras.src.dtype_policies.dtype_policy import _get_quantized_dtype_policy_by_str, QUANTIZATION_MODES from keras.src.dtype_policies.dtype_policy import DTypePolicy from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy @@ -85,9 +85,6 @@ def get(identifier): Returns: A Keras `DTypePolicy` instance. """ - from keras.src.dtype_policies.dtype_policy import ( - _get_quantized_dtype_policy_by_str, - ) if identifier is None: return dtype_policy.dtype_policy() @@ -102,7 +99,7 @@ def get(identifier): return DTypePolicy(identifier) try: return DTypePolicy(backend.standardize_dtype(identifier)) - except: + except Exception: raise ValueError( "Cannot interpret `dtype` argument. Expected a string " f"or an instance of DTypePolicy. Received: dtype={identifier}"