From e65ca98f0cd7d8043901fefb9f3097ffe11011ec Mon Sep 17 00:00:00 2001 From: kalama-ai Date: Mon, 8 Jun 2026 17:25:45 +0200 Subject: [PATCH 1/9] Add GaussianProcessSurrogate.posterior_mean property - posterior_mean returns a mean factory that can be passed directly to a new GaussianProcessSurrogate via mean_or_factory - the new GP normalizes inputs before passing them to the mean module, so the factory undoes that normalization before querying the pretrained GP, which then applies its own normalization internally - Raises ModelNotTrainedError if the surrogate has not been fitted yet --- .../gaussian_process/components/mean.py | 77 ++++++++++++++++++- baybe/surrogates/gaussian_process/core.py | 24 ++++++ tests/test_gp.py | 71 +++++++++++++++++ 3 files changed, 171 insertions(+), 1 deletion(-) diff --git a/baybe/surrogates/gaussian_process/components/mean.py b/baybe/surrogates/gaussian_process/components/mean.py index e8c3408f91..e09fb70d20 100644 --- a/baybe/surrogates/gaussian_process/components/mean.py +++ b/baybe/surrogates/gaussian_process/components/mean.py @@ -2,21 +2,30 @@ from __future__ import annotations +import gc +from copy import deepcopy from typing import TYPE_CHECKING, Any import pandas as pd -from attrs import define +from attrs import define, field from typing_extensions import override from baybe.objectives.base import Objective from baybe.searchspace.core import SearchSpace +from baybe.serialization.core import ( + block_deserialization_hook, + block_serialization_hook, + converter, +) from baybe.surrogates.gaussian_process.components.generic import ( GPComponentFactoryProtocol, PlainGPComponentFactory, ) if TYPE_CHECKING: + from botorch.models import SingleTaskGP from gpytorch.means import Mean as GPyTorchMean + from torch import Tensor MeanFactoryProtocol = GPComponentFactoryProtocol[GPyTorchMean] PlainMeanFactory = PlainGPComponentFactory[GPyTorchMean] @@ -40,3 +49,69 @@ def __call__( from gpytorch.means import ConstantMean return ConstantMean() + + +@define +class _PosteriorMeanFactory(MeanFactoryProtocol): + """A mean factory producing a posterior mean from a trained BoTorch GP. + + The mean function uses the trained GP's posterior mean predictions. + The provided model is deep-copied and its parameters are frozen. + + Surrogates using this factory are not serializable because the underlying + BoTorch model is not covered by BayBE's serialization system. + """ + + _pretrained_gp = field(alias="pretrained_gp") + """The pretrained BoTorch GP whose posterior mean is used as the mean function.""" + + @override + def __call__( + self, + searchspace: SearchSpace, + objective: Objective, + measurements: pd.DataFrame, + ) -> GPyTorchMean: + import gpytorch + from botorch.models.transforms.input import Normalize + + from baybe.surrogates.gaussian_process.core import _ModelContext + + context = _ModelContext(searchspace, objective, measurements) + + # The new GP applies its input normalization before calling this mean module, + # so x arrives in the new GP's scaled coordinate system. Undo that scaling + # before calling the pretrained GP — it will apply its own normalization. + input_transform = Normalize( + len(searchspace.comp_rep_columns), + bounds=context.parameter_bounds, + indices=context.numerical_indices, + ) + input_transform.eval() + + class _PosteriorMean(gpytorch.means.Mean): + """GPyTorch mean using a trained GP's posterior as the mean function.""" + + def __init__(self, gp: SingleTaskGP, input_transform: Normalize) -> None: + super().__init__() + self.gp: SingleTaskGP = deepcopy(gp) + for param in self.gp.parameters(): + param.requires_grad = False + self.gp.eval() + self.gp.likelihood.eval() + self.input_transform = input_transform + + def forward(self, x: Tensor) -> Tensor: + """Compute the mean using the wrapped GP's posterior.""" + with gpytorch.settings.fast_pred_var(): + x_raw = self.input_transform.untransform(x) + return self.gp.posterior(x_raw).mean.squeeze(-1) + + return _PosteriorMean(self._pretrained_gp, input_transform) + + +# Prevent (de-)serialization since it wraps a raw BoTorch model +converter.register_unstructure_hook(_PosteriorMeanFactory, block_serialization_hook) +converter.register_structure_hook(_PosteriorMeanFactory, block_deserialization_hook) + +gc.collect() # Collect leftover original slotted classes created by attrs diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index cec34a32a5..6f051823b3 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -211,6 +211,30 @@ class GaussianProcessSurrogate(Surrogate): _model = field(init=False, default=None, eq=False) """The actual model.""" + @property + def posterior_mean(self) -> MeanFactoryProtocol: + """A mean factory representing this surrogate's posterior mean. + + Examples: + >>> new_gp = GaussianProcessSurrogate( + ... mean_or_factory=prior_gp.posterior_mean + ... ) + + Raises: + ModelNotTrainedError: If this surrogate has not been fitted yet. + """ + from baybe.exceptions import ModelNotTrainedError + from baybe.surrogates.gaussian_process.components.mean import ( + _PosteriorMeanFactory, + ) + + if self._model is None: + raise ModelNotTrainedError( + f"'{self.__class__.__name__}' must be fitted before accessing " + f"'posterior_mean'." + ) + return _PosteriorMeanFactory(self._model) + @classmethod def from_preset( cls, diff --git a/tests/test_gp.py b/tests/test_gp.py index 0da3f01918..ab5862d258 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -218,3 +218,74 @@ def test_botorch_preset(multitask: bool, preset: str): posterior2 = _posterior_stats_botorch(sp, data) assert_frame_equal(posterior1, posterior2) + + +def test_posterior_mean_correct_under_different_bounds(): + """Posterior mean evaluates at correct physical points when bounds differ.""" + from baybe.parameters.numerical import NumericalDiscreteParameter + + # Train a surrogate on a narrow search space [0, 5] + prior_params = [NumericalDiscreteParameter("x1", values=[0.0, 2.5, 5.0])] + prior_ss = SearchSpace.from_product(prior_params) + prior_obj = NumericalTarget(name="y").to_objective() + + prior_surrogate = GaussianProcessSurrogate() + prior_meas = pd.DataFrame({"x1": [0.0, 2.5, 5.0], "y": [0.0, 5.0, 10.0]}) + prior_surrogate.fit(prior_ss, prior_obj, prior_meas) + + # Get the surrogate's prediction at x1=2.5 + expected_mean = prior_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item() + + # New GP on a WIDER search space [0, 10], using the posterior_mean property + new_params = [NumericalDiscreteParameter("x1", values=[0.0, 2.5, 5.0, 7.5, 10.0])] + new_ss = SearchSpace.from_product(new_params) + + new_surrogate = GaussianProcessSurrogate( + mean_or_factory=prior_surrogate.posterior_mean + ) + new_meas = pd.DataFrame({"x1": [0.0, 10.0], "y": [0.0, 20.0]}) + new_surrogate.fit(new_ss, prior_obj, new_meas) + + # In the new space [0, 10], x1=2.5 normalizes to 0.25 + mean_module = new_surrogate._model.mean_module + x_normalized = torch.tensor([[0.25]]) + with torch.no_grad(): + actual_mean = mean_module(x_normalized).item() + + assert abs(actual_mean - expected_mean) < 1e-4 + + +def test_posterior_mean_same_bounds(): + """Posterior mean is correct when both search spaces have the same bounds.""" + from baybe.parameters.numerical import NumericalDiscreteParameter + + params = [NumericalDiscreteParameter("x1", values=[0.0, 2.5, 5.0])] + ss = SearchSpace.from_product(params) + obj = NumericalTarget(name="y").to_objective() + + prior_surrogate = GaussianProcessSurrogate() + meas = pd.DataFrame({"x1": [0.0, 2.5, 5.0], "y": [0.0, 5.0, 10.0]}) + prior_surrogate.fit(ss, obj, meas) + + expected_mean = prior_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item() + + new_surrogate = GaussianProcessSurrogate( + mean_or_factory=prior_surrogate.posterior_mean + ) + new_surrogate.fit(ss, obj, meas) + + # x1=2.5 normalizes to 0.5 in [0, 5] + mean_module = new_surrogate._model.mean_module + x_normalized = torch.tensor([[0.5]]) + with torch.no_grad(): + actual_mean = mean_module(x_normalized).item() + + assert abs(actual_mean - expected_mean) < 1e-4 + + +def test_posterior_mean_raises_if_not_fitted(): + """Accessing posterior_mean raises if the surrogate has not been fitted.""" + from baybe.exceptions import ModelNotTrainedError + + with pytest.raises(ModelNotTrainedError, match="must be fitted"): + GaussianProcessSurrogate().posterior_mean # noqa: B018 From 240d81d5b31dc02a6a6de33615cc2dfeff191113 Mon Sep 17 00:00:00 2001 From: kalama-ai Date: Tue, 9 Jun 2026 12:48:15 +0200 Subject: [PATCH 2/9] Remove transfer learning related example from docstring --- baybe/surrogates/gaussian_process/core.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index 6f051823b3..7e184e37bb 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -215,11 +215,6 @@ class GaussianProcessSurrogate(Surrogate): def posterior_mean(self) -> MeanFactoryProtocol: """A mean factory representing this surrogate's posterior mean. - Examples: - >>> new_gp = GaussianProcessSurrogate( - ... mean_or_factory=prior_gp.posterior_mean - ... ) - Raises: ModelNotTrainedError: If this surrogate has not been fitted yet. """ From 755a80bc71ba43427c13afb36063658bbf12362e Mon Sep 17 00:00:00 2001 From: kalama-ai Date: Tue, 9 Jun 2026 18:20:59 +0200 Subject: [PATCH 3/9] Replace GaussianProcessSurrogate.posterior_mean by a method MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit   - Replace the posterior_mean property with get_posterior_mean() (with MeanFactoryProtocol signature)   - method can be passed directly as mean_or_factory to a new GP    - Remove _PosteriorMeanFactory from mean.py   - Move _PosteriorMean class and normalization into the method --- .../gaussian_process/components/mean.py | 73 +--------------- baybe/surrogates/gaussian_process/core.py | 86 ++++++++++++++----- tests/test_gp.py | 18 ++-- 3 files changed, 77 insertions(+), 100 deletions(-) diff --git a/baybe/surrogates/gaussian_process/components/mean.py b/baybe/surrogates/gaussian_process/components/mean.py index e09fb70d20..759f862a89 100644 --- a/baybe/surrogates/gaussian_process/components/mean.py +++ b/baybe/surrogates/gaussian_process/components/mean.py @@ -3,29 +3,21 @@ from __future__ import annotations import gc -from copy import deepcopy from typing import TYPE_CHECKING, Any import pandas as pd -from attrs import define, field +from attrs import define from typing_extensions import override from baybe.objectives.base import Objective from baybe.searchspace.core import SearchSpace -from baybe.serialization.core import ( - block_deserialization_hook, - block_serialization_hook, - converter, -) from baybe.surrogates.gaussian_process.components.generic import ( GPComponentFactoryProtocol, PlainGPComponentFactory, ) if TYPE_CHECKING: - from botorch.models import SingleTaskGP from gpytorch.means import Mean as GPyTorchMean - from torch import Tensor MeanFactoryProtocol = GPComponentFactoryProtocol[GPyTorchMean] PlainMeanFactory = PlainGPComponentFactory[GPyTorchMean] @@ -51,67 +43,4 @@ def __call__( return ConstantMean() -@define -class _PosteriorMeanFactory(MeanFactoryProtocol): - """A mean factory producing a posterior mean from a trained BoTorch GP. - - The mean function uses the trained GP's posterior mean predictions. - The provided model is deep-copied and its parameters are frozen. - - Surrogates using this factory are not serializable because the underlying - BoTorch model is not covered by BayBE's serialization system. - """ - - _pretrained_gp = field(alias="pretrained_gp") - """The pretrained BoTorch GP whose posterior mean is used as the mean function.""" - - @override - def __call__( - self, - searchspace: SearchSpace, - objective: Objective, - measurements: pd.DataFrame, - ) -> GPyTorchMean: - import gpytorch - from botorch.models.transforms.input import Normalize - - from baybe.surrogates.gaussian_process.core import _ModelContext - - context = _ModelContext(searchspace, objective, measurements) - - # The new GP applies its input normalization before calling this mean module, - # so x arrives in the new GP's scaled coordinate system. Undo that scaling - # before calling the pretrained GP — it will apply its own normalization. - input_transform = Normalize( - len(searchspace.comp_rep_columns), - bounds=context.parameter_bounds, - indices=context.numerical_indices, - ) - input_transform.eval() - - class _PosteriorMean(gpytorch.means.Mean): - """GPyTorch mean using a trained GP's posterior as the mean function.""" - - def __init__(self, gp: SingleTaskGP, input_transform: Normalize) -> None: - super().__init__() - self.gp: SingleTaskGP = deepcopy(gp) - for param in self.gp.parameters(): - param.requires_grad = False - self.gp.eval() - self.gp.likelihood.eval() - self.input_transform = input_transform - - def forward(self, x: Tensor) -> Tensor: - """Compute the mean using the wrapped GP's posterior.""" - with gpytorch.settings.fast_pred_var(): - x_raw = self.input_transform.untransform(x) - return self.gp.posterior(x_raw).mean.squeeze(-1) - - return _PosteriorMean(self._pretrained_gp, input_transform) - - -# Prevent (de-)serialization since it wraps a raw BoTorch model -converter.register_unstructure_hook(_PosteriorMeanFactory, block_serialization_hook) -converter.register_structure_hook(_PosteriorMeanFactory, block_deserialization_hook) - gc.collect() # Collect leftover original slotted classes created by attrs diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index 7e184e37bb..05733f30a7 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -14,7 +14,7 @@ from attrs.validators import instance_of, is_callable from typing_extensions import Self, override -from baybe.exceptions import DeprecationError +from baybe.exceptions import DeprecationError, ModelNotTrainedError from baybe.kernels.base import Kernel from baybe.objectives.base import Objective from baybe.parameters.base import Parameter @@ -211,25 +211,6 @@ class GaussianProcessSurrogate(Surrogate): _model = field(init=False, default=None, eq=False) """The actual model.""" - @property - def posterior_mean(self) -> MeanFactoryProtocol: - """A mean factory representing this surrogate's posterior mean. - - Raises: - ModelNotTrainedError: If this surrogate has not been fitted yet. - """ - from baybe.exceptions import ModelNotTrainedError - from baybe.surrogates.gaussian_process.components.mean import ( - _PosteriorMeanFactory, - ) - - if self._model is None: - raise ModelNotTrainedError( - f"'{self.__class__.__name__}' must be fitted before accessing " - f"'posterior_mean'." - ) - return _PosteriorMeanFactory(self._model) - @classmethod def from_preset( cls, @@ -265,6 +246,71 @@ def from_preset( gp._custom_kernel = False # preset are first-party features return gp + def get_posterior_mean( + self, + searchspace: SearchSpace, + objective: Objective, + measurements: pd.DataFrame, + ) -> GPyTorchMean: + """Return a GPyTorch mean module representing this surrogate's posterior mean. + + The bound method satisfies :class:`.MeanFactoryProtocol` and can be passed + directly as ``mean_or_factory`` to a new :class:`GaussianProcessSurrogate`. + + Args: + searchspace: The search space of the new GP being fitted. + objective: The objective of the new GP being fitted. + measurements: The training data of the new GP being fitted. + + Returns: + A GPyTorch mean module that evaluates this surrogate's posterior mean. + + Raises: + ModelNotTrainedError: If this surrogate has not been fitted yet. + """ + from copy import deepcopy + + import gpytorch + from botorch.models.transforms.input import Normalize + + if self._model is None: + raise ModelNotTrainedError( + f"'{self.__class__.__name__}' must be fitted before its " + f"'get_posterior_mean' can be used as a mean function." + ) + + context = _ModelContext(searchspace, objective, measurements) + + # The new GP applies its input normalization before calling this mean module, + # so x arrives in the new GP's scaled coordinate system. Undo that scaling + # before calling the pretrained GP — it will apply its own normalization. + input_transform = Normalize( + len(searchspace.comp_rep_columns), + bounds=context.parameter_bounds, + indices=context.numerical_indices, + ) + input_transform.eval() + + class _PosteriorMean(gpytorch.means.Mean): + """GPyTorch mean using a trained GP's posterior as the mean function.""" + + def __init__(self, gp: GPyTorchModel, input_transform: Normalize) -> None: + super().__init__() + self.gp = deepcopy(gp) + for param in self.gp.parameters(): + param.requires_grad = False + self.gp.eval() + self.gp.likelihood.eval() + self.input_transform = input_transform + + def forward(self, x: Tensor) -> Tensor: + """Compute the mean using the wrapped GP's posterior.""" + with gpytorch.settings.fast_pred_var(): + x_raw = self.input_transform.untransform(x) + return self.gp.posterior(x_raw).mean.squeeze(-1) + + return _PosteriorMean(self._model, input_transform) + @override def to_botorch(self) -> GPyTorchModel: return self._model diff --git a/tests/test_gp.py b/tests/test_gp.py index ab5862d258..0ab219bca6 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -220,7 +220,7 @@ def test_botorch_preset(multitask: bool, preset: str): assert_frame_equal(posterior1, posterior2) -def test_posterior_mean_correct_under_different_bounds(): +def test_get_posterior_mean_correct_under_different_bounds(): """Posterior mean evaluates at correct physical points when bounds differ.""" from baybe.parameters.numerical import NumericalDiscreteParameter @@ -236,12 +236,12 @@ def test_posterior_mean_correct_under_different_bounds(): # Get the surrogate's prediction at x1=2.5 expected_mean = prior_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item() - # New GP on a WIDER search space [0, 10], using the posterior_mean property + # New GP on a WIDER search space [0, 10], using the get_posterior_mean method new_params = [NumericalDiscreteParameter("x1", values=[0.0, 2.5, 5.0, 7.5, 10.0])] new_ss = SearchSpace.from_product(new_params) new_surrogate = GaussianProcessSurrogate( - mean_or_factory=prior_surrogate.posterior_mean + mean_or_factory=prior_surrogate.get_posterior_mean ) new_meas = pd.DataFrame({"x1": [0.0, 10.0], "y": [0.0, 20.0]}) new_surrogate.fit(new_ss, prior_obj, new_meas) @@ -255,7 +255,7 @@ def test_posterior_mean_correct_under_different_bounds(): assert abs(actual_mean - expected_mean) < 1e-4 -def test_posterior_mean_same_bounds(): +def test_get_posterior_mean_same_bounds(): """Posterior mean is correct when both search spaces have the same bounds.""" from baybe.parameters.numerical import NumericalDiscreteParameter @@ -270,7 +270,7 @@ def test_posterior_mean_same_bounds(): expected_mean = prior_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item() new_surrogate = GaussianProcessSurrogate( - mean_or_factory=prior_surrogate.posterior_mean + mean_or_factory=prior_surrogate.get_posterior_mean ) new_surrogate.fit(ss, obj, meas) @@ -283,9 +283,11 @@ def test_posterior_mean_same_bounds(): assert abs(actual_mean - expected_mean) < 1e-4 -def test_posterior_mean_raises_if_not_fitted(): - """Accessing posterior_mean raises if the surrogate has not been fitted.""" +def test_get_posterior_mean_raises_if_not_fitted(): + """Calling get_posterior_mean raises if the surrogate has not been fitted.""" from baybe.exceptions import ModelNotTrainedError with pytest.raises(ModelNotTrainedError, match="must be fitted"): - GaussianProcessSurrogate().posterior_mean # noqa: B018 + GaussianProcessSurrogate().get_posterior_mean( + searchspace, objective, measurements + ) From cff91db4c316b7cc726521db601c60d7d08a1809 Mon Sep 17 00:00:00 2001 From: kalama-ai Date: Wed, 10 Jun 2026 16:12:57 +0200 Subject: [PATCH 4/9] Handle outcome standardization in get_posterior_mean - Add output normalization to posterior mean - override train() on _PosteriorMean to prevent fit_gpytorch_mll from recursively switching nested submodules to training mode, which would change the learned Standardize parameters - improve tests to use points from posterior mean --- baybe/surrogates/gaussian_process/core.py | 46 +++++++++++++++++++---- tests/test_gp.py | 37 ++++++++++++------ 2 files changed, 63 insertions(+), 20 deletions(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index 05733f30a7..b5af46678b 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -273,6 +273,8 @@ def get_posterior_mean( import gpytorch from botorch.models.transforms.input import Normalize + from baybe.utils.dataframe import to_tensor + if self._model is None: raise ModelNotTrainedError( f"'{self.__class__.__name__}' must be fitted before its " @@ -281,9 +283,7 @@ def get_posterior_mean( context = _ModelContext(searchspace, objective, measurements) - # The new GP applies its input normalization before calling this mean module, - # so x arrives in the new GP's scaled coordinate system. Undo that scaling - # before calling the pretrained GP — it will apply its own normalization. + # Undo the new GP's input normalization before querying the prior GP input_transform = Normalize( len(searchspace.comp_rep_columns), bounds=context.parameter_bounds, @@ -291,10 +291,30 @@ def get_posterior_mean( ) input_transform.eval() - class _PosteriorMean(gpytorch.means.Mean): - """GPyTorch mean using a trained GP's posterior as the mean function.""" + # Match the new GP's outcome standardization + from botorch.models.transforms.outcome import Standardize + + pre_transformed = objective._pre_transform(measurements, allow_extra=True) + train_y_tensor = to_tensor(pre_transformed) + if train_y_tensor.ndim == 1: + train_y_tensor = train_y_tensor.unsqueeze(-1) + outcome_transform = Standardize(m=train_y_tensor.shape[-1]) + outcome_transform(train_y_tensor) + outcome_transform.eval() - def __init__(self, gp: GPyTorchModel, input_transform: Normalize) -> None: + class _PosteriorMean(gpytorch.means.Mean): + """GPyTorch mean wrapping a trained GP's posterior. + + Overrides ``train`` to keep all children in eval mode, preventing + ``fit_gpytorch_mll`` from corrupting learned transform parameters. + """ + + def __init__( + self, + gp: GPyTorchModel, + input_transform: Normalize, + outcome_transform: Standardize, + ) -> None: super().__init__() self.gp = deepcopy(gp) for param in self.gp.parameters(): @@ -302,14 +322,24 @@ def __init__(self, gp: GPyTorchModel, input_transform: Normalize) -> None: self.gp.eval() self.gp.likelihood.eval() self.input_transform = input_transform + self.outcome_transform = outcome_transform + + @override + def train(self, mode: bool = True) -> _PosteriorMean: + """Set training mode without propagating to children.""" + self.training = mode + return self + @override def forward(self, x: Tensor) -> Tensor: """Compute the mean using the wrapped GP's posterior.""" with gpytorch.settings.fast_pred_var(): x_raw = self.input_transform.untransform(x) - return self.gp.posterior(x_raw).mean.squeeze(-1) + posterior_mean = self.gp.posterior(x_raw).mean + standardized, _ = self.outcome_transform(posterior_mean) + return standardized.squeeze(-1) - return _PosteriorMean(self._model, input_transform) + return _PosteriorMean(self._model, input_transform, outcome_transform) @override def to_botorch(self) -> GPyTorchModel: diff --git a/tests/test_gp.py b/tests/test_gp.py index 0ab219bca6..58bac71d78 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -243,14 +243,20 @@ def test_get_posterior_mean_correct_under_different_bounds(): new_surrogate = GaussianProcessSurrogate( mean_or_factory=prior_surrogate.get_posterior_mean ) - new_meas = pd.DataFrame({"x1": [0.0, 10.0], "y": [0.0, 20.0]}) + # Train on data that lies exactly on the prior mean to avoid kernel effects + training_points = pd.DataFrame({"x1": [0.0, 10.0]}) + with torch.no_grad(): + training_targets = prior_surrogate.posterior(training_points).mean + new_meas = pd.DataFrame( + { + "x1": training_points["x1"], + "y": training_targets.numpy().ravel(), + } + ) new_surrogate.fit(new_ss, prior_obj, new_meas) - # In the new space [0, 10], x1=2.5 normalizes to 0.25 - mean_module = new_surrogate._model.mean_module - x_normalized = torch.tensor([[0.25]]) - with torch.no_grad(): - actual_mean = mean_module(x_normalized).item() + # Test end-to-end: the posterior should match the prior mean + actual_mean = new_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item() assert abs(actual_mean - expected_mean) < 1e-4 @@ -272,13 +278,20 @@ def test_get_posterior_mean_same_bounds(): new_surrogate = GaussianProcessSurrogate( mean_or_factory=prior_surrogate.get_posterior_mean ) - new_surrogate.fit(ss, obj, meas) - - # x1=2.5 normalizes to 0.5 in [0, 5] - mean_module = new_surrogate._model.mean_module - x_normalized = torch.tensor([[0.5]]) + # Train on data that lies exactly on the prior mean + training_points = pd.DataFrame({"x1": [0.0, 5.0]}) with torch.no_grad(): - actual_mean = mean_module(x_normalized).item() + training_targets = prior_surrogate.posterior(training_points).mean + new_meas = pd.DataFrame( + { + "x1": training_points["x1"], + "y": training_targets.numpy().ravel(), + } + ) + new_surrogate.fit(ss, obj, new_meas) + + # Test end-to-end: the posterior should match the prior mean + actual_mean = new_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item() assert abs(actual_mean - expected_mean) < 1e-4 From 7098fe3ab2130d9eaa433a2688083524c47daef5 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 12 Jun 2026 10:17:00 +0200 Subject: [PATCH 5/9] Extract transform factories to couple _fit and get_posterior_mean Both methods must use identical transform logic: a change in one (e.g. replacing Normalize with a different input transform) must automatically apply to the other, or get_posterior_mean silently produces mismatched results. --- baybe/surrogates/gaussian_process/core.py | 57 +++++++++++++---------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index b5af46678b..a406b342f9 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -48,11 +48,12 @@ ) from baybe.utils.boolean import strtobool from baybe.utils.conversion import to_string +from baybe.utils.dataframe import to_tensor if TYPE_CHECKING: from botorch.models.gpytorch import GPyTorchModel - from botorch.models.transforms.input import InputTransform - from botorch.models.transforms.outcome import OutcomeTransform + from botorch.models.transforms.input import InputTransform, Normalize + from botorch.models.transforms.outcome import OutcomeTransform, Standardize from botorch.posteriors import Posterior from gpytorch.kernels import Kernel as GPyTorchKernel from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood @@ -211,6 +212,31 @@ class GaussianProcessSurrogate(Surrogate): _model = field(init=False, default=None, eq=False) """The actual model.""" + @staticmethod + def _make_input_transform(context: _ModelContext) -> Normalize: + """Create the input transform for the Gaussian process.""" + from botorch.models.transforms.input import Normalize + + return Normalize( + len(context.searchspace.comp_rep_columns), + bounds=context.parameter_bounds, + indices=context.numerical_indices, + ) + + @staticmethod + def _make_outcome_transform(context: _ModelContext) -> Standardize: + """Create the outcome transform for the Gaussian process.""" + from botorch.models.transforms.outcome import Standardize + + train_y = to_tensor( + context.objective._pre_transform(context.measurements, allow_extra=True) + ) + if train_y.ndim == 1: + train_y = train_y.unsqueeze(-1) + transform = Standardize(m=train_y.shape[-1]) + transform(train_y) # fits means/stdvs; GP will re-fit in train mode + return transform + @classmethod def from_preset( cls, @@ -271,9 +297,6 @@ def get_posterior_mean( from copy import deepcopy import gpytorch - from botorch.models.transforms.input import Normalize - - from baybe.utils.dataframe import to_tensor if self._model is None: raise ModelNotTrainedError( @@ -284,22 +307,11 @@ def get_posterior_mean( context = _ModelContext(searchspace, objective, measurements) # Undo the new GP's input normalization before querying the prior GP - input_transform = Normalize( - len(searchspace.comp_rep_columns), - bounds=context.parameter_bounds, - indices=context.numerical_indices, - ) + input_transform = self._make_input_transform(context) input_transform.eval() # Match the new GP's outcome standardization - from botorch.models.transforms.outcome import Standardize - - pre_transformed = objective._pre_transform(measurements, allow_extra=True) - train_y_tensor = to_tensor(pre_transformed) - if train_y_tensor.ndim == 1: - train_y_tensor = train_y_tensor.unsqueeze(-1) - outcome_transform = Standardize(m=train_y_tensor.shape[-1]) - outcome_transform(train_y_tensor) + outcome_transform = self._make_outcome_transform(context) outcome_transform.eval() class _PosteriorMean(gpytorch.means.Mean): @@ -366,7 +378,6 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior: @override def _fit(self, train_x: Tensor, train_y: Tensor) -> None: import botorch - from botorch.models.transforms import Normalize, Standardize assert self._searchspace is not None # provided by base class assert self._objective is not None # provided by base class @@ -393,12 +404,8 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: ### Input/output scaling # NOTE: For GPs, we let BoTorch handle scaling (see [Scaling Workaround] above) - input_transform = Normalize( - train_x.shape[-1], - bounds=context.parameter_bounds, - indices=context.numerical_indices, - ) - outcome_transform = Standardize(train_y.shape[-1]) + input_transform = self._make_input_transform(context) + outcome_transform = self._make_outcome_transform(context) ### Mean mean = self.mean_factory( From 3f19c10f0479ba9e61c5c3e29c4d3811d9af388b Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 12 Jun 2026 10:20:46 +0200 Subject: [PATCH 6/9] Simplify _PosteriorMean by closing over transforms --- baybe/surrogates/gaussian_process/core.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index a406b342f9..d311eed5f3 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -321,20 +321,13 @@ class _PosteriorMean(gpytorch.means.Mean): ``fit_gpytorch_mll`` from corrupting learned transform parameters. """ - def __init__( - self, - gp: GPyTorchModel, - input_transform: Normalize, - outcome_transform: Standardize, - ) -> None: + def __init__(self, gp: GPyTorchModel) -> None: super().__init__() self.gp = deepcopy(gp) for param in self.gp.parameters(): param.requires_grad = False self.gp.eval() self.gp.likelihood.eval() - self.input_transform = input_transform - self.outcome_transform = outcome_transform @override def train(self, mode: bool = True) -> _PosteriorMean: @@ -346,12 +339,12 @@ def train(self, mode: bool = True) -> _PosteriorMean: def forward(self, x: Tensor) -> Tensor: """Compute the mean using the wrapped GP's posterior.""" with gpytorch.settings.fast_pred_var(): - x_raw = self.input_transform.untransform(x) + x_raw = input_transform.untransform(x) posterior_mean = self.gp.posterior(x_raw).mean - standardized, _ = self.outcome_transform(posterior_mean) + standardized, _ = outcome_transform(posterior_mean) return standardized.squeeze(-1) - return _PosteriorMean(self._model, input_transform, outcome_transform) + return _PosteriorMean(self._model) @override def to_botorch(self) -> GPyTorchModel: From 688bd1a16262cb70abf713ae0f81f83ca9c5984b Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 12 Jun 2026 10:26:10 +0200 Subject: [PATCH 7/9] Adjust docstrings and error messages --- .../gaussian_process/components/mean.py | 3 ++- baybe/surrogates/gaussian_process/core.py | 17 +++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/baybe/surrogates/gaussian_process/components/mean.py b/baybe/surrogates/gaussian_process/components/mean.py index 759f862a89..5adefc4888 100644 --- a/baybe/surrogates/gaussian_process/components/mean.py +++ b/baybe/surrogates/gaussian_process/components/mean.py @@ -43,4 +43,5 @@ def __call__( return ConstantMean() -gc.collect() # Collect leftover original slotted classes created by attrs +# Collect leftover original slotted classes created by attrs +gc.collect() diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index d311eed5f3..56723c15b5 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -278,10 +278,11 @@ def get_posterior_mean( objective: Objective, measurements: pd.DataFrame, ) -> GPyTorchMean: - """Return a GPyTorch mean module representing this surrogate's posterior mean. + """Return a GPyTorch mean module representing the surrogate's posterior mean. - The bound method satisfies :class:`.MeanFactoryProtocol` and can be passed - directly as ``mean_or_factory`` to a new :class:`GaussianProcessSurrogate`. + The bound method satisfies + :class:`~baybe.surrogates.gaussian_process.components.mean.MeanFactoryProtocol` + and can be passed directly to a new :class:`GaussianProcessSurrogate`. Args: searchspace: The search space of the new GP being fitted. @@ -289,10 +290,10 @@ def get_posterior_mean( measurements: The training data of the new GP being fitted. Returns: - A GPyTorch mean module that evaluates this surrogate's posterior mean. + The posterior mean. Raises: - ModelNotTrainedError: If this surrogate has not been fitted yet. + ModelNotTrainedError: If the surrogate has not been fitted yet. """ from copy import deepcopy @@ -301,7 +302,7 @@ def get_posterior_mean( if self._model is None: raise ModelNotTrainedError( f"'{self.__class__.__name__}' must be fitted before its " - f"'get_posterior_mean' can be used as a mean function." + f"'{self.get_posterior_mean.__name__}' can be used as a mean function." ) context = _ModelContext(searchspace, objective, measurements) @@ -317,8 +318,8 @@ def get_posterior_mean( class _PosteriorMean(gpytorch.means.Mean): """GPyTorch mean wrapping a trained GP's posterior. - Overrides ``train`` to keep all children in eval mode, preventing - ``fit_gpytorch_mll`` from corrupting learned transform parameters. + Overrides ``train`` to keep all children in eval mode, preventing optimizers + from corrupting learned transform parameters. """ def __init__(self, gp: GPyTorchModel) -> None: From 45b82b6a4972d3746987675508d76110dac40c36 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 12 Jun 2026 10:35:12 +0200 Subject: [PATCH 8/9] Rename get_posterior_mean to posterior_mean_function --- baybe/surrogates/gaussian_process/core.py | 5 +++-- tests/test_gp.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index 56723c15b5..d40482b8ce 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -272,7 +272,7 @@ def from_preset( gp._custom_kernel = False # preset are first-party features return gp - def get_posterior_mean( + def posterior_mean_function( self, searchspace: SearchSpace, objective: Objective, @@ -302,7 +302,8 @@ def get_posterior_mean( if self._model is None: raise ModelNotTrainedError( f"'{self.__class__.__name__}' must be fitted before its " - f"'{self.get_posterior_mean.__name__}' can be used as a mean function." + f"'{self.posterior_mean_function.__name__}' can be used as a " + f"mean function." ) context = _ModelContext(searchspace, objective, measurements) diff --git a/tests/test_gp.py b/tests/test_gp.py index 58bac71d78..77feb417b8 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -241,7 +241,7 @@ def test_get_posterior_mean_correct_under_different_bounds(): new_ss = SearchSpace.from_product(new_params) new_surrogate = GaussianProcessSurrogate( - mean_or_factory=prior_surrogate.get_posterior_mean + mean_or_factory=prior_surrogate.posterior_mean_function ) # Train on data that lies exactly on the prior mean to avoid kernel effects training_points = pd.DataFrame({"x1": [0.0, 10.0]}) @@ -276,7 +276,7 @@ def test_get_posterior_mean_same_bounds(): expected_mean = prior_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item() new_surrogate = GaussianProcessSurrogate( - mean_or_factory=prior_surrogate.get_posterior_mean + mean_or_factory=prior_surrogate.posterior_mean_function ) # Train on data that lies exactly on the prior mean training_points = pd.DataFrame({"x1": [0.0, 5.0]}) @@ -301,6 +301,6 @@ def test_get_posterior_mean_raises_if_not_fitted(): from baybe.exceptions import ModelNotTrainedError with pytest.raises(ModelNotTrainedError, match="must be fitted"): - GaussianProcessSurrogate().get_posterior_mean( + GaussianProcessSurrogate().posterior_mean_function( searchspace, objective, measurements ) From 37bb6a6829942850612cecd4a655ba384f39fb18 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 12 Jun 2026 10:37:15 +0200 Subject: [PATCH 9/9] Update CHANGELOG.md --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c17d6cf2e7..8dc303e51f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] +### Added +- `posterior_mean_function` method to `GaussianProcessSurrogate` + ## [0.15.0] - 2026-06-11 ### Breaking Changes - `GaussianProcessSurrogate` no longer automatically adds a task kernel in multi-task