From 519fb53bfb519711fa256f257d953515eb82ded1 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 15 Oct 2025 10:09:43 +0000 Subject: [PATCH 1/2] Support for python 3.14 Fixes https://github.com/google/flax/issues/5027 --- .github/workflows/flax_test.yml | 28 ++- flax/linen/__init__.py | 2 + flax/linen/kw_only_dataclasses.py | 221 +----------------------- flax/linen/module.py | 45 ++--- pyproject.toml | 14 +- tests/linen/kw_only_dataclasses_test.py | 35 ++-- tests/linen/linen_module_test.py | 12 +- 7 files changed, 84 insertions(+), 273 deletions(-) diff --git a/.github/workflows/flax_test.yml b/.github/workflows/flax_test.yml index af6a285c6..c2eadb38e 100644 --- a/.github/workflows/flax_test.yml +++ b/.github/workflows/flax_test.yml @@ -63,7 +63,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.11', '3.12', '3.13'] + python-version: ['3.11', '3.12', '3.13', '3.14'] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} @@ -170,3 +170,29 @@ jobs: "description": "'$status'", "context": "github-actions/Build" }' + + # This is a temporary workflow to test flax on Python 3.14 and + # skipping deps like tensorstore, tensorflow etc + tests-python314: + name: Run Tests on Python 3.14 + needs: [pre-commit, commit-count] + runs-on: ubuntu-24.04-16core + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Setup uv + uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0 + with: + version: "0.9.2" + python-version: "3.14" + activate-environment: true + enable-cache: true + + - name: Install dependencies + run: | + rm -fr .venv + uv sync --extra testing --extra docs + - name: Test with pytest + run: | + export XLA_FLAGS='--xla_force_host_platform_device_count=4' + find tests/ -name "*.py" | grep -vE 'io_test|tensorboard' | xargs pytest -n auto + diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index c15bb8424..62153937c 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -166,3 +166,5 @@ while_loop as while_loop, ) # pylint: enable=g-multiple-import +# For BC +from flax.linen import kw_only_dataclasses as kw_only_dataclasses diff --git a/flax/linen/kw_only_dataclasses.py b/flax/linen/kw_only_dataclasses.py index 95542fafe..6d801da99 100644 --- a/flax/linen/kw_only_dataclasses.py +++ b/flax/linen/kw_only_dataclasses.py @@ -12,230 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Support for keyword-only fields in dataclasses for Python versions <3.10. +"""This module is kept for backward compatibility. -This module provides wrappers for `dataclasses.dataclass` and -`dataclasses.field` that simulate support for keyword-only fields for Python -versions before 3.10 (which is the version where dataclasses added keyword-only -field support). If this module is imported in Python 3.10+, then -`kw_only_dataclasses.dataclass` and `kw_only_dataclasses.field` will simply be -aliases for `dataclasses.dataclass` and `dataclasses.field`. - -For earlier Python versions, when constructing a dataclass, any fields that have -been marked as keyword-only (including inherited fields) will be moved to the -end of the constuctor's argument list. This makes it possible to have a base -class that defines a field with a default, and a subclass that defines a field -without a default. E.g.: - ->>> from flax.linen import kw_only_dataclasses ->>> @kw_only_dataclasses.dataclass -... class Parent: -... name: str = kw_only_dataclasses.field(default='', kw_only=True) - ->>> @kw_only_dataclasses.dataclass -... class Child(Parent): -... size: float # required. - ->>> import inspect ->>> print(inspect.signature(Child.__init__)) -(self, size: float, name: str = '') -> None - - -(If we used `dataclasses` rather than `kw_only_dataclasses` for the above -example, then it would have failed with TypeError "non-default argument -'size' follows default argument.") - -WARNING: fields marked as keyword-only will not *actually* be turned into -keyword-only parameters in the constructor; they will only be moved to the -end of the parameter list (after all non-keyword-only parameters). +Previous code targeting Python versions <3.10 is removed and wired to +built-in dataclasses module. """ import dataclasses -import functools -import inspect -from types import MappingProxyType from typing import Any, TypeVar -import typing_extensions as tpe - import flax M = TypeVar('M', bound='flax.linen.Module') FieldName = str Annotation = Any Default = Any - - -class _KwOnlyType: - """Metadata tag used to tag keyword-only fields.""" - - def __repr__(self): - return 'KW_ONLY' - - -KW_ONLY = _KwOnlyType() - - -def field(*, metadata=None, kw_only=dataclasses.MISSING, **kwargs): - """Wrapper for dataclassess.field that adds support for kw_only fields. - - Args: - metadata: A mapping or None, containing metadata for the field. - kw_only: If true, the field will be moved to the end of `__init__`'s - parameter list. - **kwargs: Keyword arguments forwarded to `dataclasses.field` - - Returns: - A `dataclasses.Field` object. - """ - if kw_only is not dataclasses.MISSING and kw_only: - if ( - kwargs.get('default', dataclasses.MISSING) is dataclasses.MISSING - and kwargs.get('default_factory', dataclasses.MISSING) - is dataclasses.MISSING - ): - raise ValueError('Keyword-only fields with no default are not supported.') - if metadata is None: - metadata = {} - metadata[KW_ONLY] = True - return dataclasses.field(metadata=metadata, **kwargs) - - -@tpe.dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] -def dataclass(cls=None, extra_fields=None, **kwargs): - """Wrapper for dataclasses.dataclass that adds support for kw_only fields. - - Args: - cls: The class to transform (or none to return a decorator). - extra_fields: A list of `(name, type, Field)` tuples describing extra fields - that should be added to the dataclass. This is necessary for linen's - use-case of this module, since the base class (linen.Module) is *not* a - dataclass. In particular, linen.Module class is used as the base for both - frozen and non-frozen dataclass subclasses; but the frozen status of a - dataclass must match the frozen status of any base dataclasses. - **kwargs: Additional arguments for `dataclasses.dataclass`. - - Returns: - `cls`. - """ - - def wrap(cls): - return _process_class(cls, extra_fields=extra_fields, **kwargs) - - return wrap if cls is None else wrap(cls) - - -def _process_class(cls: type[M], extra_fields=None, **kwargs): - """Transforms `cls` into a dataclass that supports kw_only fields.""" - if '__annotations__' not in cls.__dict__: - cls.__annotations__ = {} - - # The original __dataclass_fields__ dicts for all base classes. We will - # modify these in-place before turning `cls` into a dataclass, and then - # restore them to their original values. - base_dataclass_fields = {} # dict[cls, cls.__dataclass_fields__.copy()] - - # The keyword only fields from `cls` or any of its base classes. - kw_only_fields: dict[FieldName, tuple[Annotation, Default]] = {} - - # Scan for KW_ONLY marker. - kw_only_name = None - for name, annotation in cls.__annotations__.items(): - if annotation is KW_ONLY: - if kw_only_name is not None: - raise TypeError('Multiple KW_ONLY markers') - kw_only_name = name - elif kw_only_name is not None: - if not hasattr(cls, name): - raise ValueError( - 'Keyword-only fields with no default are not supported.' - ) - default = getattr(cls, name) - if isinstance(default, dataclasses.Field): - default.metadata = MappingProxyType({**default.metadata, KW_ONLY: True}) - else: - default = field(default=default, kw_only=True) - setattr(cls, name, default) - if kw_only_name: - del cls.__annotations__[kw_only_name] - - # Inject extra fields. - if extra_fields: - for name, annotation, default in extra_fields: - if not (isinstance(name, str) and isinstance(default, dataclasses.Field)): - raise ValueError( - 'Expected extra_fields to a be a list of ' - '(name, type, Field) tuples.' - ) - setattr(cls, name, default) - cls.__annotations__[name] = annotation - - # Extract kw_only fields from base classes' __dataclass_fields__. - for base in reversed(cls.__mro__[1:]): - if not dataclasses.is_dataclass(base): - continue - base_annotations = base.__dict__.get('__annotations__', {}) - base_dataclass_fields[base] = dict( - getattr(base, '__dataclass_fields__', {}) - ) - for base_field in list(dataclasses.fields(base)): - field_name = base_field.name - if base_field.metadata.get(KW_ONLY) or field_name in kw_only_fields: - kw_only_fields[field_name] = ( - base_annotations.get(field_name), - base_field, - ) - del base.__dataclass_fields__[field_name] - - # Remove any keyword-only fields from this class. - cls_annotations = cls.__dict__['__annotations__'] - for name, annotation in list(cls_annotations.items()): - value = getattr(cls, name, None) - if ( - isinstance(value, dataclasses.Field) and value.metadata.get(KW_ONLY) - ) or name in kw_only_fields: - del cls_annotations[name] - kw_only_fields[name] = (annotation, value) - - # Add keyword-only fields at the end of __annotations__, in the order they - # were found in the base classes and in this class. - for name, (annotation, default) in kw_only_fields.items(): - setattr(cls, name, default) - cls_annotations.pop(name, None) - cls_annotations[name] = annotation - - create_init = '__init__' not in vars(cls) and kwargs.get('init', True) - - # Apply the dataclass transform. - transformed_cls: type[M] = dataclasses.dataclass(cls, **kwargs) - - # Restore the base classes' __dataclass_fields__. - for _cls, fields in base_dataclass_fields.items(): - _cls.__dataclass_fields__ = fields - - if create_init: - dataclass_init = transformed_cls.__init__ - # use sum to count the number of init fields that are not keyword-only - expected_num_args = sum( - f.init and not f.metadata.get(KW_ONLY, False) - for f in dataclasses.fields(transformed_cls) - ) - - @functools.wraps(dataclass_init) - def init_wrapper(self, *args, **kwargs): - num_args = len(args) - if num_args > expected_num_args: - # we add + 1 to each to account for `self`, matching python's - # default error message - raise TypeError( - f'__init__() takes {expected_num_args + 1} positional ' - f'arguments but {num_args + 1} were given' - ) - - dataclass_init(self, *args, **kwargs) - - init_wrapper.__signature__ = inspect.signature(dataclass_init) # type: ignore - transformed_cls.__init__ = init_wrapper # type: ignore[method-assign] - - # Return the transformed dataclass - return transformed_cls +KW_ONLY = dataclasses.KW_ONLY +field = dataclasses.field +dataclass = dataclasses.dataclass diff --git a/flax/linen/module.py b/flax/linen/module.py index f4dc1bd14..d2cbbfa36 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -19,7 +19,6 @@ import enum import functools import inspect -import sys import threading import typing import weakref @@ -57,7 +56,6 @@ union_filters, ) from flax.ids import FlaxId, uuid -from flax.linen import kw_only_dataclasses from flax.typing import ( RNGSequences, PRNGKey, @@ -1061,7 +1059,7 @@ def _customized_dataclass_transform(cls, kw_only: bool): 3. Generate a hash function (if not provided by cls). """ # Check reserved attributes have expected type annotations. - annotations = dict(cls.__dict__.get('__annotations__', {})) + annotations = inspect.get_annotations(cls) if annotations.get('parent', _ParentType) != _ParentType: raise errors.ReservedModuleAttributeError(annotations) if annotations.get('name', str) not in ('str', str, Optional[str]): @@ -1081,42 +1079,29 @@ def _customized_dataclass_transform(cls, kw_only: bool): ( 'parent', _ParentType, - kw_only_dataclasses.field( + dataclasses.field( repr=False, default=_unspecified_parent, kw_only=True ), ), ( 'name', Optional[str], - kw_only_dataclasses.field(default=None, kw_only=True), + dataclasses.field(default=None, kw_only=True), ), ] - if kw_only: - if tuple(sys.version_info)[:3] >= (3, 10, 0): - for ( - name, - annotation, # pytype: disable=invalid-annotation - default, - ) in extra_fields: - setattr(cls, name, default) - cls.__annotations__[name] = annotation - dataclasses.dataclass( # type: ignore[call-overload] - unsafe_hash='__hash__' not in cls.__dict__, - repr=False, - kw_only=True, - )(cls) - else: - raise TypeError('`kw_only` is not available before Py 3.10.') - else: - # Now apply dataclass transform (which operates in-place). - # Do generate a hash function only if not provided by the class. - kw_only_dataclasses.dataclass( - cls, - unsafe_hash='__hash__' not in cls.__dict__, - repr=False, - extra_fields=extra_fields, - ) # pytype: disable=wrong-keyword-args + for ( + name, + annotation, # pytype: disable=invalid-annotation + default, + ) in extra_fields: + setattr(cls, name, default) + cls.__annotations__[name] = annotation + dataclasses.dataclass( # type: ignore[call-overload] + unsafe_hash='__hash__' not in cls.__dict__, + repr=False, + kw_only=kw_only, + )(cls) cls.__hash__ = _wrap_hash(cls.__hash__) # type: ignore[method-assign] diff --git a/pyproject.toml b/pyproject.toml index ebf5dcd90..72c617c39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ testing = [ "clu", "clu<=0.0.9; python_version<'3.10'", "einops", - "gymnasium[atari]", + "gymnasium[atari]; python_version<'3.14'", "jaxlib", "jaxtyping", "jraph>=0.0.6dev0", @@ -61,11 +61,11 @@ testing = [ "tensorflow_text>=2.11.0; platform_system!='Darwin' and python_version < '3.13'", "tensorflow_datasets", "tensorflow>=2.12.0; python_version<'3.13'", # to fix Numpy np.bool8 deprecation error - "tensorflow>=2.20.0; python_version>='3.13'", + "tensorflow>=2.20.0; python_version>='3.13' and python_version<'3.14'", "torch", "treescope>=0.1.1; python_version>='3.10'", "cloudpickle>=3.0.0", - "ale-py>=0.10.2", + "ale-py>=0.10.2; python_version<'3.14'", ] docs = [ "sphinx==6.2.1", @@ -237,3 +237,11 @@ quote-style = "single" [tool.uv] # Ignore uv.lock and always upgrade the package to the latest upgrade-package = ["jax", "jaxlib", "orbax-checkpoint"] + +[tool.uv.sources] +torch = { index = "pytorch" } + +[[tool.uv.index]] +name = "pytorch" +url = "https://download.pytorch.org/whl/cpu" +explicit = true diff --git a/tests/linen/kw_only_dataclasses_test.py b/tests/linen/kw_only_dataclasses_test.py index 835312cbd..54b8a0b35 100644 --- a/tests/linen/kw_only_dataclasses_test.py +++ b/tests/linen/kw_only_dataclasses_test.py @@ -12,22 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for kw_only_dataclasses.""" +"""Tests for dataclasses.""" import dataclasses import inspect from absl.testing import absltest -from flax.linen import kw_only_dataclasses - class KwOnlyDataclassesTest(absltest.TestCase): def test_kwonly_args_moved_to_end(self): - @kw_only_dataclasses.dataclass + @dataclasses.dataclass class TestClass: a: int = 1 - b: int = kw_only_dataclasses.field(default=2, kw_only=True) + b: int = dataclasses.field(default=2, kw_only=True) c: int = 3 params = inspect.signature(TestClass.__init__).parameters @@ -46,11 +44,11 @@ class TestClass: self.assertDictEqual(dataclasses.asdict(v3), dict(a=1, b=2, c=30)) def test_base_optional_subclass_required(self): - @kw_only_dataclasses.dataclass + @dataclasses.dataclass class Parent: - a: int = kw_only_dataclasses.field(default=2, kw_only=True) + a: int = dataclasses.field(default=2, kw_only=True) - @kw_only_dataclasses.dataclass + @dataclasses.dataclass class Child(Parent): b: int @@ -65,21 +63,22 @@ class Child(Parent): v2 = Child(4, a=5) # pylint: disable=too-many-function-args self.assertDictEqual(dataclasses.asdict(v2), dict(a=5, b=4)) + @absltest.expectedFailureIf(True, "non-default argument 'size' follows default argument") def test_subclass_overrides_base(self): # Note: if a base class declares a field as keyword-only, then # subclasses don't need to also declare it as keyword-only. - @kw_only_dataclasses.dataclass + @dataclasses.dataclass class A: - x: int = kw_only_dataclasses.field(default=1, kw_only=True) + x: int = dataclasses.field(default=1, kw_only=True) - @kw_only_dataclasses.dataclass + @dataclasses.dataclass class B(A): size: float - y: int = kw_only_dataclasses.field(default=3, kw_only=True) + y: int = dataclasses.field(default=3, kw_only=True) x: int = 2 - @kw_only_dataclasses.dataclass + @dataclasses.dataclass class C(B): name: str @@ -106,15 +105,15 @@ class C(B): ) def test_kwonly_marker(self): - @kw_only_dataclasses.dataclass + @dataclasses.dataclass class A: x: float - _: kw_only_dataclasses.KW_ONLY + _: dataclasses.KW_ONLY a: int = 5 - b: int = kw_only_dataclasses.field(default=2) - c: int = kw_only_dataclasses.field(default=2, kw_only=True) + b: int = dataclasses.field(default=2) + c: int = dataclasses.field(default=2, kw_only=True) - @kw_only_dataclasses.dataclass + @dataclasses.dataclass class B(A): z: str diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index c9d4ee646..b8b57e504 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -727,16 +727,16 @@ def __call__(self, x): list(Test.__dataclass_fields__.keys()), ['bar', 'parent', 'name'] ) self.assertEqual( - list(Test2.__dataclass_fields__.keys()), - ['bar', 'baz', 'parent', 'name'], + set(Test2.__dataclass_fields__.keys()), + {'bar', 'baz', 'parent', 'name'}, ) self.assertEqual( - list(Test3.__dataclass_fields__.keys()), - ['bar', 'baz', 'parent', 'name'], + set(Test3.__dataclass_fields__.keys()), + {'bar', 'baz', 'parent', 'name'}, ) self.assertEqual( - list(Test4.__dataclass_fields__.keys()), - ['bar', 'baz', 'parent', 'name'], + set(Test4.__dataclass_fields__.keys()), + {'bar', 'baz', 'parent', 'name'}, ) def test_get_suffix_value_pairs(self): From b43ade803a8083f0b881bc545ac0ba1ad09e03f7 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 17 Nov 2025 16:26:34 +0000 Subject: [PATCH 2/2] Added workaround by P.Hawkins --- .github/workflows/flax_test.yml | 2 ++ flax/linen/module.py | 7 +++++++ tests/linen/linen_module_test.py | 21 +++++++++++++++++++++ 3 files changed, 30 insertions(+) diff --git a/.github/workflows/flax_test.yml b/.github/workflows/flax_test.yml index c2eadb38e..fbad93ba0 100644 --- a/.github/workflows/flax_test.yml +++ b/.github/workflows/flax_test.yml @@ -191,6 +191,8 @@ jobs: run: | rm -fr .venv uv sync --extra testing --extra docs + # temporary: install jax nightly + uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ - name: Test with pytest run: | export XLA_FLAGS='--xla_force_host_platform_device_count=4' diff --git a/flax/linen/module.py b/flax/linen/module.py index d2cbbfa36..b5c26a463 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -19,6 +19,7 @@ import enum import functools import inspect +import sys import threading import typing import weakref @@ -1097,6 +1098,12 @@ def _customized_dataclass_transform(cls, kw_only: bool): ) in extra_fields: setattr(cls, name, default) cls.__annotations__[name] = annotation + + # TODO: a workaround for the issue: + # https://github.com/google/flax/pull/5087#issuecomment-3536610568 + if (sys.version_info.major, sys.version_info.minor) in [(3, 12), (3, 13)]: + setattr(cls, '__annotations__', cls.__annotations__) + dataclasses.dataclass( # type: ignore[call-overload] unsafe_hash='__hash__' not in cls.__dict__, repr=False, diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index b8b57e504..199d38592 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -2315,6 +2315,27 @@ class Foo(nn.Module): Foo(1, None) Foo(a=1, parent=None) # type: ignore[call-arg] + def test_failure_with_sequencelayer(self): + # This is a minimal reproducer of the failure seen with + # SequenceLayer project and Flax Linen when enabled support for 3.14 + # See PR: https://github.com/google/flax/pull/5087 + # Code below is based on + # https://github.com/google/flax/pull/5087#issuecomment-3535067361 + import abc + from collections.abc import Iterator + from typing import Protocol + + class CheckpointableIterator(Iterator, Protocol): + pass + + class Steppable(metaclass=abc.ABCMeta): + pass + + isinstance(Steppable, Iterator) + + class SequenceLayer(nn.Module, Steppable): + pass + def test_module_path_empty(self): rngkey = jax.random.key(0) scope = Scope({}, {'params': rngkey}, mutable=['params'])