diff --git a/argbind/argbind.py b/argbind/argbind.py index 54e2fb2..b589e45 100644 --- a/argbind/argbind.py +++ b/argbind/argbind.py @@ -229,13 +229,18 @@ def dump_args(args, output_path): output.append(line) f.write('\n'.join(output)) -def load_args(input_path_or_stream): +def load_args(input_path_or_stream, config_directory = "."): """ Loads arguments from a given input path or file stream, if the file is already open. + + If config_directory is specified, the input path and all transitive + imports will be loaded relative to that directory. Streams and + absolute filepaths will not be affected. """ if isinstance(input_path_or_stream, (str, Path)): - with open(input_path_or_stream, 'r') as f: + input_path = Path(config_directory) / input_path_or_stream + with open(input_path, 'r') as f: data = yaml.load(f, Loader=yaml.Loader) else: data = yaml.load(input_path_or_stream, Loader=yaml.Loader) @@ -244,7 +249,7 @@ def load_args(input_path_or_stream): include_files = data.pop('$include') include_args = {} for include_file in include_files: - include_args.update(load_args(include_file)) + include_args.update(load_args(include_file, config_directory)) include_args.update(data) data = include_args diff --git a/tests/test_argbind.py b/tests/test_argbind.py index 3605840..a52b0a4 100644 --- a/tests/test_argbind.py +++ b/tests/test_argbind.py @@ -1,7 +1,30 @@ import argbind +import contextlib +import os + + +@contextlib.contextmanager +def temporary_working_directory(path): + original_working_directory = os.getcwd() + os.chdir(path) + + try: + yield + finally: + os.chdir(original_working_directory) + def test_load_args(): arg1 = argbind.load_args("examples/yaml/conf/base.yml") with open("examples/yaml/conf/base.yml") as f: arg2 = argbind.load_args(f) assert arg1 == arg2 + + +def test_config_directory(): + with temporary_working_directory("tests"): + arg1 = argbind.load_args("examples/yaml/conf/exp2.yml", config_directory="..") + with open("examples/yaml/conf/exp2.yml") as f: + arg2 = argbind.load_args(f) + + assert arg1 == arg2