diff --git a/flax/nnx/filterlib.py b/flax/nnx/filterlib.py index f293b6cc0..824eb4315 100644 --- a/flax/nnx/filterlib.py +++ b/flax/nnx/filterlib.py @@ -89,13 +89,16 @@ def __repr__(self): @dataclasses.dataclass(frozen=True) class PathContains: - key: Key + key: Key | str + exact: bool = True def __call__(self, path: PathParts, x: tp.Any): - return self.key in path + if self.exact: + return self.key in path + return any(str(self.key) in str(part) for part in path) def __repr__(self): - return f'PathContains({self.key!r})' + return f'PathContains({self.key!r}, exact={self.exact})' class PathIn: diff --git a/tests/nnx/filters_test.py b/tests/nnx/filters_test.py index 94dcfda1e..f330c920b 100644 --- a/tests/nnx/filters_test.py +++ b/tests/nnx/filters_test.py @@ -21,15 +21,20 @@ class TestFilters(absltest.TestCase): def test_path_contains(self): class Model(nnx.Module): def __init__(self, rngs): - self.backbone = nnx.Linear(2, 3, rngs=rngs) + self.backbone1 = nnx.Linear(2, 3, rngs=rngs) + self.backbone2 = nnx.Linear(3, 3, rngs=rngs) self.head = nnx.Linear(3, 10, rngs=rngs) model = Model(nnx.Rngs(0)) head_state = nnx.state(model, nnx.PathContains('head')) + backbones_state = nnx.state(model, nnx.PathContains('backbone', exact=False)) self.assertIn('head', head_state) self.assertNotIn('backbone', head_state) + self.assertIn('backbone1', backbones_state) + self.assertIn('backbone2', backbones_state) + self.assertNotIn('head', backbones_state) if __name__ == '__main__': absltest.main()