Add GaussianProcessSurrogate.posterior_mean property#823
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds a posterior_mean property to GaussianProcessSurrogate to enable transfer-learning-style initialization of a new GP’s mean function from a previously trained GP’s posterior mean, including correct handling of differing search space bounds by undoing the new GP’s normalization before querying the pretrained GP.
Changes:
- Add
GaussianProcessSurrogate.posterior_meanproperty that returns a mean factory backed by the surrogate’s trained BoTorch model. - Implement
_PosteriorMeanFactorymean factory that wraps a frozen copy of the pretrained GP and untransforms inputs based on the new search space bounds before evaluation. - Add tests validating correctness under same/different bounds and that accessing
posterior_meanbefore fitting raisesModelNotTrainedError.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
baybe/surrogates/gaussian_process/core.py |
Exposes posterior_mean as a mean factory derived from a fitted GP surrogate. |
baybe/surrogates/gaussian_process/components/mean.py |
Adds _PosteriorMeanFactory that builds a GPyTorch mean module from a pretrained BoTorch GP and blocks serialization. |
tests/test_gp.py |
Adds regression tests for correctness across differing bounds and unfitted access behavior. |
CHANGELOG.md |
Documents the new posterior_mean property addition. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| _pretrained_gp = field(alias="pretrained_gp") | ||
| """The pretrained BoTorch GP whose posterior mean is used as the mean function.""" |
| mean_module = new_surrogate._model.mean_module | ||
| x_normalized = torch.tensor([[0.25]]) | ||
| with torch.no_grad(): | ||
| actual_mean = mean_module(x_normalized).item() |
| mean_module = new_surrogate._model.mean_module | ||
| x_normalized = torch.tensor([[0.5]]) | ||
| with torch.no_grad(): | ||
| actual_mean = mean_module(x_normalized).item() |
AdrianSosic
left a comment
There was a problem hiding this comment.
Hey @kalama-ai, here two high-level design questions before I go into the detailed review
| """The actual model.""" | ||
|
|
||
| @property | ||
| def posterior_mean(self) -> MeanFactoryProtocol: |
There was a problem hiding this comment.
Now that I see the code, I think a method would actually be the better choice, for mainly three reasons:
- Does something rather non-trivial (as opposed to something like, for example,
def n_dimsthat only fetches a number) - Potentially raises an exception --> generally suboptimal for properties
- I can foresee that we might want to make this configurable, potentially even in this PR. For example, remember the dimensions we've generally discussed for transferring information from one GP to another? E.g. we can transfer hyperparameters only, or transfer data points, or transfer both. And we can freeze the hyperparameters or keep them learnable. Right now, you have implemented one specific choice here, but why not directly make this flexible?
There was a problem hiding this comment.
I initially went with a property because it felt more natural. posterior_mean feels like something the surrogate has rather than something it does. But I see your points and I am open to changing it to a method. This might also eliminate the need for _PosteriorMeanFactory. I will implment that variant and ping you.
There was a problem hiding this comment.
Hi @AdrianSosic , I implemented the method variant here: c8b138e. Could you please have a look?
|
|
||
|
|
||
| @define | ||
| class _PosteriorMeanFactory(MeanFactoryProtocol): |
There was a problem hiding this comment.
Need your input here: I honestly don't see the purpose of having a factory here, but perhaps I'm overlooking some aspect that you have in mind. IMO, returning the PosteriorMean directly would be the better approach, for the reasons below, but please convince me otherwise:
- In essence, the
posterior_meanfunction is already a factory! It's a callable that (depending on optional arguments, see other thread) produces a posterior mean object for you. In your approach, you therefore have a factory producing a factory that produces the mean --> why do we need that additional step, where there is zero additional configurability added in the chain? - Also, the
_PosteriorMeanFactoryclass as such – when considered in isolation – does not bring much additional value to the table since it's not even serializable.
So put together: why do we need it?
There was a problem hiding this comment.
This design is motivated by the intended use case of passing the posterior mean directly as the mean function of a new GP: GaussianProcessSurrogate(mean_or_factory=prior_gp.posterior_mean) and implementing posterior_mean as a property. This required returning a factory object rather than a gpytorch mean directly. The reason is that _PosteriorMean.forward(x) receives inputs already normalized by the new GP's input_transform and needs to undo that normalization before querying the pretrained GP's posterior mean. THis requires knowing the new GP's search space bounds. Those bounds are only available at fit time of the new GP, not at property access time.
That said, if posterior_mean were a method with signature (self, searchspace, objective, measurements) -> GPyTorchMean, the prior_gp.posterior_mean would satisfy MeanFactoryProtocol and could probably be passed directly. I'll draft that variant and ping you here once it is worth having a second look. .
There was a problem hiding this comment.
See above thread.
There was a problem hiding this comment.
Ah, thanks for the explanation, I didn't see the issue with the scaling. OK, then let's see how we tackle the factory situation later. For now, we need to get the logic right – the current one is unfortunately flawed. I played around with it yesterday and could already fix parts of it but not entirely. I think we need a call to discuss the problem 🙃
There was a problem hiding this comment.
Hi @AdrianSosic, thanks for providing a fix for the normaliaztion error. There was just one small issue remaining: when passing the posterior mean as a mean function to a new GP, training it would put the torch modules for normalization in training mode. I fixed this by overwriting train() on the posterior mean class: cb34884.
- posterior_mean returns a mean factory that can be passed directly to a new GaussianProcessSurrogate via mean_or_factory - the new GP normalizes inputs before passing them to the mean module, so the factory undoes that normalization before querying the pretrained GP, which then applies its own normalization internally - Raises ModelNotTrainedError if the surrogate has not been fitted yet
- Replace the posterior_mean property with get_posterior_mean() (with MeanFactoryProtocol signature) - method can be passed directly as mean_or_factory to a new GP - Remove _PosteriorMeanFactory from mean.py - Move _PosteriorMean class and normalization into the method
- Add output normalization to posterior mean - override train() on _PosteriorMean to prevent fit_gpytorch_mll from recursively switching nested submodules to training mode, which would change the learned Standardize parameters - improve tests to use points from posterior mean
Both methods must use identical transform logic: a change in one (e.g. replacing Normalize with a different input transform) must automatically apply to the other, or get_posterior_mean silently produces mismatched results.
a9e53b1 to
37bb6a6
Compare
Adds a
posterior_meanproperty toGaussianProcessSurrogatethat can be used as the mean function of a new GP.Normalization
GPs in BayBE normalize their inputs internally based on the bounds of the search space they were trained on. When the pretrained GP's mean module is used in a new GP, the new GP will normalize inputs according to its own bounds before passing them to the mean module. If the two search spaces have different bounds, this would cause the mean to be evaluated at the wrong point.
Example: A pretrained GP was trained on
x ∈ [0, 5]. It normalizesx=2.5to0.5. A new GP is trained on a wider spacex ∈ [0, 10]— it normalizesx=2.5to0.25. If the mean module received0.25and passed it straight to the pretrained GP, the pretrained GP would re-normalize it to0.25 * 5 = 1.25, which is the wrong physical point.The fix is to undo the new GP's normalization inside the mean module before querying the pretrained GP, which then applies its own normalization correctly. The factory builds this untransform at fit time, when the new search space bounds are known. Similarly, the factory matches the new GP's outcome standardization, ensuring the mean module returns values in the same standardized y-space that the new GP operates in.
Implementation
posterior_meanis backed by a private_PosteriorMeanFactoryinmean.py. When the new GP is fitted, the factory is called with the new search space and builds a frozen copy of the pretrained GP wrapped in a GPyTorch mean module._PosteriorMeanFactoryis not serializable, since it wraps a live BoTorch model. Attempting to serialize a surrogate using it will raise an error.posterior_meanbefore the surrogate has been fitted raisesModelNotTrainedError.Tests
Three tests were added to tests/test_gp.py:
[0, 5], then creates a new GP on[0, 10]usingget_posterior_mean. The new GP is trained on data lying exactly on the prior mean. Verifies that the new GP's posterior atx=2.5matches the pretrained GP's prediction atx=2.5.get_posterior_meanon an unfitted surrogate raisesModelNotTrainedError.