diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d53e6e742..892049a6b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- `coefficients` attribute for `DiscreteSumConstraint`, enabling weighted sums. Follows + the same pattern as `ContinuousLinearConstraint.coefficients` +- `simplex_coefficients` keyword argument to `SubspaceDiscrete.from_simplex` for + weighted simplex sum constraints - Support for Python 3.14 - Support for pandas 3 - `Settings` class for unified and streamlined settings management @@ -40,6 +44,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Breaking Changes - `parameter_cartesian_prod_pandas` and `parameter_cartesian_prod_polars` moved from `baybe.searchspace.discrete` to `baybe.searchspace.utils` +- All optional arguments of `SubspaceDiscrete.from_simplex` after `simplex_parameters` + are now keyword-only - `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples instead of a single tuple (needed for interpoint constraints) - `Kernel.to_gpytorch` now takes a `SearchSpace` instead of explicit `ard_num_dims`, @@ -61,6 +67,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 equality check ### Changed +- `DiscreteSumConstraint`, `ContinuousLinearConstraint`, and + `SubspaceDiscrete.from_simplex` now forbid 0 as coefficients - "User Guide" section has been split into "Components" and "Concepts" - Default transfer learning kernel changed from `IndexKernel` to `PositiveIndexKernel`, enforcing positive task correlations diff --git a/baybe/constraints/continuous.py b/baybe/constraints/continuous.py index a16210cca1..6f6fe7f1e2 100644 --- a/baybe/constraints/continuous.py +++ b/baybe/constraints/continuous.py @@ -81,6 +81,8 @@ def _validate_coefficients( # noqa: DOC101, DOC103 "The given 'coefficients' list must have one floating point entry for " "each entry in 'parameters'." ) + if any(c == 0.0 for c in coefficients): + raise ValueError("All entries in 'coefficients' must be non-zero.") @coefficients.default def _default_coefficients(self) -> tuple[float, ...]: diff --git a/baybe/constraints/discrete.py b/baybe/constraints/discrete.py index b9635faffa..d43e2f3fbe 100644 --- a/baybe/constraints/discrete.py +++ b/baybe/constraints/discrete.py @@ -3,15 +3,16 @@ from __future__ import annotations import gc -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import reduce from typing import TYPE_CHECKING, Any, ClassVar, cast +import cattrs import numpy as np import numpy.typing as npt import pandas as pd from attrs import define, field -from attrs.validators import in_, min_len +from attrs.validators import deep_iterable, in_, min_len from typing_extensions import override from baybe.constraints.base import CardinalityConstraint, DiscreteConstraint @@ -26,6 +27,7 @@ block_serialization_hook, converter, ) +from baybe.utils.validation import finite_float if TYPE_CHECKING: import polars as pl @@ -77,7 +79,11 @@ def get_invalid_polars(self) -> pl.Expr: @define class DiscreteSumConstraint(DiscreteConstraint): - """Class for modelling sum constraints.""" + """Class for modelling sum constraints. + + The constraint evaluates whether the (optionally weighted) sum of the specified + parameters satisfies the given threshold condition. + """ # IMPROVE: refactor `SumConstraint` and `ProdConstraint` to avoid code copying @@ -94,9 +100,45 @@ class DiscreteSumConstraint(DiscreteConstraint): condition: ThresholdCondition = field() """The condition modeled by this constraint.""" + coefficients: tuple[float, ...] = field( + converter=lambda x: cattrs.structure(x, tuple[float, ...]), + validator=deep_iterable(member_validator=finite_float), + ) + """The coefficients for the weighted sum, one per entry in ``parameters``. + + Defaults to all-ones, i.e. an unweighted sum.""" + + @coefficients.default + def _default_coefficients(self) -> tuple[float, ...]: + """Return equal weight coefficients as default.""" + return (1.0,) * len(self.parameters) + + @coefficients.validator + def _validate_coefficients( # noqa: DOC101, DOC103 + self, _: Any, coefficients: Sequence[float] + ) -> None: + """Validate the coefficients. + + Raises: + ValueError: If the number of coefficients does not match the number of + parameters. + """ + if len(self.parameters) != len(coefficients): + raise ValueError( + "The given 'coefficients' list must have one floating point entry for " + "each entry in 'parameters'." + ) + if any(c == 0.0 for c in coefficients): + raise ValueError("All entries in 'coefficients' must be non-zero.") + @override def _get_invalid(self, df: pd.DataFrame, /) -> pd.Index: - evaluate_df = df[self.parameters].sum(axis=1) + evaluate_df = pd.Series( + sum( + df[p].to_numpy() * c for p, c in zip(self.parameters, self.coefficients) + ), + index=df.index, + ) mask_bad = ~self.condition.evaluate(evaluate_df) return df.index[mask_bad] @@ -105,7 +147,8 @@ def _get_invalid(self, df: pd.DataFrame, /) -> pd.Index: def get_invalid_polars(self) -> pl.Expr: from baybe._optional.polars import polars as pl - return self.condition.to_polars(pl.sum_horizontal(self.parameters)).not_() + weighted = [pl.col(p) * c for p, c in zip(self.parameters, self.coefficients)] + return self.condition.to_polars(pl.sum_horizontal(weighted)).not_() @define diff --git a/baybe/searchspace/discrete.py b/baybe/searchspace/discrete.py index b8c3e1c0ae..5dcc700671 100644 --- a/baybe/searchspace/discrete.py +++ b/baybe/searchspace/discrete.py @@ -269,6 +269,8 @@ def from_simplex( cls, max_sum: float, simplex_parameters: Sequence[NumericalDiscreteParameter], + *, + simplex_coefficients: Sequence[float] | None = None, product_parameters: Sequence[DiscreteParameter] | None = None, constraints: Sequence[DiscreteConstraint] | None = None, min_nonzero: int = 0, @@ -290,8 +292,13 @@ def from_simplex( significantly faster construction. Args: - max_sum: The maximum sum of the parameter values defining the simplex size. + max_sum: The maximum (weighted) sum of the parameter values defining the + simplex size. simplex_parameters: The parameters to be used for the simplex construction. + Their values are required to be non-negative. + simplex_coefficients: Optional coefficients for the weighted sum, one per + entry in ``simplex_parameters``. Defaults to all-ones, i.e. an + unweighted sum. product_parameters: Optional parameters that enter in form of a Cartesian product. constraints: See :class:`baybe.searchspace.core.SearchSpace`. @@ -306,6 +313,8 @@ def from_simplex( Raises: ValueError: If the passed simplex parameters are not suitable for a simplex construction. + ValueError: If the length of ``simplex_coefficients`` does not match the + number of ``simplex_parameters``. ValueError: If the passed product parameters are not discrete. ValueError: If the passed simplex parameters and product parameters are not disjoint. @@ -325,6 +334,8 @@ def from_simplex( constraints = [] if max_nonzero is None: max_nonzero = len(simplex_parameters) + if simplex_coefficients is None: + simplex_coefficients = [1.0] * len(simplex_parameters) # Validate constraints validate_constraints(constraints, [*simplex_parameters, *product_parameters]) @@ -343,6 +354,18 @@ def from_simplex( f"must be of subclasses of '{DiscreteParameter.__name__}'." ) + # Validate coefficients length + if len(simplex_coefficients) != len(simplex_parameters): + raise ValueError( + f"'simplex_coefficients' must have one entry per 'simplex_parameters' " + f"entry, but got {len(simplex_coefficients)} coefficient(s) for " + f"{len(simplex_parameters)} parameter(s)." + ) + + # Validate no zero coefficients + if any(c == 0.0 for c in simplex_coefficients): + raise ValueError("All entries in 'simplex_coefficients' must be non-zero.") + # Validate no overlap between simplex parameters and product parameters simplex_parameters_names = {p.name for p in simplex_parameters} product_parameters_names = {p.name for p in product_parameters} @@ -364,79 +387,54 @@ def from_simplex( if len(simplex_parameters) < 1: return cls.from_product(product_parameters, constraints) - # Validate non-negativity - min_values = [min(p.values) for p in simplex_parameters] - max_values = [max(p.values) for p in simplex_parameters] - if not (min(min_values) >= 0.0): + # Validate non-negativity of raw parameter values (required by the algorithm) + min_raw = [min(p.values) for p in simplex_parameters] + max_raw = [max(p.values) for p in simplex_parameters] + if any(v < 0.0 for v in min_raw): raise ValueError( f"All simplex_parameters passed to '{cls.from_simplex.__name__}' " f"must have non-negative values only." ) - def drop_invalid( - df: pd.DataFrame, - max_sum: float, - boundary_only: bool, - min_nonzero: int | None = None, - max_nonzero: int | None = None, - ) -> None: - """Drop rows that violate the specified simplex constraint. - - Args: - df: The dataframe whose rows should satisfy the simplex constraint. - max_sum: The maximum row sum defining the simplex size. - boundary_only: Flag to control if the points represented by the rows - may lie inside the simplex or on its boundary only. - min_nonzero: Minimum number of nonzero parameters required per row. - max_nonzero: Maximum number of nonzero parameters allowed per row. - """ - # Apply sum constraints - row_sums = df.sum(axis=1) - mask_violated = row_sums > max_sum + tolerance - if boundary_only: - mask_violated |= row_sums < max_sum - tolerance - - # Apply optional nonzero constraints - if (min_nonzero is not None) or (max_nonzero is not None): - n_nonzero = (df != 0.0).sum(axis=1) - if min_nonzero is not None: - mask_violated |= n_nonzero < min_nonzero - if max_nonzero is not None: - mask_violated |= n_nonzero > max_nonzero - - # Remove violating rows - idxs_to_drop = df[mask_violated].index - df.drop(index=idxs_to_drop, inplace=True) - - # Get the minimum sum contributions to come in the upcoming joins (the - # first item is the minimum possible sum of all parameters starting from the - # second parameter, the second item is the minimum possible sum starting from - # the third parameter, and so on ...) - min_sum_upcoming = np.cumsum(min_values[:0:-1])[::-1] - - # Get the min/max number of nonzero values to come in the upcoming joins (the - # first item is the min/max number of nonzero parameters starting from the - # second parameter, the second item is the min/max number starting from - # the third parameter, and so on ...) - min_nonzero_upcoming = np.cumsum((np.asarray(min_values) > 0.0)[:0:-1])[::-1] - max_nonzero_upcoming = np.cumsum((np.asarray(max_values) > 0.0)[:0:-1])[::-1] - - # Incrementally build up the space, dropping invalid configuration along the - # way. More specifically: - # * After having cross-joined a new parameter, there must - # be enough "room" left for the remaining parameters to fit. That is, - # configurations of the current parameter subset that exceed the desired - # total value minus the minimum contribution to come from the yet-to-be-added - # parameters can be already discarded, because it is already clear that - # the total sum will be exceeded once all joins are completed. - # * Analogously, there must be enough "nonzero slots" left for the yet to be - # joined parameters, i.e. parameter subset configurations can be discarded - # where the number of nonzero parameters already exceeds the maximum number - # of nonzeros minus the number of nonzeros to come, because it is already - # clear that the maximum will be exceeded once all joins are completed. - # * Similarly, it can be verified for each parameter that there are still - # enough nonzero parameters to come to even reach the minimum - # desired number of nonzero after all joins. + # Compute per-parameter minimum weighted contributions. + # For a positive coefficient c the minimum contribution is c*min_raw; for a + # negative coefficient the ordering flips and it becomes c*max_raw. Taking + # min of both products handles any real coefficient correctly. + coeffs = np.asarray(simplex_coefficients, dtype=float) + if not np.isfinite(coeffs).all(): + raise ValueError( + f"All simplex_coefficients passed to '{cls.from_simplex.__name__}' " + f"must be finite numbers." + ) + min_weighted = np.array( + [min(c * lo, c * hi) for c, lo, hi in zip(coeffs, min_raw, max_raw)] + ) + + # Get the minimum weighted sum contributions to come in the upcoming joins (the + # first item is the minimum possible weighted sum of all parameters starting + # from the second parameter, the second item is the minimum possible weighted + # sum starting from the third parameter, and so on ...) + min_sum_upcoming = np.cumsum(min_weighted[:0:-1])[::-1] + + # Get the min/max number of nonzero values to come in the upcoming joins. + # Nonzero counting is based on raw parameter values, not weighted values, + # because the cardinality constraint counts zero/nonzero entries regardless + # of the coefficient signs. + min_nonzero_upcoming = np.cumsum((np.asarray(min_raw) > 0.0)[:0:-1])[::-1] + max_nonzero_upcoming = np.cumsum((np.asarray(max_raw) > 0.0)[:0:-1])[::-1] + + # Incrementally build up the space as a numpy array, dropping invalid + # configurations along the way. Working with raw numpy avoids pandas overhead + # (index management, BlockManager, merge machinery) in the hot loop. + # + # After having cross-joined a new parameter, there must be enough "room" left + # for the remaining parameters to fit. That is, configurations of the current + # parameter subset that exceed the desired total value minus the minimum + # contribution to come from the yet-to-be-added parameters can be already + # discarded, because it is already clear that the total sum will be exceeded + # once all joins are completed. Analogously, nonzero cardinality bounds are + # checked at each step. + arr: np.ndarray for i, ( param, min_sum_to_go, @@ -450,27 +448,44 @@ def drop_invalid( np.append(max_nonzero_upcoming, 0), ) ): + values = np.asarray(param.values, dtype=float) + if i == 0: - exp_rep = pd.DataFrame({param.name: param.values}) + arr = values.reshape(-1, 1) else: - exp_rep = pd.merge( - exp_rep, pd.DataFrame({param.name: param.values}), how="cross" + n_old = arr.shape[0] + n_new = len(values) + arr = np.column_stack( + [ + np.repeat(arr, n_new, axis=0), + np.tile(values, n_old), + ] ) - drop_invalid( - exp_rep, - max_sum=max_sum - min_sum_to_go, - # the maximum possible number of nonzeros to come dictates if we - # can achieve our minimum constraint in the end: - min_nonzero=min_nonzero - max_nonzero_to_go, - # the minimum possible number of nonzeros to come dictates if we - # can stay below the targeted maximum in the end: - max_nonzero=max_nonzero - min_nonzero_to_go, - boundary_only=False, - ) + + # Compute weighted row sums and build validity mask + row_sums = arr @ coeffs[: i + 1] + mask = row_sums <= (max_sum - min_sum_to_go) + tolerance + + # Apply nonzero cardinality bounds + effective_min = min_nonzero - max_nonzero_to_go + effective_max = max_nonzero - min_nonzero_to_go + if effective_min > 0 or effective_max < len(simplex_parameters): + n_nz = np.count_nonzero(arr, axis=1) + if effective_min > 0: + mask &= n_nz >= effective_min + if effective_max < len(simplex_parameters): + mask &= n_nz <= effective_max + + arr = arr[mask] # If requested, keep only the boundary values if boundary_only: - drop_invalid(exp_rep, max_sum, boundary_only=True) + row_sums = arr @ coeffs + mask = np.abs(row_sums - max_sum) <= tolerance + arr = arr[mask] + + # Wrap in DataFrame + exp_rep = pd.DataFrame(arr, columns=[p.name for p in simplex_parameters]) # Merge product parameters and apply constraints incrementally exp_rep = build_constrained_product( @@ -772,6 +787,30 @@ def validate_simplex_subspace_from_config(specs: dict, _) -> None: f"values only." ) + simplex_coefficients = specs.get("simplex_coefficients", None) + if simplex_coefficients is not None: + try: + simplex_coefficients = converter.structure( + simplex_coefficients, list[float] + ) + except (IterableValidationError, TypeError, ValueError) as exc: + raise ValueError( + "'simplex_coefficients' must be a list of numeric values." + ) from exc + + if len(simplex_coefficients) != len(simplex_parameters): + raise ValueError( + f"'simplex_coefficients' must have one entry per " + f"'simplex_parameters' entry, but got " + f"{len(simplex_coefficients)} coefficient(s) for " + f"{len(simplex_parameters)} parameter(s)." + ) + + if any(c == 0.0 for c in simplex_coefficients): + raise ValueError( + "All entries in 'simplex_coefficients' must be non-zero." + ) + product_parameters = specs.get("product_parameters", []) if product_parameters: product_parameters = converter.structure( diff --git a/tests/constraints/test_constraints_discrete.py b/tests/constraints/test_constraints_discrete.py index 9273ae13bf..fb850aae91 100644 --- a/tests/constraints/test_constraints_discrete.py +++ b/tests/constraints/test_constraints_discrete.py @@ -1,8 +1,14 @@ """Test for imposing discrete constraints.""" +import itertools import math +import pandas as pd import pytest +from pytest import param + +from baybe.constraints.conditions import ThresholdCondition +from baybe.constraints.discrete import DiscreteSumConstraint @pytest.fixture( @@ -275,3 +281,31 @@ def test_cardinality(campaign): min_cardinality = 1 max_cardinality = 2 assert non_zeros.between(min_cardinality, max_cardinality).all() + + +@pytest.mark.parametrize( + ("coefficients", "threshold", "operator", "n_invalid"), + [ + param(None, 1.0, "<=", 3, id="default"), + param((1.0, 1.0), 1.0, "<=", 3, id="all-ones"), + param((2.0, 1.0), 1.0, "<=", 5, id="scaled"), + param((1.0, -1.0), 0.5, "<=", 1, id="negative"), + param((1.0, 1.0), 1.0, "=", 6, id="equality"), + ], +) +def test_sum_constraint_coefficients(coefficients, threshold, operator, n_invalid): + """DiscreteSumConstraint filters correctly with default and custom coefficients.""" + kwargs = {} if coefficients is None else {"coefficients": coefficients} + constraint = DiscreteSumConstraint( + parameters=["A", "B"], + condition=ThresholdCondition(threshold=threshold, operator=operator), + **kwargs, + ) + df = pd.DataFrame( + list(itertools.product([0.0, 0.5, 1.0], repeat=2)), columns=["A", "B"] + ) + coeffs = coefficients or (1.0, 1.0) + weighted = df["A"] * coeffs[0] + df["B"] * coeffs[1] + expected = df.index[~ThresholdCondition(threshold, operator).evaluate(weighted)] + assert list(constraint.get_invalid(df)) == list(expected) + assert len(constraint.get_invalid(df)) == n_invalid diff --git a/tests/constraints/test_constraints_polars.py b/tests/constraints/test_constraints_polars.py index adbb1c5b2c..f5d772ae5c 100644 --- a/tests/constraints/test_constraints_polars.py +++ b/tests/constraints/test_constraints_polars.py @@ -2,6 +2,7 @@ import pytest from pandas.testing import assert_frame_equal +from pytest import param from baybe._optional.info import POLARS_INSTALLED from baybe.constraints import ( @@ -51,25 +52,10 @@ def _lazyframe_from_product(parameters): return res -@pytest.mark.parametrize("parameter_names", [["Fraction_1", "Fraction_2"]]) -@pytest.mark.parametrize("constraint_names", [["Constraint_8"]]) -def test_polars_prodsum1(parameters, constraints): - """Tests Polars implementation of sum constraint.""" - ldf = _lazyframe_from_product(parameters) - ldf = _apply_constraint_filter_polars(ldf, constraints) - - # Number of entries with 1,2-sum above 150 - ldf = ldf.with_columns(sum=pl.sum_horizontal(["Fraction_1", "Fraction_2"])) - ldf = ldf.filter(pl.col("sum") > 150) - num_entries = len(ldf.collect()) - - assert num_entries == 0 - - @pytest.mark.parametrize("parameter_names", [["Fraction_1", "Fraction_2"]]) @pytest.mark.parametrize("constraint_names", [["Constraint_9"]]) -def test_polars_prodsum2(parameters, constraints): - """Tests Polars implementation of product constrain.""" +def test_polars_product_constraint(parameters, constraints): + """Tests Polars implementation of product constraint.""" ldf = _lazyframe_from_product(parameters) ldf = _apply_constraint_filter_polars(ldf, constraints) @@ -85,20 +71,44 @@ def test_polars_prodsum2(parameters, constraints): assert num_entries == 0 +@pytest.mark.parametrize( + ("coefficients", "threshold", "operator"), + [ + param(None, 150.0, "<=", id="unweighted-le"), + param(None, 100.0, "=", id="unweighted-eq"), + param((2.0, 1.0), 150.0, "<=", id="weighted-le"), + param((1.0, -1.0), 50.0, "<=", id="negative-le"), + param((0.5, 0.5), 50.0, "=", id="weighted-eq"), + ], +) @pytest.mark.parametrize("parameter_names", [["Fraction_1", "Fraction_2"]]) -@pytest.mark.parametrize("constraint_names", [["Constraint_10"]]) -def test_polars_prodsum3(parameters, constraints): - """Tests Polars implementation of exact sum constraint.""" +def test_polars_sum_constraint(parameters, coefficients, threshold, operator): + """Polars and Pandas paths produce correct and identical results.""" + names = [p.name for p in parameters] + kwargs = {} if coefficients is None else {"coefficients": coefficients} + condition = ThresholdCondition(threshold=threshold, operator=operator) + constraint = DiscreteSumConstraint(parameters=names, condition=condition, **kwargs) + coeffs = coefficients or (1.0,) * len(parameters) + ldf = _lazyframe_from_product(parameters) - ldf = _apply_constraint_filter_polars(ldf, constraints) + df_pd = parameter_cartesian_prod_pandas(parameters) - # Number of entries with sum unequal to 100 - ldf = ldf.with_columns(sum=pl.sum_horizontal(["Fraction_1", "Fraction_2"])) - df = ldf.select(abs(pl.col("sum") - 100)).filter(pl.col("sum") > 0.01).collect() + _apply_constraint_filter_pandas(df_pd, [constraint]) + df_pl = _apply_constraint_filter_polars(ldf, [constraint]).collect().to_pandas() - num_entries = len(df) + # Correctness: all remaining rows satisfy the constraint + weighted_pd = sum(df_pd[n] * c for n, c in zip(names, coeffs)) + assert condition.evaluate(weighted_pd).all() - assert num_entries == 0 + weighted_pl = sum(df_pl[n] * c for n, c in zip(names, coeffs)) + assert condition.evaluate(weighted_pl).all() + + # Consistency: both paths agree + cols = df_pd.columns.tolist() + assert_frame_equal( + df_pd.sort_values(cols).reset_index(drop=True), + df_pl.sort_values(cols).reset_index(drop=True), + ) @pytest.mark.parametrize( diff --git a/tests/hypothesis_strategies/alternative_creation/test_searchspace.py b/tests/hypothesis_strategies/alternative_creation/test_searchspace.py index 662e898134..1dbc67f892 100644 --- a/tests/hypothesis_strategies/alternative_creation/test_searchspace.py +++ b/tests/hypothesis_strategies/alternative_creation/test_searchspace.py @@ -1,5 +1,7 @@ """Test alternative ways of creation not considered in the strategies.""" +import itertools + import hypothesis.strategies as st import numpy as np import pandas as pd @@ -8,6 +10,8 @@ from pandas.testing import assert_frame_equal from pytest import param +from baybe.constraints.conditions import ThresholdCondition +from baybe.constraints.discrete import DiscreteSumConstraint from baybe.parameters import ( CategoricalParameter, NumericalContinuousParameter, @@ -196,3 +200,98 @@ def test_discrete_space_creation_from_simplex_restricted(boundary_only): assert n_nonzero.max() == 4 assert len(subspace.parameters) == len(subspace.exp_rep.columns) assert all(p.name in subspace.exp_rep.columns for p in subspace.parameters) + + +_simplex_params = [ + NumericalDiscreteParameter(name="A", values=[0.0, 0.5, 1.0]), + NumericalDiscreteParameter(name="B", values=[0.0, 0.5, 1.0]), + NumericalDiscreteParameter(name="C", values=[0.0, 0.5, 1.0]), +] + + +def _brute_force_weighted_simplex( + params, max_sum, coefficients, *, boundary_only=False, tol=1e-9 +): + """Return all combinations satisfying the weighted simplex constraint.""" + df = pd.DataFrame( + list(itertools.product(*[p.values for p in params])), + columns=[p.name for p in params], + ) + weighted = sum(df[p.name] * c for p, c in zip(params, coefficients)) + mask = weighted <= max_sum + tol + if boundary_only: + mask &= weighted >= max_sum - tol + return df[mask].reset_index(drop=True) + + +@pytest.mark.parametrize( + ("coefficients", "max_sum", "boundary_only"), + [ + param(None, 1.0, False, id="default"), + param([1.0, 1.0, 1.0], 1.0, False, id="explicit-ones"), + param([2.0, 1.0, 0.5], 1.5, False, id="positive"), + param([2.0, 1.0, 0.5], 1.5, True, id="positive-boundary"), + param([1.0, -0.5, 2.0], 1.0, False, id="mixed-sign"), + ], +) +def test_discrete_space_creation_from_simplex_coefficients( + coefficients, max_sum, boundary_only +): + """Simplex subspace with coefficients matches brute-force and from_product.""" + coeffs = coefficients or [1.0, 1.0, 1.0] + cols = [p.name for p in _simplex_params] + + # Ground truth via brute force + expected = _brute_force_weighted_simplex( + _simplex_params, max_sum, coeffs, boundary_only=boundary_only + ) + expected = expected.sort_values(cols).reset_index(drop=True) + + # from_simplex + result_simplex = ( + SubspaceDiscrete.from_simplex( + max_sum, + _simplex_params, + simplex_coefficients=coefficients, + boundary_only=boundary_only, + ) + .exp_rep.sort_values(cols) + .reset_index(drop=True) + ) + assert_frame_equal(result_simplex, expected, check_dtype=False) + + # from_product with equivalent constraint + operator = "=" if boundary_only else "<=" + constraint = DiscreteSumConstraint( + parameters=cols, + condition=ThresholdCondition(threshold=max_sum, operator=operator), + coefficients=tuple(coeffs), + ) + result_product = ( + SubspaceDiscrete.from_product(_simplex_params, constraints=[constraint]) + .exp_rep.sort_values(cols) + .reset_index(drop=True) + ) + assert_frame_equal(result_product, expected, check_dtype=False) + + +@pytest.mark.parametrize( + ("simplex_coefficients", "match"), + [ + param( + [1.0], "'simplex_coefficients' must have one entry", id="length-mismatch" + ), + param([1.0, 0.0], "'simplex_coefficients' must be non-zero", id="zero-coeff"), + ], +) +def test_from_simplex_invalid_coefficients(simplex_coefficients, match): + """Invalid simplex_coefficients raise a ValueError.""" + with pytest.raises(ValueError, match=match): + SubspaceDiscrete.from_simplex( + 1.0, + [ + NumericalDiscreteParameter(name="x", values=[0.0, 0.5, 1.0]), + NumericalDiscreteParameter(name="y", values=[0.0, 0.5, 1.0]), + ], + simplex_coefficients=simplex_coefficients, + ) diff --git a/tests/hypothesis_strategies/constraints.py b/tests/hypothesis_strategies/constraints.py index e1f1014833..822c5fe2c3 100644 --- a/tests/hypothesis_strategies/constraints.py +++ b/tests/hypothesis_strategies/constraints.py @@ -27,6 +27,9 @@ from baybe.parameters.numerical import NumericalDiscreteParameter from tests.hypothesis_strategies.basic import finite_floats +_nonzero_finite_floats = finite_floats().filter(lambda x: x != 0.0) +"""A strategy producing non-zero finite floats.""" + def sub_selection_conditions(superset: list[Any] | None = None): """Generate :class:`baybe.constraints.conditions.SubSelectionCondition`.""" @@ -174,7 +177,9 @@ def discrete_permutation_invariance_constraints( return DiscretePermutationInvarianceConstraint(parameter_names, dependencies) +@st.composite def _discrete_constraints( + draw: st.DrawFn, constraint_type: ( type[DiscreteSumConstraint] | type[DiscreteProductConstraint] @@ -185,16 +190,22 @@ def _discrete_constraints( ): """Generate discrete constraints.""" if parameter_names is None: - parameters = st.lists(st.text(), unique=True, min_size=1) + params = draw(st.lists(st.text(), unique=True, min_size=1)) else: assert len(parameter_names) > 0 assert len(parameter_names) == len(set(parameter_names)) - parameters = st.just(parameter_names) - - if constraint_type in [DiscreteSumConstraint, DiscreteProductConstraint]: - return st.builds(constraint_type, parameters, threshold_conditions()) + params = parameter_names + + if constraint_type is DiscreteSumConstraint: + condition = draw(threshold_conditions()) + if draw(st.booleans()): + coefficients = draw(st.tuples(*([_nonzero_finite_floats] * len(params)))) + return DiscreteSumConstraint(params, condition, coefficients) + return DiscreteSumConstraint(params, condition) + elif constraint_type is DiscreteProductConstraint: + return DiscreteProductConstraint(params, draw(threshold_conditions())) else: - return st.builds(constraint_type, parameters) + return constraint_type(params) discrete_sum_constraints = partial(_discrete_constraints, DiscreteSumConstraint) @@ -227,7 +238,7 @@ def continuous_linear_constraints( assert len(parameter_names) > 0 assert len(parameter_names) == len(set(parameter_names)) - coefficients = draw(st.tuples(*([finite_floats()] * len(parameter_names)))) + coefficients = draw(st.tuples(*([_nonzero_finite_floats] * len(parameter_names)))) rhs = draw(finite_floats()) is_interpoint = draw(st.booleans()) diff --git a/tests/validation/test_constraint_validation.py b/tests/validation/test_constraint_validation.py index 2bee6bdd8f..db0aec3876 100644 --- a/tests/validation/test_constraint_validation.py +++ b/tests/validation/test_constraint_validation.py @@ -3,7 +3,12 @@ import pytest from pytest import param -from baybe.constraints.continuous import ContinuousCardinalityConstraint +from baybe.constraints.conditions import ThresholdCondition +from baybe.constraints.continuous import ( + ContinuousCardinalityConstraint, + ContinuousLinearConstraint, +) +from baybe.constraints.discrete import DiscreteSumConstraint @pytest.mark.parametrize( @@ -21,3 +26,26 @@ def test_invalid_cardinalities(cardinalities, error, match): """Providing an invalid parameter name raises an exception.""" with pytest.raises(error, match=match): ContinuousCardinalityConstraint(["x", "y"], *cardinalities) + + +@pytest.mark.parametrize( + ("coefficients", "match"), + [ + param((1.0, 2.0), "'coefficients' list must have one", id="length-mismatch"), + param((1.0, 0.0, 1.0), "'coefficients' must be non-zero", id="zero-coeff"), + ], +) +def test_invalid_coefficients(coefficients, match): + """Invalid coefficients raise a ValueError.""" + with pytest.raises(ValueError, match=match): + DiscreteSumConstraint( + parameters=["A", "B", "C"], + condition=ThresholdCondition(threshold=1.0, operator="<="), + coefficients=coefficients, + ) + with pytest.raises(ValueError, match=match): + ContinuousLinearConstraint( + parameters=["A", "B", "C"], + operator="<=", + coefficients=coefficients, + )