Skip to content
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- `posterior_mean_function` method to `GaussianProcessSurrogate`

## [0.15.0] - 2026-06-11
### Breaking Changes
- `GaussianProcessSurrogate` no longer automatically adds a task kernel in multi-task
Expand Down
5 changes: 5 additions & 0 deletions baybe/surrogates/gaussian_process/components/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import gc
from typing import TYPE_CHECKING, Any

import pandas as pd
Expand Down Expand Up @@ -40,3 +41,7 @@ def __call__(
from gpytorch.means import ConstantMean

return ConstantMean()


# Collect leftover original slotted classes created by attrs
gc.collect()
117 changes: 107 additions & 10 deletions baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from attrs.validators import instance_of, is_callable
from typing_extensions import Self, override

from baybe.exceptions import DeprecationError
from baybe.exceptions import DeprecationError, ModelNotTrainedError
from baybe.kernels.base import Kernel
from baybe.objectives.base import Objective
from baybe.parameters.base import Parameter
Expand Down Expand Up @@ -48,11 +48,12 @@
)
from baybe.utils.boolean import strtobool
from baybe.utils.conversion import to_string
from baybe.utils.dataframe import to_tensor

if TYPE_CHECKING:
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.transforms.input import InputTransform, Normalize
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.posteriors import Posterior
from gpytorch.kernels import Kernel as GPyTorchKernel
from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
Expand Down Expand Up @@ -211,6 +212,31 @@ class GaussianProcessSurrogate(Surrogate):
_model = field(init=False, default=None, eq=False)
"""The actual model."""

@staticmethod
def _make_input_transform(context: _ModelContext) -> Normalize:
"""Create the input transform for the Gaussian process."""
from botorch.models.transforms.input import Normalize

return Normalize(
len(context.searchspace.comp_rep_columns),
bounds=context.parameter_bounds,
indices=context.numerical_indices,
)

@staticmethod
def _make_outcome_transform(context: _ModelContext) -> Standardize:
"""Create the outcome transform for the Gaussian process."""
from botorch.models.transforms.outcome import Standardize

train_y = to_tensor(
context.objective._pre_transform(context.measurements, allow_extra=True)
)
if train_y.ndim == 1:
train_y = train_y.unsqueeze(-1)
transform = Standardize(m=train_y.shape[-1])
transform(train_y) # fits means/stdvs; GP will re-fit in train mode
return transform

