Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]
### Added
- `coefficients` attribute for `DiscreteSumConstraint`, enabling weighted sums. Follows
Comment thread
Scienfitz marked this conversation as resolved.
the same pattern as `ContinuousLinearConstraint.coefficients`
- `simplex_coefficients` keyword argument to `SubspaceDiscrete.from_simplex` for
weighted simplex sum constraints
- Support for Python 3.14
- Support for pandas 3
- `Settings` class for unified and streamlined settings management
Expand Down Expand Up @@ -40,6 +44,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Breaking Changes
- `parameter_cartesian_prod_pandas` and `parameter_cartesian_prod_polars` moved
from `baybe.searchspace.discrete` to `baybe.searchspace.utils`
- All optional arguments of `SubspaceDiscrete.from_simplex` after `simplex_parameters`
Comment thread
AVHopp marked this conversation as resolved.
are now keyword-only
- `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples
instead of a single tuple (needed for interpoint constraints)
- `Kernel.to_gpytorch` now takes a `SearchSpace` instead of explicit `ard_num_dims`,
Expand All @@ -61,6 +67,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
equality check

### Changed
- `DiscreteSumConstraint`, `ContinuousLinearConstraint`, and
`SubspaceDiscrete.from_simplex` now forbid 0 as coefficients
- "User Guide" section has been split into "Components" and "Concepts"
- Default transfer learning kernel changed from `IndexKernel` to `PositiveIndexKernel`,
enforcing positive task correlations
Expand Down
2 changes: 2 additions & 0 deletions baybe/constraints/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def _validate_coefficients( # noqa: DOC101, DOC103
"The given 'coefficients' list must have one floating point entry for "
"each entry in 'parameters'."
)
if any(c == 0.0 for c in coefficients):
raise ValueError("All entries in 'coefficients' must be non-zero.")

@coefficients.default
def _default_coefficients(self) -> tuple[float, ...]:
Expand Down
53 changes: 48 additions & 5 deletions baybe/constraints/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from __future__ import annotations

import gc
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import reduce
from typing import TYPE_CHECKING, Any, ClassVar, cast

import cattrs
import numpy as np
import numpy.typing as npt
import pandas as pd
from attrs import define, field
from attrs.validators import in_, min_len
from attrs.validators import deep_iterable, in_, min_len
from typing_extensions import override

from baybe.constraints.base import CardinalityConstraint, DiscreteConstraint
Expand All @@ -26,6 +27,7 @@
block_serialization_hook,
converter,
)
from baybe.utils.validation import finite_float

if TYPE_CHECKING:
import polars as pl
Expand Down Expand Up @@ -77,7 +79,11 @@ def get_invalid_polars(self) -> pl.Expr:

@define
class DiscreteSumConstraint(DiscreteConstraint):
"""Class for modelling sum constraints."""
"""Class for modelling sum constraints.

The constraint evaluates whether the (optionally weighted) sum of the specified
parameters satisfies the given threshold condition.
"""

# IMPROVE: refactor `SumConstraint` and `ProdConstraint` to avoid code copying

Expand All @@ -94,9 +100,45 @@ class DiscreteSumConstraint(DiscreteConstraint):
condition: ThresholdCondition = field()
"""The condition modeled by this constraint."""

coefficients: tuple[float, ...] = field(
converter=lambda x: cattrs.structure(x, tuple[float, ...]),
validator=deep_iterable(member_validator=finite_float),
)
"""The coefficients for the weighted sum, one per entry in ``parameters``.

Defaults to all-ones, i.e. an unweighted sum."""

@coefficients.default
def _default_coefficients(self) -> tuple[float, ...]:
"""Return equal weight coefficients as default."""
return (1.0,) * len(self.parameters)

@coefficients.validator
def _validate_coefficients( # noqa: DOC101, DOC103
self, _: Any, coefficients: Sequence[float]
) -> None:
"""Validate the coefficients.

Raises:
ValueError: If the number of coefficients does not match the number of
parameters.
"""
if len(self.parameters) != len(coefficients):
raise ValueError(
"The given 'coefficients' list must have one floating point entry for "
"each entry in 'parameters'."
)
if any(c == 0.0 for c in coefficients):
raise ValueError("All entries in 'coefficients' must be non-zero.")

@override
def _get_invalid(self, df: pd.DataFrame, /) -> pd.Index:
evaluate_df = df[self.parameters].sum(axis=1)
evaluate_df = pd.Series(
sum(
df[p].to_numpy() * c for p, c in zip(self.parameters, self.coefficients)
),
index=df.index,
)
Comment on lines +136 to +141

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
evaluate_df = pd.Series(
sum(
df[p].to_numpy() * c for p, c in zip(self.parameters, self.coefficients)
),
index=df.index,
)
evaluate_df = df[self.parameters] @ self.coefficients

@Scienfitz Scienfitz Jun 10, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you see this comment in the PR description
image

i am prioritizing not doing copy operations here at the cost of having to do several computations instead of one big vectorized one. in the limit of few parameters (generally the case for us) this should be the better choice

mask_bad = ~self.condition.evaluate(evaluate_df)

return df.index[mask_bad]
Expand All @@ -105,7 +147,8 @@ def _get_invalid(self, df: pd.DataFrame, /) -> pd.Index:
def get_invalid_polars(self) -> pl.Expr:
from baybe._optional.polars import polars as pl

return self.condition.to_polars(pl.sum_horizontal(self.parameters)).not_()
weighted = [pl.col(p) * c for p, c in zip(self.parameters, self.coefficients)]
return self.condition.to_polars(pl.sum_horizontal(weighted)).not_()


@define
Expand Down
Loading
Loading