diff --git a/fiddle/_src/absl_flags/utils.py b/fiddle/_src/absl_flags/utils.py index 1a41782b..84a547d2 100644 --- a/fiddle/_src/absl_flags/utils.py +++ b/fiddle/_src/absl_flags/utils.py @@ -108,7 +108,17 @@ def _import_dotted_name( f'attribute {failing_name!r}; available names: {available_names}' ) from None return value - except ModuleNotFoundError: + except ModuleNotFoundError as e: + # If the missing module isn't along the path we tried to import, + # it's an internal import error — surface the real error. + # e.name is None when the error has no associated module name + if e.name is None: + raise + # If e.name is a prefix of name_pieces, the path simply doesn't + # exist at this split point. Otherwise, the module + # exists but broke importing an unrelated dependency. + if (missing := e.name.split('.')) != name_pieces[: len(missing)]: + raise if i == 1: # Final iteration through the loop. raise diff --git a/fiddle/_src/absl_flags/utils_test.py b/fiddle/_src/absl_flags/utils_test.py index eaf6c519..2919dcd2 100644 --- a/fiddle/_src/absl_flags/utils_test.py +++ b/fiddle/_src/absl_flags/utils_test.py @@ -14,7 +14,9 @@ # limitations under the License. import dataclasses +import os import sys +import tempfile from typing import Any from absl.testing import absltest @@ -159,5 +161,36 @@ def test_from_fully_qualified_name(self): self.assertEqual(cfg, base_experiment()) +class ImportDottedNameTest(absltest.TestCase): + """Tests that _import_dotted_name surfaces real import errors.""" + + def test_internal_import_error_is_not_swallowed(self): + """A module that exists but has a broken import should raise.""" + tmpdir = self.enterContext(tempfile.TemporaryDirectory()) + module_path = os.path.join(tmpdir, '_broken_module.py') + with open(module_path, 'w') as f: + f.write('import _nonexistent_dependency\n') + sys.path.insert(0, tmpdir) + self.addCleanup(lambda: sys.path.remove(tmpdir)) + self.addCleanup(lambda: sys.modules.pop('_broken_module', None)) + + with self.assertRaises(ModuleNotFoundError) as ctx: + utils._import_dotted_name( + '_broken_module.some_symbol', + mode=_IRRELEVANT_MODE, + module=None, + ) + self.assertIn('_nonexistent_dependency', str(ctx.exception)) + + def test_nonexistent_module_raises_module_not_found(self): + """A module that doesn't exist should raise ModuleNotFoundError.""" + with self.assertRaises(ModuleNotFoundError): + utils._import_dotted_name( + 'completely_nonexistent_module.some_symbol', + mode=_IRRELEVANT_MODE, + module=None, + ) + + if __name__ == '__main__': absltest.main()