-
Notifications
You must be signed in to change notification settings - Fork 79
Add outcome constraints #792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0019a8f
0d7714c
8a12ff8
cc05d25
d7f29cb
9550bb0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| _ExpectedHypervolumeImprovement, | ||
| qExpectedHypervolumeImprovement, | ||
| qLogExpectedHypervolumeImprovement, | ||
| qLogNoisyExpectedImprovement, | ||
| qNegIntegratedPosteriorVariance, | ||
| qThompsonSampling, | ||
| ) | ||
|
|
@@ -75,6 +76,7 @@ class BotorchAcquisitionArgs: | |
| # Optional, depending on the specific acquisition function being used | ||
| best_f: float | None = _OPT_FIELD | ||
| beta: float | None = _OPT_FIELD | ||
| constraints: list | None = _OPT_FIELD | ||
| maximize: bool | None = _OPT_FIELD | ||
| mc_points: Tensor | None = _OPT_FIELD | ||
| num_fantasies: int | None = _OPT_FIELD | ||
|
|
@@ -197,6 +199,7 @@ def build(self) -> BoAcquisitionFunction: | |
| # Set context-specific parameters | ||
| self._set_best_f() | ||
| self._set_target_transformation() | ||
| self._set_constraints() | ||
| self._set_X_baseline() | ||
| self._set_X_pending() | ||
| self._set_mc_points() | ||
|
|
@@ -222,6 +225,18 @@ def _set_target_transformation(self) -> None: | |
| return | ||
|
|
||
| if self.acqf.is_analytic: | ||
| # TODO: Certain analytic acquisition functions (e.g. analytic EI with | ||
| # constraints) do support outcome constraints and will be added to BayBE | ||
| # in the future. Once available, this guard should be scoped to only | ||
| # those analytic acqfs that do NOT support constraints, and | ||
| # `to_botorch_posterior_transform()` must be fixed to pad the weight | ||
| # vector to length `n_models`. | ||
| if self.objective.outcome_constraints: | ||
| raise IncompatibilityError( | ||
| f"Analytical acquisition function '{type(self.acqf).__name__}' " | ||
| f"does not support outcome constraints. Use an MC-based " | ||
| f"acquisition function instead." | ||
| ) | ||
| try: | ||
| transform = self.objective.to_botorch_posterior_transform() | ||
| except NonGaussianityError as ex: | ||
|
|
@@ -253,17 +268,111 @@ def _set_target_transformation(self) -> None: | |
|
|
||
| self._args.objective = self.objective.to_botorch() | ||
|
|
||
| def _set_constraints(self) -> None: | ||
| """Set BoTorch's ``constraints`` argument from outcome constraints. | ||
|
|
||
| Outcome constraint compatibility check — Layer 2 (acquisition function level). | ||
| Raises IncompatibilityError if the acqf's BoTorch __init__ signature does not | ||
| include a ``constraints`` parameter. | ||
| """ | ||
| if not self.objective.outcome_constraints: | ||
| return | ||
|
|
||
| if flds.constraints.name not in self._signature: | ||
| raise IncompatibilityError( | ||
| f"The selected acquisition function " | ||
| f"'{type(self.acqf).__name__}' does not support outcome " | ||
| f"constraints. Use a compatible acquisition function such as " | ||
| f"'{qLogNoisyExpectedImprovement.__name__}' instead." | ||
| ) | ||
| constraints = self.objective.to_botorch_constraints() | ||
| if constraints: | ||
| self._args.constraints = constraints | ||
|
|
||
| def _set_best_f(self) -> None: | ||
| """Set BoTorch's ``best_f`` argument.""" | ||
| """Set BoTorch's ``best_f`` argument. | ||
|
|
||
| best_f is a constant reference value (not differentiable). When outcome | ||
| constraints are present, only feasible training points are considered. | ||
| """ | ||
| if flds.best_f.name not in self._signature: | ||
| return | ||
|
|
||
| match self.objective: | ||
| case SingleTargetObjective() | DesirabilityObjective(): | ||
| self._args.best_f = self._posterior_mean_comp.max().item() | ||
| if not (constraints := self.objective.to_botorch_constraints()): | ||
| self._args.best_f = self._posterior_mean_comp.max().item() | ||
| else: | ||
| self._args.best_f = self._compute_best_f_with_constraints( | ||
| constraints | ||
| ) | ||
| case _: | ||
| raise NotImplementedError("This line should be impossible to reach.") | ||
|
|
||
| def _compute_best_f_with_constraints( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| self, constraints: list[Callable[[Tensor], Tensor]] | ||
| ) -> float: | ||
| """Compute the best objective value considering outcome constraints. | ||
|
|
||
| Falls back to the global maximum if no feasible training point exists. | ||
|
|
||
| Args: | ||
| constraints: Constraint functions from | ||
| :meth:`~baybe.objectives.base.Objective.to_botorch_constraints`. | ||
|
|
||
| Returns: | ||
| The best feasible objective value, or the global maximum as fallback. | ||
| """ | ||
| # Get objective values for all training points | ||
| objective_values = self._posterior_mean_comp | ||
|
|
||
| # Get raw model predictions for constraint evaluation | ||
| batched = to_tensor(self._train_x).unsqueeze(-2) | ||
| posterior = self._botorch_surrogate.posterior(batched) | ||
| model_predictions = posterior.mean.squeeze(-2) | ||
|
|
||
| # Apply constraint functions to filter feasible points | ||
| feasible_mask = self._compute_feasible_mask(model_predictions, constraints) | ||
|
|
||
| if not feasible_mask.any(): | ||
| # TODO: other mechanisms, e.g. steer towards feasible region? | ||
| # No feasible training points - fall back to global maximum | ||
| return objective_values.max().item() | ||
|
|
||
| # Return maximum among feasible points | ||
| feasible_objectives = objective_values[feasible_mask] | ||
| return feasible_objectives.max().item() | ||
|
|
||
| def _compute_feasible_mask( | ||
| self, | ||
| model_predictions: Tensor, | ||
| constraints: list[Callable[[Tensor], Tensor]], | ||
| ) -> Tensor: | ||
| """Compute boolean mask indicating which points satisfy all constraints. | ||
|
|
||
| Uses hard thresholding (feasible when constraint value <= 0) combined | ||
| via boolean AND across all constraints. | ||
|
|
||
| Args: | ||
| model_predictions: Raw model predictions [n_points, n_outputs] | ||
| constraints: Constraint functions from to_botorch_constraints() | ||
|
|
||
| Returns: | ||
| Boolean mask [n_points] where True = feasible, False = infeasible | ||
| """ | ||
| n_points = model_predictions.shape[0] | ||
| feasible_mask = torch.ones(n_points, dtype=torch.bool) | ||
|
|
||
| for constraint_func in constraints: | ||
| # Constraint func: [batch, q, m] -> [batch, q]; we insert q=1 | ||
| # via unsqueeze(-2), so output is [n_points, 1]; squeeze q dim. | ||
| constraint_violations = constraint_func( | ||
| model_predictions.unsqueeze(-2) | ||
| ).squeeze(-1) | ||
| feasible_mask &= constraint_violations <= 0 | ||
|
|
||
| return feasible_mask | ||
|
Comment on lines
+363
to
+374
|
||
|
|
||
| def set_default_sample_shape(self, acqf: BoAcquisitionFunction, /): | ||
| """Apply temporary workaround for Thompson sampling.""" | ||
| # TODO: Needs redesign once bandits are supported more generally | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| """Functionality for outcome constraints.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Callable | ||
| from typing import Literal | ||
|
|
||
| import pandas as pd | ||
| import torch | ||
| from attrs import define, field | ||
| from attrs.validators import in_, instance_of | ||
|
|
||
| from baybe.serialization.mixin import SerialMixin | ||
| from baybe.targets.base import Target | ||
|
|
||
|
|
||
| @define(frozen=True, slots=False) | ||
| class OutcomeConstraint(SerialMixin): | ||
| """A constraint applied to target outcomes in the output space. | ||
|
|
||
| Outcome constraints restrict the feasible region based on target predictions, | ||
| different from parameter constraints which restrict the input space. | ||
| """ | ||
|
|
||
| target: Target = field(validator=instance_of(Target)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if we should really store a |
||
| """The target to be constrained.""" | ||
|
|
||
| operator: Literal["<=", ">=", "=="] = field(validator=in_(["<=", ">=", "=="])) | ||
| """The constraint operator.""" | ||
|
|
||
| threshold: float = field(validator=instance_of((int, float)), converter=float) | ||
| """The constraint threshold value in experimental units.""" | ||
|
|
||
|
Comment on lines
+25
to
+33
|
||
| def __str__(self) -> str: | ||
| """Return string representation.""" | ||
| return f"{self.target.name} {self.operator} {self.threshold}" | ||
|
|
||
| def get_computational_threshold(self) -> float: | ||
| """Convert experimental threshold to computational units. | ||
|
|
||
| Returns: | ||
| The threshold value in computational units. | ||
| """ | ||
| # Create dummy series with threshold value in experimental units | ||
| experimental_series = pd.Series([self.threshold], name=self.target.name) | ||
|
|
||
| # Apply the same transformations as the target | ||
| computational_series = self.target.transform(experimental_series) | ||
|
|
||
| return computational_series.iloc[0] | ||
|
|
||
| def to_botorch_constraint_func( | ||
| self, target_idx: int | ||
| ) -> Callable[[torch.Tensor], torch.Tensor]: | ||
| """Create a botorch-compatible constraint function. | ||
|
|
||
| Args: | ||
| target_idx: Index of the target in model output. | ||
|
|
||
| Returns: | ||
| A constraint function that returns <= 0 for feasible region. | ||
| """ | ||
| computational_threshold = self.get_computational_threshold() | ||
|
|
||
| def constraint_func(samples: torch.Tensor) -> torch.Tensor: | ||
| """Constraint function operating on computational level. | ||
|
|
||
| Args: | ||
| samples: Model output samples in computational units. | ||
|
|
||
| Returns: | ||
| Constraint values where <= 0 indicates feasible region. | ||
|
|
||
| Raises: | ||
| ValueError: If the constraint operator is not supported. | ||
| """ | ||
| if self.operator == "<=": | ||
| return samples[..., target_idx] - computational_threshold | ||
| elif self.operator == ">=": | ||
| return computational_threshold - samples[..., target_idx] | ||
| elif self.operator == "==": | ||
| # Equality constraint with small tolerance | ||
| return ( | ||
| torch.abs(samples[..., target_idx] - computational_threshold) - 1e-6 | ||
| ) | ||
| else: | ||
| raise ValueError(f"Unsupported constraint operator: {self.operator}") | ||
|
|
||
| return constraint_func | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to do the following:
I don't know when you branched this branch off of main, but we only recently added the
AGENTS.mdfiles (#769) which auto-inject instructions to achieve consistent code using agentic development. So if you haven't done that yet, please rebase this PR on main immediately and ask the agent to "replay" the commits with the new rules fromAGENTS.mdfiles in mind, the force pushThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine/better if you rebase and add an additional commit that applies the changes – then we can also easily see if the AGENTS.md does its job 👍🏼