diff --git a/setup.py b/setup.py index 6f39d837428..610e2ca5ba9 100644 --- a/setup.py +++ b/setup.py @@ -225,6 +225,7 @@ ] + ['datasets'] EXTRAS = { + 'gcs_prefer_fsspec': ['fsspec', 'gcsfs'], 'matplotlib': ['matplotlib'], 'tensorflow': ['tensorflow>=2.1'], 'tf-nightly': ['tf-nightly'], diff --git a/tensorflow_datasets/__init__.py b/tensorflow_datasets/__init__.py index 9251bc58e93..293c642056e 100644 --- a/tensorflow_datasets/__init__.py +++ b/tensorflow_datasets/__init__.py @@ -39,6 +39,11 @@ from __future__ import annotations +import os + +if os.environ.get('GCS_PREFER_FSSPEC') == 'true': + os.environ['EPATH_PREFER_FSSPEC'] = 'true' + from absl import logging from etils import epy as _epy diff --git a/tensorflow_datasets/import_test.py b/tensorflow_datasets/import_test.py index 4620a8fec0c..f72a95484b7 100644 --- a/tensorflow_datasets/import_test.py +++ b/tensorflow_datasets/import_test.py @@ -15,6 +15,9 @@ """Test import.""" +import os +import subprocess +import sys import tensorflow_datasets as tfds @@ -23,6 +26,52 @@ class ImportTest(tfds.testing.TestCase): def test_import(self): pass + def test_gcs_prefer_fsspec_true(self): + env = os.environ.copy() + env['GCS_PREFER_FSSPEC'] = 'true' + env.pop('EPATH_PREFER_FSSPEC', None) + + code = """ +import os +import tensorflow_datasets as tfds +from etils import epath +print("EPATH_PREFER_FSSPEC:", os.environ.get('EPATH_PREFER_FSSPEC')) +p = epath.Path('gs://dummy-bucket/file.txt') +print("BACKEND:", type(p._backend).__name__) +""" + result = subprocess.run( + [sys.executable, '-c', code], + env=env, + capture_output=True, + text=True, + check=True, + ) + self.assertIn("EPATH_PREFER_FSSPEC: true", result.stdout) + self.assertIn("BACKEND: _FileSystemSpecBackend", result.stdout) + + def test_gcs_prefer_fsspec_false(self): + env = os.environ.copy() + env.pop('GCS_PREFER_FSSPEC', None) + env.pop('EPATH_PREFER_FSSPEC', None) + + code = """ +import os +import tensorflow_datasets as tfds +from etils import epath +print("EPATH_PREFER_FSSPEC:", os.environ.get('EPATH_PREFER_FSSPEC')) +p = epath.Path('gs://dummy-bucket/file.txt') +print("BACKEND:", type(p._backend).__name__) +""" + result = subprocess.run( + [sys.executable, '-c', code], + env=env, + capture_output=True, + text=True, + check=True, + ) + self.assertIn("EPATH_PREFER_FSSPEC: None", result.stdout) + self.assertIn("BACKEND: _TfBackend", result.stdout) + if __name__ == '__main__': tfds.testing.test_main()