Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: CI

on:
pull_request:
push:
workflow_dispatch:

permissions:
contents: read

jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- name: Checkout repository
uses: actions/checkout@v5

- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.12"

- name: Set up uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true

- name: Sync dev dependencies
run: uv sync --group dev

- name: Run tests with coverage
run: uv run --group dev pytest --cov=msup --cov-report=term-missing

build:
runs-on: ubuntu-latest
timeout-minutes: 10
needs: test
steps:
- name: Checkout repository
uses: actions/checkout@v5

- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.12"

- name: Set up uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true

- name: Build distributions
run: uv build --out-dir dist
4 changes: 2 additions & 2 deletions configs/example.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
{
"model_config": {
"n_layers": 10,
"n_layers": 7,
"checkpoint_path": null
},
"lr": 0.01,
"name": "example",
"lr_step_fn": "cosine_warmup_lr_step",
"lr_step_fn": "__main__.cosine_warmup_lr_step",
"num_workers": -1,
"cont": false,
"config_root_dir": "./configs"
Expand Down
6 changes: 3 additions & 3 deletions configs/identity.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
"n_layers": 10,
"checkpoint_path": null
},
"lr": 0.1,
"lr": 0.2,
"name": "identity",
"lr_step_fn": "identity_step_fn",
"num_workers": -1,
"lr_step_fn": "__main__.identity_step_fn",
"num_workers": 42,
"cont": false,
"config_root_dir": "./configs"
}
28 changes: 28 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
set shell := ["bash", "-euo", "pipefail", "-c"]

default:
@just --list

setup-dev:
uv sync --group dev

test:
uv run --group dev pytest

coverage:
uv run --group dev pytest --cov=msup --cov-report=term-missing

tag-release version:
test -n "{{version}}"
test -z "$(git status --porcelain)"
uv version "{{version}}" --frozen
git add pyproject.toml
git commit -m "Release {{version}}"
git tag -a "v{{version}}" -m "Release {{version}}"
git push origin HEAD
git push origin "v{{version}}"

