Skip to content

Add GaussianProcessSurrogate.posterior_mean property#823

Open
kalama-ai wants to merge 9 commits into
mainfrom
feature/posterior_mean_property
Open

Add GaussianProcessSurrogate.posterior_mean property#823
kalama-ai wants to merge 9 commits into
mainfrom
feature/posterior_mean_property

Conversation

@kalama-ai

@kalama-ai kalama-ai commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Adds a posterior_mean property to GaussianProcessSurrogate that can be used as the mean function of a new GP.

pretrained_gp = GaussianProcessSurrogate()                                                                                                                                                                                                                                 
pretrained_gp.fit(searchspace, objective, measurements)                                                                                                                                                                                                                    
 
new_gp = GaussianProcessSurrogate(mean_or_factory=pretrained_gp.posterior_mean)                                                                                                                                                                                            
new_gp.fit(new_searchspace, objective, new_measurements)

Normalization

GPs in BayBE normalize their inputs internally based on the bounds of the search space they were trained on. When the pretrained GP's mean module is used in a new GP, the new GP will normalize inputs according to its own bounds before passing them to the mean module. If the two search spaces have different bounds, this would cause the mean to be evaluated at the wrong point.

Example: A pretrained GP was trained on x ∈ [0, 5]. It normalizes x=2.5 to 0.5. A new GP is trained on a wider space x ∈ [0, 10] — it normalizes x=2.5 to 0.25. If the mean module received 0.25 and passed it straight to the pretrained GP, the pretrained GP would re-normalize it to 0.25 * 5 = 1.25, which is the wrong physical point.

The fix is to undo the new GP's normalization inside the mean module before querying the pretrained GP, which then applies its own normalization correctly. The factory builds this untransform at fit time, when the new search space bounds are known. Similarly, the factory matches the new GP's outcome standardization, ensuring the mean module returns values in the same standardized y-space that the new GP operates in.

Implementation

  • posterior_mean is backed by a private _PosteriorMeanFactory in mean.py. When the new GP is fitted, the factory is called with the new search space and builds a frozen copy of the pretrained GP wrapped in a GPyTorch mean module.
  • The pretrained GP's parameters are frozen (no gradients) and set to eval mode, so it only contributes its posterior mean, not learnable parameters.
  • _PosteriorMeanFactory is not serializable, since it wraps a live BoTorch model. Attempting to serialize a surrogate using it will raise an error.
  • Accessing posterior_mean before the surrogate has been fitted raises ModelNotTrainedError.

Tests

Three tests were added to tests/test_gp.py:

  • Different bounds: Trains a surrogate on [0, 5], then creates a new GP on [0, 10] using get_posterior_mean. The new GP is trained on data lying exactly on the prior mean. Verifies that the new GP's posterior at x=2.5 matches the pretrained GP's prediction at x=2.5.
  • Same bounds: Same check when both search spaces share the same bounds.
  • Unfitted: Verifies that calling get_posterior_mean on an unfitted surrogate raises ModelNotTrainedError.

@kalama-ai kalama-ai marked this pull request as ready for review June 9, 2026 11:25
@kalama-ai kalama-ai requested a review from Scienfitz as a code owner June 9, 2026 11:25
Copilot AI review requested due to automatic review settings June 9, 2026 11:25

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a posterior_mean property to GaussianProcessSurrogate to enable transfer-learning-style initialization of a new GP’s mean function from a previously trained GP’s posterior mean, including correct handling of differing search space bounds by undoing the new GP’s normalization before querying the pretrained GP.

Changes:

  • Add GaussianProcessSurrogate.posterior_mean property that returns a mean factory backed by the surrogate’s trained BoTorch model.
  • Implement _PosteriorMeanFactory mean factory that wraps a frozen copy of the pretrained GP and untransforms inputs based on the new search space bounds before evaluation.
  • Add tests validating correctness under same/different bounds and that accessing posterior_mean before fitting raises ModelNotTrainedError.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
baybe/surrogates/gaussian_process/core.py Exposes posterior_mean as a mean factory derived from a fitted GP surrogate.
baybe/surrogates/gaussian_process/components/mean.py Adds _PosteriorMeanFactory that builds a GPyTorch mean module from a pretrained BoTorch GP and blocks serialization.
tests/test_gp.py Adds regression tests for correctness across differing bounds and unfitted access behavior.
CHANGELOG.md Documents the new posterior_mean property addition.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +65 to +66
_pretrained_gp = field(alias="pretrained_gp")
"""The pretrained BoTorch GP whose posterior mean is used as the mean function."""
Comment thread tests/test_gp.py Outdated
Comment on lines +250 to +253
mean_module = new_surrogate._model.mean_module
x_normalized = torch.tensor([[0.25]])
with torch.no_grad():
actual_mean = mean_module(x_normalized).item()
Comment thread tests/test_gp.py Outdated
Comment on lines +278 to +281
mean_module = new_surrogate._model.mean_module
x_normalized = torch.tensor([[0.5]])
with torch.no_grad():
actual_mean = mean_module(x_normalized).item()

