From 1a2cbdd09ee2002c7184f92fa009a65992474b7f Mon Sep 17 00:00:00 2001 From: Scott Breyfogle Date: Thu, 1 Aug 2024 14:28:36 -0700 Subject: [PATCH 1/3] Add and test relative config loading --- argbind/argbind.py | 9 +++++++-- tests/test_argbind.py | 8 ++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/argbind/argbind.py b/argbind/argbind.py index 54e2fb2..40f168a 100644 --- a/argbind/argbind.py +++ b/argbind/argbind.py @@ -229,13 +229,18 @@ def dump_args(args, output_path): output.append(line) f.write('\n'.join(output)) -def load_args(input_path_or_stream): +def load_args(input_path_or_stream, config_directory = "."): """ Loads arguments from a given input path or file stream, if the file is already open. + + If config_directory is specified, the input path and all transitive + imports will be loaded relative to that directory. Streams and + absolute filepaths will not be affected. """ if isinstance(input_path_or_stream, (str, Path)): - with open(input_path_or_stream, 'r') as f: + input_path = Path(config_directory) / input_path_or_stream + with open(input_path, 'r') as f: data = yaml.load(f, Loader=yaml.Loader) else: data = yaml.load(input_path_or_stream, Loader=yaml.Loader) diff --git a/tests/test_argbind.py b/tests/test_argbind.py index 3605840..63243cd 100644 --- a/tests/test_argbind.py +++ b/tests/test_argbind.py @@ -1,7 +1,15 @@ import argbind + def test_load_args(): arg1 = argbind.load_args("examples/yaml/conf/base.yml") with open("examples/yaml/conf/base.yml") as f: arg2 = argbind.load_args(f) assert arg1 == arg2 + + +def test_config_directory(): + arg1 = argbind.load_args("conf/exp2.yml", config_directory="examples/yaml") + with open("examples/yaml/conf/exp2.yml") as f: + arg2 = argbind.load_args(f) + assert arg1 == arg2 From 079f632d07ea55fc4b89ae77c09b794400783af7 Mon Sep 17 00:00:00 2001 From: Scott Breyfogle Date: Thu, 1 Aug 2024 14:44:18 -0700 Subject: [PATCH 2/3] Make directory transitive --- argbind/argbind.py | 257 +++++++++++++++++++++++++----------------- tests/test_argbind.py | 17 ++- 2 files changed, 167 insertions(+), 107 deletions(-) diff --git a/argbind/argbind.py b/argbind/argbind.py index 40f168a..963049b 100644 --- a/argbind/argbind.py +++ b/argbind/argbind.py @@ -19,10 +19,11 @@ DEBUG = False HELP_WIDTH = 60 + @contextmanager -def scope(parsed_args, pattern=''): +def scope(parsed_args, pattern=""): """ - Context manager to put parsed arguments into + Context manager to put parsed arguments into a state. """ parsed_args = parsed_args.copy() @@ -36,11 +37,11 @@ def scope(parsed_args, pattern=''): old_pattern = PATTERN for key in parsed_args: - if '/' in key: - if key.split('/')[0] == pattern: - matched[key.split('/')[-1]] = parsed_args[key] + if "/" in key: + if key.split("/")[0] == pattern: + matched[key.split("/")[-1]] = parsed_args[key] remove_keys.append(key) - + parsed_args.update(matched) for key in remove_keys: parsed_args.pop(key) @@ -51,6 +52,7 @@ def scope(parsed_args, pattern=''): ARGS = old_args PATTERN = old_pattern + def _format_func_debug(func_name, func_kwargs, scope=None): formatted = [f"{func_name}("] if scope is not None: @@ -58,9 +60,12 @@ def _format_func_debug(func_name, func_kwargs, scope=None): for key, val in func_kwargs.items(): formatted.append(f" {key} : {type(val).__name__} = {val}") formatted.append(")") - return '\n'.join(formatted) + return "\n".join(formatted) -def bind(*args, without_prefix=False, positional=False, group: Union[list, str] = "default"): + +def bind( + *args, without_prefix=False, positional=False, group: Union[list, str] = "default" +): """Binds a functions arguments so that it looks up argument values in a dictionary scoped by ArgBind. @@ -72,7 +77,7 @@ def bind(*args, without_prefix=False, positional=False, group: Union[list, str] here (e.g. decorate is called on the first argument). Otherwise, it is treated is a decorator. without_prefix : bool, optional - Whether or not to bind without the function name as the prefix. + Whether or not to bind without the function name as the prefix. If True, the functions arguments will be available at "arg_name" rather than "func_name.arg_name", by default False positional : bool, optional @@ -91,7 +96,8 @@ def bind(*args, without_prefix=False, positional=False, group: Union[list, str] if positional and patterns: warnings.warn( f"Combining positional arguments with scoping patterns is not allowed. Removing scoping patterns {patterns}. \n" - "See https://github.com/pseeth/argbind/tree/main/examples/hello_world#argbind-with-positional-arguments") + "See https://github.com/pseeth/argbind/tree/main/examples/hello_world#argbind-with-positional-arguments" + ) patterns = [] if isinstance(group, str): @@ -101,30 +107,30 @@ def decorator(object_or_func): func = object_or_func is_class = inspect.isclass(func) if is_class: - func = getattr(func, "__init__") + func = getattr(func, "__init__") prefix = func.__qualname__ if "__init__" in prefix: prefix = prefix.split(".")[0] - + # Check if function is bound already. If it is, just re-wrap it, # instead of wrapping the function twice. if prefix in PARSE_FUNCS: func = PARSE_FUNCS[prefix][0] else: PARSE_FUNCS[prefix] = (func, patterns, without_prefix, positional, group) - + @wraps(func) def cmd_func(*args, **kwargs): parameters = list(inspect.signature(func).parameters.items()) - + cmd_kwargs = {} pos_kwargs = {parameters[i][0]: arg for i, arg in enumerate(args)} for key, val in parameters: arg_val = val.default if arg_val is not inspect.Parameter.empty or positional: - arg_name = f'{prefix}.{key}' if not without_prefix else f'{key}' + arg_name = f"{prefix}.{key}" if not without_prefix else f"{key}" if arg_name in ARGS and key not in kwargs: val = ARGS[arg_name] if key in pos_kwargs: @@ -132,7 +138,7 @@ def cmd_func(*args, **kwargs): cmd_kwargs[key] = val use_key = arg_name if PATTERN: - use_key = f'{PATTERN}/{use_key}' + use_key = f"{PATTERN}/{use_key}" USED_ARGS[use_key] = val kwargs.update(cmd_kwargs) @@ -149,15 +155,16 @@ def cmd_func(*args, **kwargs): ordered_kwargs[k] = kwargs[k] kwargs = ordered_kwargs - if 'args.debug' not in ARGS: ARGS['args.debug'] = False - if ARGS['args.debug'] or DEBUG: - if PATTERN: + if "args.debug" not in ARGS: + ARGS["args.debug"] = False + if ARGS["args.debug"] or DEBUG: + if PATTERN: scope = PATTERN else: scope = None print(_format_func_debug(prefix, kwargs, scope)) return func(*cmd_args, **kwargs) - + if is_class: setattr(object_or_func, "__init__", cmd_func) cmd_func = object_or_func @@ -169,14 +176,16 @@ def cmd_func(*args, **kwargs): else: return decorator(bound_fn_or_cls) + # Backwards compat. # For scripts written with argbind<=0.1.3. bind_to_parser = bind + class bind_module: def __init__(self, module, *scopes, filter_fn=lambda fn: True, **kwargs): """Binds every function/class in a specified module. The output - class is a bound version of the original module, with the + class is a bound version of the original module, with the attributes in the same place. Parameters @@ -186,7 +195,7 @@ class is a bound version of the original module, with the scopes : List[str] or [fn or Object] + List[str], optional List of patterns to bind the function under. filter_fn : Callable, optional - A function that takes in the function that is to be bound, and + A function that takes in the function that is to be bound, and returns a boolean as to whether or not it should be bound. Defaults to always True, no matter what the function is. kwargs : keyword arguments, optional @@ -200,6 +209,7 @@ class is a bound version of the original module, with the bound_fn = bind(fn, *scopes, **kwargs) setattr(self, fn_name, bound_fn) + def get_used_args(): """ Gets the args that have been used so far @@ -208,6 +218,7 @@ def get_used_args(): """ return USED_ARGS + def dump_args(args, output_path): """ Dumps the provided arguments to a @@ -215,21 +226,22 @@ def dump_args(args, output_path): """ path = Path(output_path) os.makedirs(path.parent, exist_ok=True) - with open(path, 'w') as f: - yaml.Dumper.ignore_aliases = lambda *args : True + with open(path, "w") as f: + yaml.Dumper.ignore_aliases = lambda *args: True x = yaml.dump(args, Dumper=yaml.Dumper) prev_line = None output = [] - for line in x.split('\n'): - cur_line = line.split('.')[0].strip() - if not cur_line.startswith('-'): + for line in x.split("\n"): + cur_line = line.split(".")[0].strip() + if not cur_line.startswith("-"): if cur_line != prev_line and prev_line: - line = f'\n{line}' - prev_line = line.split('.')[0].strip() + line = f"\n{line}" + prev_line = line.split(".")[0].strip() output.append(line) - f.write('\n'.join(output)) + f.write("\n".join(output)) -def load_args(input_path_or_stream, config_directory = "."): + +def load_args(input_path_or_stream, config_directory="."): """ Loads arguments from a given input path or file stream, if the file is already open. @@ -240,35 +252,35 @@ def load_args(input_path_or_stream, config_directory = "."): """ if isinstance(input_path_or_stream, (str, Path)): input_path = Path(config_directory) / input_path_or_stream - with open(input_path, 'r') as f: + with open(input_path, "r") as f: data = yaml.load(f, Loader=yaml.Loader) else: data = yaml.load(input_path_or_stream, Loader=yaml.Loader) - - if '$include' in data: - include_files = data.pop('$include') + + if "$include" in data: + include_files = data.pop("$include") include_args = {} for include_file in include_files: - include_args.update(load_args(include_file)) + include_args.update(load_args(include_file, config_directory)) include_args.update(data) data = include_args _vars = os.environ.copy() - if '$vars' in data: - _vars.update(data.pop('$vars')) - + if "$vars" in data: + _vars.update(data.pop("$vars")) + for key, val in data.items(): # Check if string starts with $. - if isinstance(val, str): - if val.startswith('$'): + if isinstance(val, str): + if val.startswith("$"): lookup = val[1:] if lookup in _vars: data[key] = _vars[lookup] - + elif isinstance(val, list): new_list = [] for subval in val: - if isinstance(subval, str) and subval.startswith('$'): + if isinstance(subval, str) and subval.startswith("$"): lookup = subval[1:] if lookup in _vars: new_list.append(_vars[lookup]) @@ -278,27 +290,32 @@ def load_args(input_path_or_stream, config_directory = "."): new_list.append(subval) data[key] = new_list - if 'args.debug' not in data: - data['args.debug'] = DEBUG + if "args.debug" not in data: + data["args.debug"] = DEBUG return data -class str_to_list(): + +class str_to_list: def __init__(self, _type): self._type = _type + def __call__(self, values): - _values = values.split(' ') + _values = values.split(" ") _values = [self._type(v) for v in _values] return _values -class str_to_tuple(): + +class str_to_tuple: def __init__(self, _type_list): self._type_list = _type_list + def __call__(self, values): - _values = values.split(' ') + _values = values.split(" ") _values = [self._type_list[i](v) for i, v in enumerate(_values)] return tuple(_values) -class str_to_dict(): + +class str_to_dict: def __init__(self): pass @@ -311,17 +328,18 @@ def _guess_type(self, s): return value def __call__(self, values): - values = values.split(' ') + values = values.split(" ") _values = {} for elem in values: - key, val = elem.split('=', 1) + key, val = elem.split("=", 1) key = self._guess_type(key) val = self._guess_type(val) _values[key] = val return _values + def build_parser(group: Union[list, str] = "default"): """Builds the argument parser from all of the bound functions. @@ -330,16 +348,27 @@ def build_parser(group: Union[list, str] = "default"): ArgumentParser Argument parser built by ArgBind. """ - p = argparse.ArgumentParser( - formatter_class=argparse.RawTextHelpFormatter - ) + p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) - p.add_argument('--args.save', type=str, required=False, - help="Path to save all arguments used to run script to.") - p.add_argument('--args.load', type=str, required=False, - help="Path to load arguments from, stored as a .yml file.") - p.add_argument('--args.debug', type=int, required=False, default=0, - help="Print arguments as they are passed to each function.") + p.add_argument( + "--args.save", + type=str, + required=False, + help="Path to save all arguments used to run script to.", + ) + p.add_argument( + "--args.load", + type=str, + required=False, + help="Path to load arguments from, stored as a .yml file.", + ) + p.add_argument( + "--args.debug", + type=int, + required=False, + default=0, + help="Print arguments as they are passed to each function.", + ) if isinstance(group, str): group = [group] @@ -355,9 +384,7 @@ def build_parser(group: Union[list, str] = "default"): docstring = docstring_parser.parse(func.__doc__) parameter_help = docstring.params - parameter_help = { - x.arg_name: x.description for x in parameter_help - } + parameter_help = {x.arg_name: x.description for x in parameter_help} f = p.add_argument_group( title=f"Generated arguments for function {prefix}", @@ -367,22 +394,19 @@ def _get_arg_names(key, is_kwarg): arg_names = [] arg_name = key - prepend = '--' if is_kwarg else '' + prepend = "--" if is_kwarg else "" if without_prefix: - arg_name = prepend + f'PATTERN/{key}' + arg_name = prepend + f"PATTERN/{key}" else: - arg_name = prepend + f'PATTERN/{prefix}.{key}' + arg_name = prepend + f"PATTERN/{prefix}.{key}" + + arg_names.append(arg_name.replace("PATTERN/", "")) - arg_names.append(arg_name.replace('PATTERN/', '')) - if patterns is not None: for p in patterns: - arg_names.append( - arg_name.replace('PATTERN', p) - ) + arg_names.append(arg_name.replace("PATTERN", p)) return arg_names - for key, val in sig.parameters.items(): arg_val = val.default arg_type = val.annotation @@ -394,7 +418,7 @@ def _get_arg_names(key, is_kwarg): if is_kwarg or positional: arg_names = _get_arg_names(key, is_kwarg) arg_help = {} - help_text = '' + help_text = "" if key in parameter_help: help_text = textwrap.fill(parameter_help[key], width=HELP_WIDTH) arg_help[arg_names[0]] = help_text @@ -407,33 +431,51 @@ def _get_arg_names(key, is_kwarg): list_types = [List[x] for x in inner_types] if arg_type is bool: - f.add_argument(arg_name, action='store_true', - help=arg_help[arg_name]) + f.add_argument( + arg_name, action="store_true", help=arg_help[arg_name] + ) elif arg_type in list_types: _type = inner_types[list_types.index(arg_type)] - f.add_argument(arg_name, type=str_to_list(_type), - default=arg_val, help=arg_help[arg_name]) + f.add_argument( + arg_name, + type=str_to_list(_type), + default=arg_val, + help=arg_help[arg_name], + ) elif arg_type is Dict: - f.add_argument(arg_name, type=str_to_dict(), - default=arg_val, help=arg_help[arg_name]) - elif hasattr(arg_type, '__origin__'): + f.add_argument( + arg_name, + type=str_to_dict(), + default=arg_val, + help=arg_help[arg_name], + ) + elif hasattr(arg_type, "__origin__"): if arg_type.__origin__ is tuple: _type_list = arg_type.__args__ - f.add_argument(arg_name, type=str_to_tuple(_type_list), - default=arg_val, help=arg_help[arg_name]) + f.add_argument( + arg_name, + type=str_to_tuple(_type_list), + default=arg_val, + help=arg_help[arg_name], + ) else: - f.add_argument(arg_name, type=arg_type, - default=arg_val, help=arg_help[arg_name]) - + f.add_argument( + arg_name, + type=arg_type, + default=arg_val, + help=arg_help[arg_name], + ) + desc = docstring.short_description - if desc is None: desc = '' + if desc is None: + desc = "" if patterns: if not without_prefix: scope_pattern = f"--{patterns[0]}/{prefix}.{key}" else: scope_pattern = f"--{patterns[0]}/{key}" - + desc += ( f" Additional scope patterns: {', '.join(list(patterns))}. " "Use these by prefacing any of the args below with one " @@ -443,36 +485,39 @@ def _get_arg_names(key, is_kwarg): desc = textwrap.fill(desc, width=HELP_WIDTH) f.description = desc - + return p + def parse_args(p=None, group: Union[list, str] = "default"): """ Parses the command line and returns a dictionary. Builds the argument parser if p is None. """ p = build_parser(group=group) if p is None else p - used_args = [x.replace('--', '').split('=')[0] for x in sys.argv if x.startswith('--')] - used_args.extend(['args.save', 'args.load']) + used_args = [ + x.replace("--", "").split("=")[0] for x in sys.argv if x.startswith("--") + ] + used_args.extend(["args.save", "args.load"]) known, unknown = p.parse_known_args() args = vars(known) args["args.unknown"] = unknown - load_args_path = args.pop('args.load') - save_args_path = args.pop('args.save') - debug_args = args.pop('args.debug') - - pattern_keys = [key for key in args if '/' in key] - top_level_args = [key for key in args if '/' not in key] + load_args_path = args.pop("args.load") + save_args_path = args.pop("args.save") + debug_args = args.pop("args.debug") + + pattern_keys = [key for key in args if "/" in key] + top_level_args = [key for key in args if "/" not in key] for key in pattern_keys: # If the top-level arguments were altered but the ones # in patterns were not, change the scoped ones to # match the top-level (inherit arguments from top-level). - pattern, arg_name = key.split('/') + pattern, arg_name = key.split("/") if key not in used_args: args[key] = args[arg_name] - + if load_args_path: loaded_args = load_args(load_args_path) # Overwrite defaults with things in loaded arguments. @@ -481,15 +526,15 @@ def parse_args(p=None, group: Union[list, str] = "default"): if key not in used_args: args[key] = loaded_args[key] for key in pattern_keys: - pattern, arg_name = key.split('/') + pattern, arg_name = key.split("/") if key not in loaded_args and key not in used_args: if arg_name in loaded_args: args[key] = args[arg_name] - + for key in top_level_args: if key in used_args: for pattern_key in pattern_keys: - pattern, arg_name = pattern_key.split('/') + pattern, arg_name = pattern_key.split("/") if key == arg_name and pattern_key not in used_args: args[pattern_key] = args[key] @@ -497,8 +542,8 @@ def parse_args(p=None, group: Union[list, str] = "default"): dump_args(args, save_args_path) # Put them back in case the script wants to use them - args['args.load'] = load_args_path - args['args.save'] = save_args_path - args['args.debug'] = debug_args - + args["args.load"] = load_args_path + args["args.save"] = save_args_path + args["args.debug"] = debug_args + return args diff --git a/tests/test_argbind.py b/tests/test_argbind.py index 63243cd..a52b0a4 100644 --- a/tests/test_argbind.py +++ b/tests/test_argbind.py @@ -1,4 +1,17 @@ import argbind +import contextlib +import os + + +@contextlib.contextmanager +def temporary_working_directory(path): + original_working_directory = os.getcwd() + os.chdir(path) + + try: + yield + finally: + os.chdir(original_working_directory) def test_load_args(): @@ -9,7 +22,9 @@ def test_load_args(): def test_config_directory(): - arg1 = argbind.load_args("conf/exp2.yml", config_directory="examples/yaml") + with temporary_working_directory("tests"): + arg1 = argbind.load_args("examples/yaml/conf/exp2.yml", config_directory="..") with open("examples/yaml/conf/exp2.yml") as f: arg2 = argbind.load_args(f) + assert arg1 == arg2 From adf5888eca0901a37bd093003b5617d9a258044c Mon Sep 17 00:00:00 2001 From: Scott Breyfogle Date: Thu, 1 Aug 2024 14:45:10 -0700 Subject: [PATCH 3/3] Revert autoformat --- argbind/argbind.py | 255 +++++++++++++++++++-------------------------- 1 file changed, 105 insertions(+), 150 deletions(-) diff --git a/argbind/argbind.py b/argbind/argbind.py index 963049b..b589e45 100644 --- a/argbind/argbind.py +++ b/argbind/argbind.py @@ -19,11 +19,10 @@ DEBUG = False HELP_WIDTH = 60 - @contextmanager -def scope(parsed_args, pattern=""): +def scope(parsed_args, pattern=''): """ - Context manager to put parsed arguments into + Context manager to put parsed arguments into a state. """ parsed_args = parsed_args.copy() @@ -37,11 +36,11 @@ def scope(parsed_args, pattern=""): old_pattern = PATTERN for key in parsed_args: - if "/" in key: - if key.split("/")[0] == pattern: - matched[key.split("/")[-1]] = parsed_args[key] + if '/' in key: + if key.split('/')[0] == pattern: + matched[key.split('/')[-1]] = parsed_args[key] remove_keys.append(key) - + parsed_args.update(matched) for key in remove_keys: parsed_args.pop(key) @@ -52,7 +51,6 @@ def scope(parsed_args, pattern=""): ARGS = old_args PATTERN = old_pattern - def _format_func_debug(func_name, func_kwargs, scope=None): formatted = [f"{func_name}("] if scope is not None: @@ -60,12 +58,9 @@ def _format_func_debug(func_name, func_kwargs, scope=None): for key, val in func_kwargs.items(): formatted.append(f" {key} : {type(val).__name__} = {val}") formatted.append(")") - return "\n".join(formatted) + return '\n'.join(formatted) - -def bind( - *args, without_prefix=False, positional=False, group: Union[list, str] = "default" -): +def bind(*args, without_prefix=False, positional=False, group: Union[list, str] = "default"): """Binds a functions arguments so that it looks up argument values in a dictionary scoped by ArgBind. @@ -77,7 +72,7 @@ def bind( here (e.g. decorate is called on the first argument). Otherwise, it is treated is a decorator. without_prefix : bool, optional - Whether or not to bind without the function name as the prefix. + Whether or not to bind without the function name as the prefix. If True, the functions arguments will be available at "arg_name" rather than "func_name.arg_name", by default False positional : bool, optional @@ -96,8 +91,7 @@ def bind( if positional and patterns: warnings.warn( f"Combining positional arguments with scoping patterns is not allowed. Removing scoping patterns {patterns}. \n" - "See https://github.com/pseeth/argbind/tree/main/examples/hello_world#argbind-with-positional-arguments" - ) + "See https://github.com/pseeth/argbind/tree/main/examples/hello_world#argbind-with-positional-arguments") patterns = [] if isinstance(group, str): @@ -107,30 +101,30 @@ def decorator(object_or_func): func = object_or_func is_class = inspect.isclass(func) if is_class: - func = getattr(func, "__init__") + func = getattr(func, "__init__") prefix = func.__qualname__ if "__init__" in prefix: prefix = prefix.split(".")[0] - + # Check if function is bound already. If it is, just re-wrap it, # instead of wrapping the function twice. if prefix in PARSE_FUNCS: func = PARSE_FUNCS[prefix][0] else: PARSE_FUNCS[prefix] = (func, patterns, without_prefix, positional, group) - + @wraps(func) def cmd_func(*args, **kwargs): parameters = list(inspect.signature(func).parameters.items()) - + cmd_kwargs = {} pos_kwargs = {parameters[i][0]: arg for i, arg in enumerate(args)} for key, val in parameters: arg_val = val.default if arg_val is not inspect.Parameter.empty or positional: - arg_name = f"{prefix}.{key}" if not without_prefix else f"{key}" + arg_name = f'{prefix}.{key}' if not without_prefix else f'{key}' if arg_name in ARGS and key not in kwargs: val = ARGS[arg_name] if key in pos_kwargs: @@ -138,7 +132,7 @@ def cmd_func(*args, **kwargs): cmd_kwargs[key] = val use_key = arg_name if PATTERN: - use_key = f"{PATTERN}/{use_key}" + use_key = f'{PATTERN}/{use_key}' USED_ARGS[use_key] = val kwargs.update(cmd_kwargs) @@ -155,16 +149,15 @@ def cmd_func(*args, **kwargs): ordered_kwargs[k] = kwargs[k] kwargs = ordered_kwargs - if "args.debug" not in ARGS: - ARGS["args.debug"] = False - if ARGS["args.debug"] or DEBUG: - if PATTERN: + if 'args.debug' not in ARGS: ARGS['args.debug'] = False + if ARGS['args.debug'] or DEBUG: + if PATTERN: scope = PATTERN else: scope = None print(_format_func_debug(prefix, kwargs, scope)) return func(*cmd_args, **kwargs) - + if is_class: setattr(object_or_func, "__init__", cmd_func) cmd_func = object_or_func @@ -176,16 +169,14 @@ def cmd_func(*args, **kwargs): else: return decorator(bound_fn_or_cls) - # Backwards compat. # For scripts written with argbind<=0.1.3. bind_to_parser = bind - class bind_module: def __init__(self, module, *scopes, filter_fn=lambda fn: True, **kwargs): """Binds every function/class in a specified module. The output - class is a bound version of the original module, with the + class is a bound version of the original module, with the attributes in the same place. Parameters @@ -195,7 +186,7 @@ class is a bound version of the original module, with the scopes : List[str] or [fn or Object] + List[str], optional List of patterns to bind the function under. filter_fn : Callable, optional - A function that takes in the function that is to be bound, and + A function that takes in the function that is to be bound, and returns a boolean as to whether or not it should be bound. Defaults to always True, no matter what the function is. kwargs : keyword arguments, optional @@ -209,7 +200,6 @@ class is a bound version of the original module, with the bound_fn = bind(fn, *scopes, **kwargs) setattr(self, fn_name, bound_fn) - def get_used_args(): """ Gets the args that have been used so far @@ -218,7 +208,6 @@ def get_used_args(): """ return USED_ARGS - def dump_args(args, output_path): """ Dumps the provided arguments to a @@ -226,22 +215,21 @@ def dump_args(args, output_path): """ path = Path(output_path) os.makedirs(path.parent, exist_ok=True) - with open(path, "w") as f: - yaml.Dumper.ignore_aliases = lambda *args: True + with open(path, 'w') as f: + yaml.Dumper.ignore_aliases = lambda *args : True x = yaml.dump(args, Dumper=yaml.Dumper) prev_line = None output = [] - for line in x.split("\n"): - cur_line = line.split(".")[0].strip() - if not cur_line.startswith("-"): + for line in x.split('\n'): + cur_line = line.split('.')[0].strip() + if not cur_line.startswith('-'): if cur_line != prev_line and prev_line: - line = f"\n{line}" - prev_line = line.split(".")[0].strip() + line = f'\n{line}' + prev_line = line.split('.')[0].strip() output.append(line) - f.write("\n".join(output)) + f.write('\n'.join(output)) - -def load_args(input_path_or_stream, config_directory="."): +def load_args(input_path_or_stream, config_directory = "."): """ Loads arguments from a given input path or file stream, if the file is already open. @@ -252,13 +240,13 @@ def load_args(input_path_or_stream, config_directory="."): """ if isinstance(input_path_or_stream, (str, Path)): input_path = Path(config_directory) / input_path_or_stream - with open(input_path, "r") as f: + with open(input_path, 'r') as f: data = yaml.load(f, Loader=yaml.Loader) else: data = yaml.load(input_path_or_stream, Loader=yaml.Loader) - - if "$include" in data: - include_files = data.pop("$include") + + if '$include' in data: + include_files = data.pop('$include') include_args = {} for include_file in include_files: include_args.update(load_args(include_file, config_directory)) @@ -266,21 +254,21 @@ def load_args(input_path_or_stream, config_directory="."): data = include_args _vars = os.environ.copy() - if "$vars" in data: - _vars.update(data.pop("$vars")) - + if '$vars' in data: + _vars.update(data.pop('$vars')) + for key, val in data.items(): # Check if string starts with $. - if isinstance(val, str): - if val.startswith("$"): + if isinstance(val, str): + if val.startswith('$'): lookup = val[1:] if lookup in _vars: data[key] = _vars[lookup] - + elif isinstance(val, list): new_list = [] for subval in val: - if isinstance(subval, str) and subval.startswith("$"): + if isinstance(subval, str) and subval.startswith('$'): lookup = subval[1:] if lookup in _vars: new_list.append(_vars[lookup]) @@ -290,32 +278,27 @@ def load_args(input_path_or_stream, config_directory="."): new_list.append(subval) data[key] = new_list - if "args.debug" not in data: - data["args.debug"] = DEBUG + if 'args.debug' not in data: + data['args.debug'] = DEBUG return data - -class str_to_list: +class str_to_list(): def __init__(self, _type): self._type = _type - def __call__(self, values): - _values = values.split(" ") + _values = values.split(' ') _values = [self._type(v) for v in _values] return _values - -class str_to_tuple: +class str_to_tuple(): def __init__(self, _type_list): self._type_list = _type_list - def __call__(self, values): - _values = values.split(" ") + _values = values.split(' ') _values = [self._type_list[i](v) for i, v in enumerate(_values)] return tuple(_values) - -class str_to_dict: +class str_to_dict(): def __init__(self): pass @@ -328,18 +311,17 @@ def _guess_type(self, s): return value def __call__(self, values): - values = values.split(" ") + values = values.split(' ') _values = {} for elem in values: - key, val = elem.split("=", 1) + key, val = elem.split('=', 1) key = self._guess_type(key) val = self._guess_type(val) _values[key] = val return _values - def build_parser(group: Union[list, str] = "default"): """Builds the argument parser from all of the bound functions. @@ -348,28 +330,17 @@ def build_parser(group: Union[list, str] = "default"): ArgumentParser Argument parser built by ArgBind. """ - p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument( - "--args.save", - type=str, - required=False, - help="Path to save all arguments used to run script to.", - ) - p.add_argument( - "--args.load", - type=str, - required=False, - help="Path to load arguments from, stored as a .yml file.", - ) - p.add_argument( - "--args.debug", - type=int, - required=False, - default=0, - help="Print arguments as they are passed to each function.", + p = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter ) + p.add_argument('--args.save', type=str, required=False, + help="Path to save all arguments used to run script to.") + p.add_argument('--args.load', type=str, required=False, + help="Path to load arguments from, stored as a .yml file.") + p.add_argument('--args.debug', type=int, required=False, default=0, + help="Print arguments as they are passed to each function.") + if isinstance(group, str): group = [group] if "default" not in group: @@ -384,7 +355,9 @@ def build_parser(group: Union[list, str] = "default"): docstring = docstring_parser.parse(func.__doc__) parameter_help = docstring.params - parameter_help = {x.arg_name: x.description for x in parameter_help} + parameter_help = { + x.arg_name: x.description for x in parameter_help + } f = p.add_argument_group( title=f"Generated arguments for function {prefix}", @@ -394,19 +367,22 @@ def _get_arg_names(key, is_kwarg): arg_names = [] arg_name = key - prepend = "--" if is_kwarg else "" + prepend = '--' if is_kwarg else '' if without_prefix: - arg_name = prepend + f"PATTERN/{key}" + arg_name = prepend + f'PATTERN/{key}' else: - arg_name = prepend + f"PATTERN/{prefix}.{key}" - - arg_names.append(arg_name.replace("PATTERN/", "")) + arg_name = prepend + f'PATTERN/{prefix}.{key}' + arg_names.append(arg_name.replace('PATTERN/', '')) + if patterns is not None: for p in patterns: - arg_names.append(arg_name.replace("PATTERN", p)) + arg_names.append( + arg_name.replace('PATTERN', p) + ) return arg_names + for key, val in sig.parameters.items(): arg_val = val.default arg_type = val.annotation @@ -418,7 +394,7 @@ def _get_arg_names(key, is_kwarg): if is_kwarg or positional: arg_names = _get_arg_names(key, is_kwarg) arg_help = {} - help_text = "" + help_text = '' if key in parameter_help: help_text = textwrap.fill(parameter_help[key], width=HELP_WIDTH) arg_help[arg_names[0]] = help_text @@ -431,51 +407,33 @@ def _get_arg_names(key, is_kwarg): list_types = [List[x] for x in inner_types] if arg_type is bool: - f.add_argument( - arg_name, action="store_true", help=arg_help[arg_name] - ) + f.add_argument(arg_name, action='store_true', + help=arg_help[arg_name]) elif arg_type in list_types: _type = inner_types[list_types.index(arg_type)] - f.add_argument( - arg_name, - type=str_to_list(_type), - default=arg_val, - help=arg_help[arg_name], - ) + f.add_argument(arg_name, type=str_to_list(_type), + default=arg_val, help=arg_help[arg_name]) elif arg_type is Dict: - f.add_argument( - arg_name, - type=str_to_dict(), - default=arg_val, - help=arg_help[arg_name], - ) - elif hasattr(arg_type, "__origin__"): + f.add_argument(arg_name, type=str_to_dict(), + default=arg_val, help=arg_help[arg_name]) + elif hasattr(arg_type, '__origin__'): if arg_type.__origin__ is tuple: _type_list = arg_type.__args__ - f.add_argument( - arg_name, - type=str_to_tuple(_type_list), - default=arg_val, - help=arg_help[arg_name], - ) + f.add_argument(arg_name, type=str_to_tuple(_type_list), + default=arg_val, help=arg_help[arg_name]) else: - f.add_argument( - arg_name, - type=arg_type, - default=arg_val, - help=arg_help[arg_name], - ) - + f.add_argument(arg_name, type=arg_type, + default=arg_val, help=arg_help[arg_name]) + desc = docstring.short_description - if desc is None: - desc = "" + if desc is None: desc = '' if patterns: if not without_prefix: scope_pattern = f"--{patterns[0]}/{prefix}.{key}" else: scope_pattern = f"--{patterns[0]}/{key}" - + desc += ( f" Additional scope patterns: {', '.join(list(patterns))}. " "Use these by prefacing any of the args below with one " @@ -485,39 +443,36 @@ def _get_arg_names(key, is_kwarg): desc = textwrap.fill(desc, width=HELP_WIDTH) f.description = desc - + return p - def parse_args(p=None, group: Union[list, str] = "default"): """ Parses the command line and returns a dictionary. Builds the argument parser if p is None. """ p = build_parser(group=group) if p is None else p - used_args = [ - x.replace("--", "").split("=")[0] for x in sys.argv if x.startswith("--") - ] - used_args.extend(["args.save", "args.load"]) + used_args = [x.replace('--', '').split('=')[0] for x in sys.argv if x.startswith('--')] + used_args.extend(['args.save', 'args.load']) known, unknown = p.parse_known_args() args = vars(known) args["args.unknown"] = unknown - load_args_path = args.pop("args.load") - save_args_path = args.pop("args.save") - debug_args = args.pop("args.debug") - - pattern_keys = [key for key in args if "/" in key] - top_level_args = [key for key in args if "/" not in key] + load_args_path = args.pop('args.load') + save_args_path = args.pop('args.save') + debug_args = args.pop('args.debug') + + pattern_keys = [key for key in args if '/' in key] + top_level_args = [key for key in args if '/' not in key] for key in pattern_keys: # If the top-level arguments were altered but the ones # in patterns were not, change the scoped ones to # match the top-level (inherit arguments from top-level). - pattern, arg_name = key.split("/") + pattern, arg_name = key.split('/') if key not in used_args: args[key] = args[arg_name] - + if load_args_path: loaded_args = load_args(load_args_path) # Overwrite defaults with things in loaded arguments. @@ -526,15 +481,15 @@ def parse_args(p=None, group: Union[list, str] = "default"): if key not in used_args: args[key] = loaded_args[key] for key in pattern_keys: - pattern, arg_name = key.split("/") + pattern, arg_name = key.split('/') if key not in loaded_args and key not in used_args: if arg_name in loaded_args: args[key] = args[arg_name] - + for key in top_level_args: if key in used_args: for pattern_key in pattern_keys: - pattern, arg_name = pattern_key.split("/") + pattern, arg_name = pattern_key.split('/') if key == arg_name and pattern_key not in used_args: args[pattern_key] = args[key] @@ -542,8 +497,8 @@ def parse_args(p=None, group: Union[list, str] = "default"): dump_args(args, save_args_path) # Put them back in case the script wants to use them - args["args.load"] = load_args_path - args["args.save"] = save_args_path - args["args.debug"] = debug_args - + args['args.load'] = load_args_path + args['args.save'] = save_args_path + args['args.debug'] = debug_args + return args