publish-release:
rm -rf dist
uv build --out-dir dist
uv run --group dev python -m twine upload dist/*
4 changes: 2 additions & 2 deletions msup/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def fields_or_init_kwargs(clazz: type):

def load_callable(name: str):
idx = name.rfind('.')
assert idx != -1, "expected <module_name>.<name>"
assert idx != -1, f"expected <module_name>.<name>, got {name}"
module_name = name[0:idx]
fn_name = name[idx+1:]
mod = importlib.import_module(module_name)
Expand All @@ -89,7 +89,7 @@ def _to_dict_value(x: T, field_type: type):
return to_dict(x)
elif get_origin(field_type) is Callable2:
if callable(x):
return x.__name__
return f"{x.__module__}.{x.__name__}"
else:
assert isinstance(x, str), f"{x.__class__=}"
return x
Expand Down
144 changes: 89 additions & 55 deletions msup/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from dataclasses import dataclass, field, is_dataclass, fields, MISSING
from collections.abc import Callable as Callable2

from msup.base import has_default_value, is_optional, _from_value, to_json
from msup.base import has_default_value, is_optional, _from_value, to_json, to_kwargs
from typing import Optional, List, Dict, Union, TypeVar, get_origin, get_args, Callable, get_type_hints, Any

T = TypeVar('T')
_UNSET = object()

def cli(cmd_or_cmds: Callable[[T], Any] | dict[Callable[[T], Any], str], **argsparse_kwargs): ...
def cliarg(help: str = "", short: str | list[str] | None = None, env: str | None = None, pos: bool = False, opt: bool = True, **kwargs): ...
Expand Down Expand Up @@ -38,77 +39,107 @@ def _get_first_arg(func):
raise TypeError(f"First argument for {getattr(fn, '__name__', fn)} is not a dataclass: {dtype}")
return result

def _from_cli_args(clazz: type, args, prefix: str = ""):
assert is_dataclass(clazz), f"{cmd_type} is not a dataclass"

construct_args = {}
def _get_env_default_value(f):
env_name = f.metadata.get("env")
env_value = os.getenv(env_name) if env_name else None
if env_value:
return _from_value(env_value, f.type, str, f.name)
return _UNSET

def _get_cli_value(args, arg_name: str):
value = getattr(args, arg_name, _UNSET)
if value is _UNSET and hasattr(args, arg_name + "_pos"):
pos_value = getattr(args, arg_name + "_pos")
if pos_value is not _UNSET:
return pos_value
return value

def _from_cli_args(clazz: type, args, prefix: str = "", args_value: Any = _UNSET):
assert is_dataclass(clazz), f"{clazz} is not a dataclass"

if args_value is _UNSET:
root_args_value = getattr(args, "args", _UNSET)
args_value = _from_value(root_args_value, clazz, str, "") if prefix == "" and root_args_value is not _UNSET else None

construct_args = to_kwargs(clazz, args_value) if args_value is not None else {}
for f in fields(clazz):
arg_name = prefix + "." + f.name if prefix else f.name
value = getattr(args, arg_name, None)
if value is None and hasattr(args, arg_name + "_pos"):
value = getattr(args, arg_name + "_pos")
value = _get_cli_value(args, arg_name)
env_default = _get_env_default_value(f)
if is_dataclass(f.type):
if value is not None:
if value is not _UNSET:
if not isinstance(value, str):
error_exit(f"expected string for --{arg_name}, got {type(value)} ({value=})", 2)

sub = _from_value(
sub_args_value = _from_value(
value,
f.type,
str,
f.name,
)
# NOTE: merge additional values
for subf in fields(f.type):
subv = getattr(args, arg_name + "." + subf.name)
if subv:
v = _from_value(
subv,
subf.type,
type(subv),
field_name=f.name,
)
setattr(sub, subf.name, subv)
else:
sub = _from_cli_args(f.type, args, prefix=f.name)
sub_args_value = construct_args.get(f.name, env_default)

sub = _from_cli_args(f.type, args, prefix=arg_name, args_value=sub_args_value)
construct_args[f.name] = sub
elif get_origin(f.type) is dict or f.type is dict:
if value is None:
if has_default_value(f):
continue
if value is not _UNSET:
if not isinstance(value, str):
error_exit(f"expected string for --{arg_name}, got {type(value)} ({value=})", 2)
sub = _from_value(
value,
f.type,
str,
f.name,
)
construct_args[f.name] = sub
elif f.name in construct_args:
continue
elif env_default is not _UNSET:
construct_args[f.name] = env_default
elif has_default_value(f):
continue
else:
error_exit(f"--{arg_name} not provided (default value DNE)", 3)
if not isinstance(value, str):
error_exit(f"expected string for --{arg_name}, got {type(value)} ({value=})", 2)
sub = _from_value(
value,
f.type,
str,
f.name,
)
construct_args[f.name] = sub
elif f.type is bool:
if isinstance(value, bool):
construct_args[f.name] = value
if value is not _UNSET:
if isinstance(value, bool):
construct_args[f.name] = value
else:
if not isinstance(value, str):
error_exit(f"expected string for --{arg_name}, got {type(value)} ({value=})", 2)

if value.lower() not in ("0", "false", "1", "true"):
error_exit(f"expected one of: {0, False, 1, True} as a bool value for --{arg_name}, got: {value}")

construct_args[f.name] = value.lower() in ("1", "true")
elif f.name in construct_args:
continue
elif env_default is not _UNSET:
construct_args[f.name] = env_default
elif has_default_value(f):
continue
elif is_optional(f.type):
construct_args[f.name] = None
else:
if not isinstance(value, str):
error_exit(f"expected string for --{arg_name}, got {type(value)} ({value=})", 2)

if value.lower() not in ("0", "false", "1", "true"):
error_exit(f"expected one of: {0, False, 1, True} as a bool value for --{arg_name}, got: {value}")

construct_args[f.name] = value.lower() in ("1", "true")
error_exit(f"--{arg_name} not provided (default value DNE)", 3)
else:
if value is not None:
if value is not _UNSET:
construct_args[f.name] = _from_value(
value,
f.type,
type(value),
field_name=f.name,
)
elif f.name in construct_args:
continue
elif env_default is not _UNSET:
construct_args[f.name] = env_default
elif has_default_value(f):
continue
elif is_optional(f.type):
construct_args[f.name] = None
elif not has_default_value(f):
else:
error_exit(f"--{arg_name} not provided (default value DNE)", 3)

return clazz(**construct_args)
Expand Down Expand Up @@ -136,6 +167,7 @@ def _add_args(parser, cmd_type: type, prefix: str = "", short_prefix: str | None
nargs="?",
type=_get_cli_arg_type(cmd_type),
help=f"configuration for {cmd_type.__name__}",
default=_UNSET,
)
parser.add_argument(
"--Args",
Expand All @@ -144,19 +176,20 @@ def _add_args(parser, cmd_type: type, prefix: str = "", short_prefix: str | None
type=_get_cli_arg_type(cmd_type),
help=f"configuration for {cmd_type.__name__}",
required=False,
default=_UNSET,
)

for f in fields(cmd_type):
field_name = f.name
name = prefix + "." + field_name if prefix else field_name
req = prefix == "" and not has_default_value(f)
req = False
o_or_field_type = get_origin(f.type) or f.type
default_value = f.default if f.default is not MISSING and not force_no_default else None
default_help = f"Default: {default_value}" if default_value else ""
env_name = f.metadata.get("env")
env_value = os.getenv(env_name) if env_name else None
if env_value:
default_value = _from_value(env_value, f.type, str, field_name)
env_default_value = _get_env_default_value(f)
if env_default_value is not _UNSET:
default_value = env_default_value
env_name = f.metadata.get("env")
default_help = f"Default (using env: ${{{env_name}}}): {default_value}"

help = f.metadata.get("help") + ". " + default_help if f.metadata.get("help") else default_help
Expand Down Expand Up @@ -186,6 +219,7 @@ def _add_args(parser, cmd_type: type, prefix: str = "", short_prefix: str | None
**kwargs,
type=_get_cli_arg_type(f.type),
help=help,
default=_UNSET,
)
_add_args(
parser,
Expand All @@ -201,15 +235,15 @@ def _add_args(parser, cmd_type: type, prefix: str = "", short_prefix: str | None
**kwargs,
type=_get_cli_arg_type(f.type),
help=help,
default=default_value,
default=_UNSET,
)
elif o_or_field_type in (dict,):
parser.add_argument(
*args,
**kwargs,
type=str,
help=help,
default=default_value,
default=_UNSET,
)
elif f.type in (bool,):
if "nargs" not in kwargs:
Expand All @@ -220,23 +254,23 @@ def _add_args(parser, cmd_type: type, prefix: str = "", short_prefix: str | None
const=not default_value,
type=to_bool,
metavar="{0|1,true|false,yes|no}",
default=default_value,
default=_UNSET,
)
elif get_origin(f.type) is Callable2:
parser.add_argument(
*args,
**kwargs,
type=str,
help=help,
default=default_value,
default=_UNSET,
)
else:
parser.add_argument(
*args,
**kwargs,
type=_get_cli_arg_type(f.type),
help=help,
default=default_value,
default=_UNSET,
)

def cliarg(help: str = "", short: str | list[str] | None = None, env: str | None = None, pos: bool = False, opt: bool = True, **kwargs):
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ readme = "README.md"
license-files = ["LICENSE"]
keywords = ["serialization", "encode", "dataclass", "cli", "argparse"]

[dependency-groups]
dev = [
"pytest>=8.4.1",
"pytest-cov>=6.2.1",
"twine>=6.1.0",
]

[tool.ruff]
line-length = 120
target-version = "py310"
Expand Down
Loading