@AdrianSosic AdrianSosic left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @kalama-ai, here two high-level design questions before I go into the detailed review

"""The actual model."""

@property
def posterior_mean(self) -> MeanFactoryProtocol:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I see the code, I think a method would actually be the better choice, for mainly three reasons:

  1. Does something rather non-trivial (as opposed to something like, for example, def n_dims that only fetches a number)
  2. Potentially raises an exception --> generally suboptimal for properties
  3. I can foresee that we might want to make this configurable, potentially even in this PR. For example, remember the dimensions we've generally discussed for transferring information from one GP to another? E.g. we can transfer hyperparameters only, or transfer data points, or transfer both. And we can freeze the hyperparameters or keep them learnable. Right now, you have implemented one specific choice here, but why not directly make this flexible?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially went with a property because it felt more natural. posterior_mean feels like something the surrogate has rather than something it does. But I see your points and I am open to changing it to a method. This might also eliminate the need for _PosteriorMeanFactory. I will implment that variant and ping you.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AdrianSosic , I implemented the method variant here: c8b138e. Could you please have a look?



@define
class _PosteriorMeanFactory(MeanFactoryProtocol):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need your input here: I honestly don't see the purpose of having a factory here, but perhaps I'm overlooking some aspect that you have in mind. IMO, returning the PosteriorMean directly would be the better approach, for the reasons below, but please convince me otherwise:

  • In essence, the posterior_mean function is already a factory! It's a callable that (depending on optional arguments, see other thread) produces a posterior mean object for you. In your approach, you therefore have a factory producing a factory that produces the mean --> why do we need that additional step, where there is zero additional configurability added in the chain?
  • Also, the _PosteriorMeanFactory class as such – when considered in isolation – does not bring much additional value to the table since it's not even serializable.

So put together: why do we need it?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This design is motivated by the intended use case of passing the posterior mean directly as the mean function of a new GP: GaussianProcessSurrogate(mean_or_factory=prior_gp.posterior_mean) and implementing posterior_mean as a property. This required returning a factory object rather than a gpytorch mean directly. The reason is that _PosteriorMean.forward(x) receives inputs already normalized by the new GP's input_transform and needs to undo that normalization before querying the pretrained GP's posterior mean. THis requires knowing the new GP's search space bounds. Those bounds are only available at fit time of the new GP, not at property access time.

That said, if posterior_mean were a method with signature (self, searchspace, objective, measurements) -> GPyTorchMean, the prior_gp.posterior_mean would satisfy MeanFactoryProtocol and could probably be passed directly. I'll draft that variant and ping you here once it is worth having a second look. .

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above thread.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks for the explanation, I didn't see the issue with the scaling. OK, then let's see how we tackle the factory situation later. For now, we need to get the logic right – the current one is unfortunately flawed. I played around with it yesterday and could already fix parts of it but not entirely. I think we need a call to discuss the problem 🙃

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AdrianSosic, thanks for providing a fix for the normaliaztion error. There was just one small issue remaining: when passing the posterior mean as a mean function to a new GP, training it would put the torch modules for normalization in training mode. I fixed this by overwriting train() on the posterior mean class: cb34884.

kalama-ai and others added 9 commits June 12, 2026 10:36
  - posterior_mean returns a mean factory that can be passed directly to a new GaussianProcessSurrogate via mean_or_factory
  - the new GP normalizes inputs before passing them to the mean module, so the factory undoes that normalization before querying the pretrained GP, which then applies its own normalization internally
  - Raises ModelNotTrainedError if the surrogate has not been fitted yet
  - Replace the posterior_mean property with get_posterior_mean() (with MeanFactoryProtocol signature)
  - method can be passed directly as mean_or_factory to a new GP 
  - Remove _PosteriorMeanFactory from mean.py
  - Move _PosteriorMean class and normalization into the method
- Add output normalization to posterior mean
- override train() on _PosteriorMean to prevent fit_gpytorch_mll from recursively switching nested submodules to training mode, which would change the learned Standardize parameters
- improve tests to use points from posterior mean
Both methods must use identical transform logic: a change in one
(e.g. replacing Normalize with a different input transform) must
automatically apply to the other, or get_posterior_mean silently
produces mismatched results.
@AdrianSosic AdrianSosic force-pushed the feature/posterior_mean_property branch from a9e53b1 to 37bb6a6 Compare June 12, 2026 08:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants