diff --git a/conda.recipe/meta.yaml b/conda.recipe/meta.yaml index a69167a..88a554f 100755 --- a/conda.recipe/meta.yaml +++ b/conda.recipe/meta.yaml @@ -5,13 +5,13 @@ package: requirements: build: - python - - "marshmallow>=3.0.0rc4" + - "marshmallow>=4.0.0" - "numpy>=1.13" - "python-dateutil>=2.8.0" run: - python - - "marshmallow>=3.0.0rc4" + - "marshmallow>=4.0.0" - "numpy>=1.13" - "python-dateutil>=2.8.0" diff --git a/environment.yml b/environment.yml index 7903a13..5c09981 100644 --- a/environment.yml +++ b/environment.yml @@ -2,7 +2,7 @@ name: paramtools-dev channels: - conda-forge dependencies: - - "marshmallow>=3.22.0" + - "marshmallow>=4.0.0" - "numpy>=2.1.0" - "python-dateutil>=2.8.0" - "pytest>=6.0.0" diff --git a/paramtools/__init__.py b/paramtools/__init__.py index 292fa9f..283acdc 100644 --- a/paramtools/__init__.py +++ b/paramtools/__init__.py @@ -53,7 +53,7 @@ name = "paramtools" -__version__ = "0.19.0" +__version__ = "0.20.0" __all__ = [ "SchemaFactory", diff --git a/paramtools/parameters.py b/paramtools/parameters.py index fe3e908..0830049 100644 --- a/paramtools/parameters.py +++ b/paramtools/parameters.py @@ -101,7 +101,7 @@ def __init__( else: self._stateless_label_grid[name] = [] self.label_grid = copy.deepcopy(self._stateless_label_grid) - self._validator_schema.context["spec"] = self + self._validator_schema.pt_context["spec"] = self self._warnings = {} self._errors = {} self._defer_validation = False @@ -364,7 +364,7 @@ def _adjust( for param, value in parsed_params.items(): self._update_param(param, value) - self._validator_schema.context["spec"] = self + self._validator_schema.pt_context["spec"] = self has_errors = bool(self._errors.get("messages")) has_warnings = bool(self._warnings.get("messages")) @@ -525,7 +525,7 @@ def _delete( if self.label_to_extend is not None and extend_adj: self.extend() - self._validator_schema.context["spec"] = self + self._validator_schema.pt_context["spec"] = self has_errors = bool(self._errors.get("messages")) has_warnings = bool(self._warnings.get("messages")) @@ -1414,4 +1414,4 @@ def get_defaults(self): - `params`: String if URL or file path. Dict if this is the loaded params dict. """ - return utils.read_json(self.defaults) \ No newline at end of file + return utils.read_json(self.defaults) diff --git a/paramtools/schema.py b/paramtools/schema.py index ebbb453..2df76ed 100644 --- a/paramtools/schema.py +++ b/paramtools/schema.py @@ -7,6 +7,7 @@ validates_schema, ValidationError as MarshmallowValidationError, decorators, + RAISE as RAISEUNKNOWNOPTION, ) from marshmallow.error_store import ErrorStore @@ -28,14 +29,14 @@ class RangeSchema(Schema): } """ - _min = fields.Field(attribute="min", data_key="min") - _max = fields.Field(attribute="max", data_key="max") - step = fields.Field() + _min = fields.Raw(attribute="min", data_key="min") + _max = fields.Raw(attribute="max", data_key="max") + step = fields.Raw() level = fields.String(validate=[validate.OneOf(["warn", "error"])]) class ChoiceSchema(Schema): - choices = fields.List(fields.Field) + choices = fields.List(fields.Raw) level = fields.String(validate=[validate.OneOf(["warn", "error"])]) @@ -53,9 +54,9 @@ class ValueValidatorSchema(Schema): class IsSchema(Schema): - equal_to = fields.Field(required=False) - greater_than = fields.Field(required=False) - less_than = fields.Field(required=False) + equal_to = fields.Raw(required=False) + greater_than = fields.Raw(required=False) + less_than = fields.Raw(required=False) @validates_schema def just_one(self, data, **kwargs): @@ -107,15 +108,12 @@ class BaseParamSchema(Schema): data_key="type", ) number_dims = fields.Integer(required=False, load_default=0) - value = fields.Field(required=True) # will be specified later + value = fields.Raw(required=True) # will be specified later validators = fields.Nested( ValueValidatorSchema(), required=False, load_default={} ) indexed = fields.Boolean(required=False) - class Meta: - ordered = True - class EmptySchema(Schema): """ @@ -126,15 +124,6 @@ class EmptySchema(Schema): pass -class OrderedSchema(Schema): - """ - Same as `EmptySchema`, but preserves the order of its fields. - """ - - class Meta: - ordered = True - - class ValueObject(fields.Nested): """ Schema for value objects @@ -182,9 +171,6 @@ class BaseValidatorSchema(Schema): class. """ - class Meta: - ordered = True - WRAPPER_MAP = { "range": "_get_range_validator", "date_range": "_get_range_validator", @@ -192,6 +178,10 @@ class Meta: "when": "_get_when_validator", } + def __init__(self, *args, **kwargs): + self.pt_context = {} + super().__init__(*args, **kwargs) + def validate_only(self, data): """ Bypass deserialization and just run field validators. This is taken @@ -208,21 +198,23 @@ def validate_only(self, data): field_errors = bool(error_store.errors) self._invoke_schema_validators( error_store=error_store, - pass_many=True, + pass_collection=True, data=data, original_data=data, many=None, partial=None, field_errors=field_errors, + unknown=RAISEUNKNOWNOPTION, ) self._invoke_schema_validators( error_store=error_store, - pass_many=False, + pass_collection=False, data=data, original_data=data, many=None, partial=None, field_errors=field_errors, + unknown=RAISEUNKNOWNOPTION, ) errors = error_store.errors if errors: @@ -271,7 +263,7 @@ def validate_param(self, param_name, param_spec, raw_data): Do range validation for a parameter. """ validate_schema = not getattr( - self.context["spec"], "_defer_validation", False + self.pt_context["spec"], "_defer_validation", False ) validators = self.validators( param_name, param_spec, raw_data, validate_schema=validate_schema @@ -290,7 +282,7 @@ def validate_param(self, param_name, param_spec, raw_data): return warnings, errors def field_keyfunc(self, param_name): - data = self.context["spec"]._data[param_name] + data = self.pt_context["spec"]._data[param_name] field = get_type(data, self.validators(param_name)) try: return field.cmp_funcs()["key"] @@ -298,7 +290,7 @@ def field_keyfunc(self, param_name): return None def field(self, param_name): - data = self.context["spec"]._data[param_name] + data = self.pt_context["spec"]._data[param_name] return get_type(data, self.validators(param_name)) def validators( @@ -309,7 +301,7 @@ def validators( if raw_data is None: raw_data = {} - param_info = self.context["spec"]._data[param_name] + param_info = self.pt_context["spec"]._data[param_name] # sort keys to guarantee order. validator_spec = param_info.get("validators", {}) validators = [] @@ -347,7 +339,7 @@ def _get_when_validator( when_param = when_dict["param"] if ( - when_param not in self.context["spec"]._data.keys() + when_param not in self.pt_context["spec"]._data.keys() and when_param != "default" ): raise MarshmallowValidationError( @@ -382,8 +374,8 @@ def _get_when_validator( ) ) - _type = self.context["spec"]._data[oth_param]["type"] - number_dims = self.context["spec"]._data[oth_param]["number_dims"] + _type = self.pt_context["spec"]._data[oth_param]["type"] + number_dims = self.pt_context["spec"]._data[oth_param]["number_dims"] error_then = ( f"When {oth_param}{{when_labels}}{{ix}} is {{is_val}}, " @@ -469,9 +461,9 @@ def _get_range_validator( ) def _sort_by_label_to_extend(self, vos): - label_to_extend = self.context["spec"].label_to_extend + label_to_extend = self.pt_context["spec"].label_to_extend if label_to_extend is not None: - label_grid = self.context["spec"]._stateless_label_grid + label_grid = self.pt_context["spec"]._stateless_label_grid extend_vals = label_grid[label_to_extend] return sorted( vos, @@ -533,9 +525,9 @@ def _get_related_value( # If comparing against the "default" value then get the current # value of the parameter being updated. if oth_param_name == "default": - oth_param = self.context["spec"]._data[param_name] + oth_param = self.pt_context["spec"]._data[param_name] else: - oth_param = self.context["spec"]._data[oth_param_name] + oth_param = self.pt_context["spec"]._data[oth_param_name] vals = oth_param["value"] labs_to_check = {k for k in param_spec if k not in ("value", "_auto")} if labs_to_check: @@ -560,11 +552,11 @@ def _check_ndim_restriction( if other_param is None: continue if other_param == "default": - ndims = self.context["spec"]._data[param_name][ + ndims = self.pt_context["spec"]._data[param_name][ "number_dims" ] else: - ndims = self.context["spec"]._data[other_param][ + ndims = self.pt_context["spec"]._data[other_param][ "number_dims" ] if ndims > 0: diff --git a/paramtools/schema_factory.py b/paramtools/schema_factory.py index b841169..dc61dbb 100644 --- a/paramtools/schema_factory.py +++ b/paramtools/schema_factory.py @@ -1,7 +1,6 @@ -from marshmallow import fields +from marshmallow import fields, Schema from paramtools.schema import ( - OrderedSchema, BaseValidatorSchema, ValueObject, get_type, @@ -67,9 +66,7 @@ def schemas(self): # if not isinstance(v["value"], list): # v["value"] = [{"value": v["value"]}] - validator_dict[k] = type( - "ValidatorItem", (OrderedSchema,), classattrs - ) + validator_dict[k] = type("ValidatorItem", (Schema,), classattrs) classattrs = {"value": ValueObject(validator_dict[k], many=True)} param_dict[k] = type( @@ -77,7 +74,7 @@ def schemas(self): ) classattrs = {k: fields.Nested(v) for k, v in param_dict.items()} - DefaultsSchema = type("DefaultsSchema", (OrderedSchema,), classattrs) + DefaultsSchema = type("DefaultsSchema", (Schema,), classattrs) defaults_schema = DefaultsSchema() classattrs = { diff --git a/paramtools/tests/test_parameters.py b/paramtools/tests/test_parameters.py index 1c5a6ec..d84ef04 100644 --- a/paramtools/tests/test_parameters.py +++ b/paramtools/tests/test_parameters.py @@ -166,7 +166,6 @@ def get_defaults(self): assert params.hello_world == "hello world" assert params.label_grid == {"somelabel": [0, 1, 2, 3, 4, 5]} - def test_schema_not_dropped(self, defaults_spec_path): with open(defaults_spec_path, "r") as f: defaults_ = json.loads(f.read()) @@ -379,14 +378,6 @@ def test_specification(self, TestParams, defaults_spec_path): assert spec1["min_int_param"] == exp["min_int_param"]["value"] - def test_is_ordered(self, TestParams): - params = TestParams() - spec1 = params.specification() - assert isinstance(spec1, OrderedDict) - - spec2 = params.specification(meta_data=True, serializable=True) - assert isinstance(spec2, OrderedDict) - def test_specification_query(self, TestParams): params = TestParams() spec1 = params.specification() diff --git a/setup.py b/setup.py index 4cbb390..e2b0abb 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setuptools.setup( name="paramtools", - version=os.environ.get("VERSION", "0.19.0"), + version=os.environ.get("VERSION", "0.20.0"), author="Hank Doupe", author_email="henrymdoupe@gmail.com", description=( @@ -18,7 +18,7 @@ url="https://github.com/hdoupe/ParamTools", packages=setuptools.find_packages(), install_requires=[ - "marshmallow>=3.0.0", + "marshmallow>=4.0.0", "numpy", "python-dateutil>=2.8.0", "fsspec",