@classmethod
def from_preset(
cls,
Expand Down Expand Up @@ -246,6 +272,82 @@ def from_preset(
gp._custom_kernel = False # preset are first-party features
return gp

def posterior_mean_function(
self,
searchspace: SearchSpace,
objective: Objective,
measurements: pd.DataFrame,
) -> GPyTorchMean:
"""Return a GPyTorch mean module representing the surrogate's posterior mean.

The bound method satisfies
:class:`~baybe.surrogates.gaussian_process.components.mean.MeanFactoryProtocol`
and can be passed directly to a new :class:`GaussianProcessSurrogate`.

Args:
searchspace: The search space of the new GP being fitted.
objective: The objective of the new GP being fitted.
measurements: The training data of the new GP being fitted.

Returns:
The posterior mean.

Raises:
ModelNotTrainedError: If the surrogate has not been fitted yet.
"""
from copy import deepcopy

import gpytorch

if self._model is None:
raise ModelNotTrainedError(
f"'{self.__class__.__name__}' must be fitted before its "
f"'{self.posterior_mean_function.__name__}' can be used as a "
f"mean function."
)

context = _ModelContext(searchspace, objective, measurements)

# Undo the new GP's input normalization before querying the prior GP
input_transform = self._make_input_transform(context)
input_transform.eval()

# Match the new GP's outcome standardization
outcome_transform = self._make_outcome_transform(context)
outcome_transform.eval()

class _PosteriorMean(gpytorch.means.Mean):
"""GPyTorch mean wrapping a trained GP's posterior.

Overrides ``train`` to keep all children in eval mode, preventing optimizers
from corrupting learned transform parameters.
"""

def __init__(self, gp: GPyTorchModel) -> None:
super().__init__()
self.gp = deepcopy(gp)
for param in self.gp.parameters():
param.requires_grad = False
self.gp.eval()
self.gp.likelihood.eval()

@override
def train(self, mode: bool = True) -> _PosteriorMean:
"""Set training mode without propagating to children."""
self.training = mode
return self

@override
def forward(self, x: Tensor) -> Tensor:
"""Compute the mean using the wrapped GP's posterior."""
with gpytorch.settings.fast_pred_var():
x_raw = input_transform.untransform(x)
posterior_mean = self.gp.posterior(x_raw).mean
standardized, _ = outcome_transform(posterior_mean)
return standardized.squeeze(-1)

return _PosteriorMean(self._model)

@override
def to_botorch(self) -> GPyTorchModel:
return self._model
Expand All @@ -271,7 +373,6 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior:
@override
def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
import botorch
from botorch.models.transforms import Normalize, Standardize

assert self._searchspace is not None # provided by base class
assert self._objective is not None # provided by base class
Expand All @@ -298,12 +399,8 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:

### Input/output scaling
# NOTE: For GPs, we let BoTorch handle scaling (see [Scaling Workaround] above)
input_transform = Normalize(
train_x.shape[-1],
bounds=context.parameter_bounds,
indices=context.numerical_indices,
)
outcome_transform = Standardize(train_y.shape[-1])
input_transform = self._make_input_transform(context)
outcome_transform = self._make_outcome_transform(context)

### Mean
mean = self.mean_factory(
Expand Down
86 changes: 86 additions & 0 deletions tests/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,89 @@ def test_botorch_preset(multitask: bool, preset: str):
posterior2 = _posterior_stats_botorch(sp, data)

assert_frame_equal(posterior1, posterior2)


def test_get_posterior_mean_correct_under_different_bounds():
"""Posterior mean evaluates at correct physical points when bounds differ."""
from baybe.parameters.numerical import NumericalDiscreteParameter

# Train a surrogate on a narrow search space [0, 5]
prior_params = [NumericalDiscreteParameter("x1", values=[0.0, 2.5, 5.0])]
prior_ss = SearchSpace.from_product(prior_params)
prior_obj = NumericalTarget(name="y").to_objective()

prior_surrogate = GaussianProcessSurrogate()
prior_meas = pd.DataFrame({"x1": [0.0, 2.5, 5.0], "y": [0.0, 5.0, 10.0]})
prior_surrogate.fit(prior_ss, prior_obj, prior_meas)

# Get the surrogate's prediction at x1=2.5
expected_mean = prior_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item()

# New GP on a WIDER search space [0, 10], using the get_posterior_mean method
new_params = [NumericalDiscreteParameter("x1", values=[0.0, 2.5, 5.0, 7.5, 10.0])]
new_ss = SearchSpace.from_product(new_params)

new_surrogate = GaussianProcessSurrogate(
mean_or_factory=prior_surrogate.posterior_mean_function
)
# Train on data that lies exactly on the prior mean to avoid kernel effects
training_points = pd.DataFrame({"x1": [0.0, 10.0]})
with torch.no_grad():
training_targets = prior_surrogate.posterior(training_points).mean
new_meas = pd.DataFrame(
{
"x1": training_points["x1"],
"y": training_targets.numpy().ravel(),
}
)
new_surrogate.fit(new_ss, prior_obj, new_meas)

# Test end-to-end: the posterior should match the prior mean
actual_mean = new_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item()

assert abs(actual_mean - expected_mean) < 1e-4


def test_get_posterior_mean_same_bounds():
"""Posterior mean is correct when both search spaces have the same bounds."""
from baybe.parameters.numerical import NumericalDiscreteParameter

params = [NumericalDiscreteParameter("x1", values=[0.0, 2.5, 5.0])]
ss = SearchSpace.from_product(params)
obj = NumericalTarget(name="y").to_objective()

prior_surrogate = GaussianProcessSurrogate()
meas = pd.DataFrame({"x1": [0.0, 2.5, 5.0], "y": [0.0, 5.0, 10.0]})
prior_surrogate.fit(ss, obj, meas)

expected_mean = prior_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item()

new_surrogate = GaussianProcessSurrogate(
mean_or_factory=prior_surrogate.posterior_mean_function
)
# Train on data that lies exactly on the prior mean
training_points = pd.DataFrame({"x1": [0.0, 5.0]})
with torch.no_grad():
training_targets = prior_surrogate.posterior(training_points).mean
new_meas = pd.DataFrame(
{
"x1": training_points["x1"],
"y": training_targets.numpy().ravel(),
}
)
new_surrogate.fit(ss, obj, new_meas)

# Test end-to-end: the posterior should match the prior mean
actual_mean = new_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item()

assert abs(actual_mean - expected_mean) < 1e-4


def test_get_posterior_mean_raises_if_not_fitted():
"""Calling get_posterior_mean raises if the surrogate has not been fitted."""
from baybe.exceptions import ModelNotTrainedError

with pytest.raises(ModelNotTrainedError, match="must be fitted"):
GaussianProcessSurrogate().posterior_mean_function(
searchspace, objective, measurements
)
Loading