Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions conda.recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ package:
requirements:
build:
- python
- "marshmallow>=3.0.0rc4"
- "marshmallow>=4.0.0"
- "numpy>=1.13"
- "python-dateutil>=2.8.0"

run:
- python
- "marshmallow>=3.0.0rc4"
- "marshmallow>=4.0.0"
- "numpy>=1.13"
- "python-dateutil>=2.8.0"

Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: paramtools-dev
channels:
- conda-forge
dependencies:
- "marshmallow>=3.22.0"
- "marshmallow>=4.0.0"
- "numpy>=2.1.0"
- "python-dateutil>=2.8.0"
- "pytest>=6.0.0"
Expand Down
2 changes: 1 addition & 1 deletion paramtools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@


name = "paramtools"
__version__ = "0.19.0"
__version__ = "0.20.0"

__all__ = [
"SchemaFactory",
Expand Down
8 changes: 4 additions & 4 deletions paramtools/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
else:
self._stateless_label_grid[name] = []
self.label_grid = copy.deepcopy(self._stateless_label_grid)
self._validator_schema.context["spec"] = self
self._validator_schema.pt_context["spec"] = self
self._warnings = {}
self._errors = {}
self._defer_validation = False
Expand Down Expand Up @@ -364,7 +364,7 @@ def _adjust(
for param, value in parsed_params.items():
self._update_param(param, value)

self._validator_schema.context["spec"] = self
self._validator_schema.pt_context["spec"] = self

has_errors = bool(self._errors.get("messages"))
has_warnings = bool(self._warnings.get("messages"))
Expand Down Expand Up @@ -525,7 +525,7 @@ def _delete(
if self.label_to_extend is not None and extend_adj:
self.extend()

self._validator_schema.context["spec"] = self
self._validator_schema.pt_context["spec"] = self

has_errors = bool(self._errors.get("messages"))
has_warnings = bool(self._warnings.get("messages"))
Expand Down Expand Up @@ -1414,4 +1414,4 @@ def get_defaults(self):
- `params`: String if URL or file path. Dict if this is the loaded params
dict.
"""
return utils.read_json(self.defaults)
return utils.read_json(self.defaults)
68 changes: 30 additions & 38 deletions paramtools/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
validates_schema,
ValidationError as MarshmallowValidationError,
decorators,
RAISE as RAISEUNKNOWNOPTION,
)
from marshmallow.error_store import ErrorStore

Expand All @@ -28,14 +29,14 @@ class RangeSchema(Schema):
}
"""

_min = fields.Field(attribute="min", data_key="min")
_max = fields.Field(attribute="max", data_key="max")
step = fields.Field()
_min = fields.Raw(attribute="min", data_key="min")
_max = fields.Raw(attribute="max", data_key="max")
step = fields.Raw()
level = fields.String(validate=[validate.OneOf(["warn", "error"])])


class ChoiceSchema(Schema):
choices = fields.List(fields.Field)
choices = fields.List(fields.Raw)
level = fields.String(validate=[validate.OneOf(["warn", "error"])])


Expand All @@ -53,9 +54,9 @@ class ValueValidatorSchema(Schema):


class IsSchema(Schema):
equal_to = fields.Field(required=False)
greater_than = fields.Field(required=False)
less_than = fields.Field(required=False)
equal_to = fields.Raw(required=False)
greater_than = fields.Raw(required=False)
less_than = fields.Raw(required=False)

@validates_schema
def just_one(self, data, **kwargs):
Expand Down Expand Up @@ -107,15 +108,12 @@ class BaseParamSchema(Schema):
data_key="type",
)
number_dims = fields.Integer(required=False, load_default=0)
value = fields.Field(required=True) # will be specified later
value = fields.Raw(required=True) # will be specified later
validators = fields.Nested(
ValueValidatorSchema(), required=False, load_default={}
)
indexed = fields.Boolean(required=False)

class Meta:
ordered = True


class EmptySchema(Schema):
"""
Expand All @@ -126,15 +124,6 @@ class EmptySchema(Schema):
pass


class OrderedSchema(Schema):
"""
Same as `EmptySchema`, but preserves the order of its fields.
"""

class Meta:
ordered = True


class ValueObject(fields.Nested):
"""
Schema for value objects
Expand Down Expand Up @@ -182,16 +171,17 @@ class BaseValidatorSchema(Schema):
class.
"""

class Meta:
ordered = True

WRAPPER_MAP = {
"range": "_get_range_validator",
"date_range": "_get_range_validator",
"choice": "_get_choice_validator",
"when": "_get_when_validator",
}

def __init__(self, *args, **kwargs):
self.pt_context = {}
super().__init__(*args, **kwargs)

def validate_only(self, data):
"""
Bypass deserialization and just run field validators. This is taken
Expand All @@ -208,21 +198,23 @@ def validate_only(self, data):
field_errors = bool(error_store.errors)
self._invoke_schema_validators(
error_store=error_store,
pass_many=True,
pass_collection=True,
data=data,
original_data=data,
many=None,
partial=None,
field_errors=field_errors,
unknown=RAISEUNKNOWNOPTION,
)
self._invoke_schema_validators(
error_store=error_store,
pass_many=False,
pass_collection=False,
data=data,
original_data=data,
many=None,
partial=None,
field_errors=field_errors,
unknown=RAISEUNKNOWNOPTION,
)
errors = error_store.errors
if errors:
Expand Down Expand Up @@ -271,7 +263,7 @@ def validate_param(self, param_name, param_spec, raw_data):
Do range validation for a parameter.
"""
validate_schema = not getattr(
self.context["spec"], "_defer_validation", False
self.pt_context["spec"], "_defer_validation", False
)
validators = self.validators(
param_name, param_spec, raw_data, validate_schema=validate_schema
Expand All @@ -290,15 +282,15 @@ def validate_param(self, param_name, param_spec, raw_data):
return warnings, errors

def field_keyfunc(self, param_name):
data = self.context["spec"]._data[param_name]
data = self.pt_context["spec"]._data[param_name]
field = get_type(data, self.validators(param_name))
try:
return field.cmp_funcs()["key"]
except AttributeError:
return None

def field(self, param_name):
data = self.context["spec"]._data[param_name]
data = self.pt_context["spec"]._data[param_name]
return get_type(data, self.validators(param_name))

def validators(
Expand All @@ -309,7 +301,7 @@ def validators(
if raw_data is None:
raw_data = {}

param_info = self.context["spec"]._data[param_name]
param_info = self.pt_context["spec"]._data[param_name]
# sort keys to guarantee order.
validator_spec = param_info.get("validators", {})
validators = []
Expand Down Expand Up @@ -347,7 +339,7 @@ def _get_when_validator(
when_param = when_dict["param"]

if (
when_param not in self.context["spec"]._data.keys()
when_param not in self.pt_context["spec"]._data.keys()
and when_param != "default"
):
raise MarshmallowValidationError(
Expand Down Expand Up @@ -382,8 +374,8 @@ def _get_when_validator(
)
)

_type = self.context["spec"]._data[oth_param]["type"]
number_dims = self.context["spec"]._data[oth_param]["number_dims"]
_type = self.pt_context["spec"]._data[oth_param]["type"]
number_dims = self.pt_context["spec"]._data[oth_param]["number_dims"]

error_then = (
f"When {oth_param}{{when_labels}}{{ix}} is {{is_val}}, "
Expand Down Expand Up @@ -469,9 +461,9 @@ def _get_range_validator(
)

def _sort_by_label_to_extend(self, vos):
label_to_extend = self.context["spec"].label_to_extend
label_to_extend = self.pt_context["spec"].label_to_extend
if label_to_extend is not None:
label_grid = self.context["spec"]._stateless_label_grid
label_grid = self.pt_context["spec"]._stateless_label_grid
extend_vals = label_grid[label_to_extend]
return sorted(
vos,
Expand Down Expand Up @@ -533,9 +525,9 @@ def _get_related_value(
# If comparing against the "default" value then get the current
# value of the parameter being updated.
if oth_param_name == "default":
oth_param = self.context["spec"]._data[param_name]
oth_param = self.pt_context["spec"]._data[param_name]
else:
oth_param = self.context["spec"]._data[oth_param_name]
oth_param = self.pt_context["spec"]._data[oth_param_name]
vals = oth_param["value"]
labs_to_check = {k for k in param_spec if k not in ("value", "_auto")}
if labs_to_check:
Expand All @@ -560,11 +552,11 @@ def _check_ndim_restriction(
if other_param is None:
continue
if other_param == "default":
ndims = self.context["spec"]._data[param_name][
ndims = self.pt_context["spec"]._data[param_name][
"number_dims"
]
else:
ndims = self.context["spec"]._data[other_param][
ndims = self.pt_context["spec"]._data[other_param][
"number_dims"
]
if ndims > 0:
Expand Down
9 changes: 3 additions & 6 deletions paramtools/schema_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from marshmallow import fields
from marshmallow import fields, Schema

from paramtools.schema import (
OrderedSchema,
BaseValidatorSchema,
ValueObject,
get_type,
Expand Down Expand Up @@ -67,17 +66,15 @@ def schemas(self):
# if not isinstance(v["value"], list):
# v["value"] = [{"value": v["value"]}]

validator_dict[k] = type(
"ValidatorItem", (OrderedSchema,), classattrs
)
validator_dict[k] = type("ValidatorItem", (Schema,), classattrs)

classattrs = {"value": ValueObject(validator_dict[k], many=True)}
param_dict[k] = type(
"IndividualParamSchema", (self.BaseParamSchema,), classattrs
)

classattrs = {k: fields.Nested(v) for k, v in param_dict.items()}
DefaultsSchema = type("DefaultsSchema", (OrderedSchema,), classattrs)
DefaultsSchema = type("DefaultsSchema", (Schema,), classattrs)
defaults_schema = DefaultsSchema()

classattrs = {
Expand Down
9 changes: 0 additions & 9 deletions paramtools/tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def get_defaults(self):
assert params.hello_world == "hello world"
assert params.label_grid == {"somelabel": [0, 1, 2, 3, 4, 5]}


def test_schema_not_dropped(self, defaults_spec_path):
with open(defaults_spec_path, "r") as f:
defaults_ = json.loads(f.read())
Expand Down Expand Up @@ -379,14 +378,6 @@ def test_specification(self, TestParams, defaults_spec_path):

assert spec1["min_int_param"] == exp["min_int_param"]["value"]

def test_is_ordered(self, TestParams):
params = TestParams()
spec1 = params.specification()
assert isinstance(spec1, OrderedDict)

spec2 = params.specification(meta_data=True, serializable=True)
assert isinstance(spec2, OrderedDict)

def test_specification_query(self, TestParams):
params = TestParams()
spec1 = params.specification()
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setuptools.setup(
name="paramtools",
version=os.environ.get("VERSION", "0.19.0"),
version=os.environ.get("VERSION", "0.20.0"),
author="Hank Doupe",
author_email="henrymdoupe@gmail.com",
description=(
Expand All @@ -18,7 +18,7 @@
url="https://github.com/hdoupe/ParamTools",
packages=setuptools.find_packages(),
install_requires=[
"marshmallow>=3.0.0",
"marshmallow>=4.0.0",
"numpy",
"python-dateutil>=2.8.0",
"fsspec",
Expand Down
Loading