Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 113 additions & 2 deletions baybe/searchspace/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from __future__ import annotations

import gc
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Collection, Iterable, Iterator, Sequence
from enum import Enum
from itertools import product
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, ClassVar, cast

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -433,6 +433,23 @@ def get_comp_rep_parameter_indices(
if col in p.comp_rep_columns
)

def _get_n_comp_rep_columns(
self,
name_or_selector: str | ParameterSelectorProtocol,
/,
) -> int:
"""Get the number of comp-rep columns for a parameter selection.

Args:
name_or_selector: Either the name of a single parameter or a selector
that filters parameters to be included.

Returns:
The number of columns in the computational representation associated
with the selected parameter(s).
"""
return len(self.get_comp_rep_parameter_indices(name_or_selector))

@staticmethod
def estimate_product_space_size(parameters: Iterable[Parameter]) -> MemorySize:
"""Estimate an upper bound for the memory size of a product space.
Expand Down Expand Up @@ -515,6 +532,100 @@ def get_parameters_by_name(self, names: Sequence[str]) -> tuple[Parameter, ...]:
names
) + self.continuous.get_parameters_by_name(names)

def _drop_parameters(self, names: Collection[str], /) -> _ReducedSearchSpace:
"""Return a reduced search space without the named parameters.

The returned object exposes only parameter information and blocks
access to constraints, subspaces, and transformation.

Args:
names: The names of the parameters to remove.

Raises:
ValueError: If any name does not match a parameter in the space.

Returns:
A reduced search space containing only parameter information.
"""
current_names = {p.name for p in self.parameters}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't you also precompute set_names = set(names) here to avoid the double calculation in both the if and the calculation of remaining?

if unknown := set(names) - current_names:
raise ValueError(
f"Parameter name(s) {unknown} not found in the search space. "
f"Available: {current_names}."
)
remaining = [p for p in self.parameters if p.name not in set(names)]

disc_params = [p for p in remaining if p.is_discrete]
cont_params = [p for p in remaining if p.is_continuous]

# Explicit comp_rep needed because transform() drops columns for empty inputs.
discrete = (
SubspaceDiscrete(
parameters=disc_params,
exp_rep=pd.DataFrame(columns=[p.name for p in disc_params]),
comp_rep=pd.DataFrame(
columns=[c for p in disc_params for c in p.comp_rep_columns]
),
)
if disc_params
else SubspaceDiscrete.empty()
)

continuous = (
SubspaceContinuous(
parameters=cont_params,
)
if cont_params
else SubspaceContinuous.empty()
)

return _ReducedSearchSpace(discrete=discrete, continuous=continuous)


@define(slots=False)
class _ReducedSearchSpace(SearchSpace):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we think that inheritance is the way to go here? This implementation clearly violates LSP, and we could probably also use Protocol instead. I am not saying that we should - in fact, my gut feeling is that using the inheritance here is the better way right now, but I want to at least open up the discussion and make sure that we all are happy with this design. Hence also looping in @Scienfitz and @AdrianSosic for their opinions :)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. But tbh I this is hopefully only a temporary solution and I think we can live with it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have also talked with @kalama-ai offline about a related issue, but my opinion: this is clearly a hack anyway and not a first-class new feature, so do whatever it takes to bring you to the goal as quickly as possible 😬

"""A lightweight search space exposing only parameter information.
Comment thread
AdrianSosic marked this conversation as resolved.

Provides access to parameter-related properties needed by kernel factory
calls. Blocks access to transformation, index-based lookups, and other
functionality requiring actual candidate data.

This class is not intended for direct construction. Use
:meth:`SearchSpace._drop_parameters` instead.
"""

_ALLOWED_ATTRIBUTES: ClassVar[frozenset[str]] = frozenset(
{
"discrete",
"continuous",
"parameters",
"parameter_names",
"comp_rep_columns",
"constraints",
"type",
"_task_parameter",
"n_tasks",
"_get_n_comp_rep_columns",
"get_comp_rep_parameter_indices",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now you have added this one here to the list, which shouldn't be there. I know, _get_n_comp_rep_columns needs it, but you can work around that in its implementation

"get_parameters_by_name",
"_ALLOWED_ATTRIBUTES",
}
Comment thread
kalama-ai marked this conversation as resolved.
)
"""Attributes accessible on this reduced search space."""

@override
def __getattribute__(self, name: str):
"""Guard attribute access, allowing only parameter-related attributes."""
if name.startswith("__"):
return object.__getattribute__(self, name)
allowed = object.__getattribute__(self, "_ALLOWED_ATTRIBUTES")
if name in allowed:
return object.__getattribute__(self, name)
raise NotImplementedError(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this rather raise an AttributeError?

f"'{object.__getattribute__(self, '__class__').__name__}' does not "
f"support attribute '{name}'. Only parameter information is available."
)


def to_searchspace(
x: Parameter | SubspaceDiscrete | SubspaceContinuous | SearchSpace, /
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class _MLLForNonTLFitCriterionFactory(FitCriterionFactoryProtocol):
def __call__(
self, searchspace: SearchSpace, objective: Objective, measurements: pd.DataFrame
) -> FitCriterion:
if searchspace.task_idx is None:
if searchspace.n_tasks == 1:
return FitCriterion.MARGINAL_LOG_LIKELIHOOD

from baybe.surrogates.gaussian_process.presets.baybe import (
Expand Down
10 changes: 4 additions & 6 deletions baybe/surrogates/gaussian_process/components/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,8 @@ def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...]:

def _get_effective_dimensionality(self, searchspace: SearchSpace) -> int:
"""Get the number of computational columns for the selected parameters."""
return len(
searchspace.get_comp_rep_parameter_indices(
self.parameter_selector or (lambda _: True)
)
return searchspace._get_n_comp_rep_columns(
self.parameter_selector or (lambda _: True)
)

def _validate_parameter_kinds(self, parameters: Iterable[Parameter]) -> None:
Expand Down Expand Up @@ -213,7 +211,7 @@ def __call__(
target_cls._supported_parameter_kinds = broadened_kinds
self.parameter_selector = original_selector

if searchspace.task_idx is not None:
if searchspace.n_tasks > 1:
icm = ICMKernelFactory(base_kernel_or_factory=base_kernel)
return icm(searchspace, objective, measurements)
return base_kernel
Expand Down Expand Up @@ -301,7 +299,7 @@ def _validate_task_kernel_factory(self, _, factory: KernelFactoryProtocol):
def __call__(
self, searchspace: SearchSpace, objective: Objective, measurements: pd.DataFrame
) -> Kernel | GPyTorchKernel:
if searchspace.task_idx is None:
if searchspace.n_tasks == 1:
raise IncompatibleSearchSpaceError(
f"'{type(self).__name__}' can only be used with a searchspace that "
f"contains a '{TaskParameter.__name__}'."
Expand Down
2 changes: 2 additions & 0 deletions baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
)

### Kernel
# TODO: When calling a factory on a `_ReducedSearchSpace`, validate that

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure, this is intended to be done in a follow-up PR I assume and not in this one?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was only meant to highlight where the new _ReducedSearchSPaceclass would be used for the first time. Our plan is to pass it only to BayBE factories and block gpytorch kernels. This is why the current allow list in the class blocks everything index related.

# it returns a BayBE Kernel (not a raw gpytorch kernel).
kernel = self.kernel_factory(
context.searchspace, context.objective, context.measurements
)
Expand Down
2 changes: 1 addition & 1 deletion baybe/surrogates/gaussian_process/presets/baybe.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __call__(
) -> FitCriterion:
return (
FitCriterion.MARGINAL_LOG_LIKELIHOOD
if searchspace.task_idx is None
if searchspace.n_tasks == 1
else FitCriterion.LEAVE_ONE_OUT_PSEUDOLIKELIHOOD
)

Expand Down
158 changes: 158 additions & 0 deletions tests/test_searchspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,161 @@ def test_sample_from_polytope_mixed_constraints_with_interpoint():
# Verify interpoint constraint is satisfied across the batch
interpoint_constraint_result = samples["Conti_finite1"].sum()
assert np.isclose(interpoint_constraint_result, 0.6, atol=1e-6)


@pytest.fixture(name="reduced_searchspace")
def fixture_reduced_searchspace():
"""A reduced search space with the task parameter removed."""
ss = SearchSpace.from_product(
[
CategoricalParameter("Color", ["red", "blue"]),
NumericalContinuousParameter("x", (0.0, 1.0)),
TaskParameter("Task", ["A", "B"]),
]
)
return ss._drop_parameters({"Task"})


def test_reduced_parameters(reduced_searchspace):
"""Verify that the reduced space exposes only the remaining parameters."""
names = {p.name for p in reduced_searchspace.parameters}
assert names == {"Color", "x"}


def test_reduced_parameter_names(reduced_searchspace):
"""Verify that parameter_names derives from the remaining parameters."""
assert set(reduced_searchspace.parameter_names) == {"Color", "x"}


def test_reduced_comp_rep_columns(reduced_searchspace):
"""Verify that comp_rep_columns matches the remaining parameters' columns."""
expected = set()
for p in reduced_searchspace.parameters:
expected.update(p.comp_rep_columns)
assert set(reduced_searchspace.comp_rep_columns) == expected


def test_reduced_n_tasks(reduced_searchspace):
"""Verify that n_tasks is 1 when no task parameter is present."""
assert reduced_searchspace.n_tasks == 1


def test_reduced_get_n_comp_rep_columns(reduced_searchspace):
"""Verify that _get_n_comp_rep_columns works on the reduced space."""
assert reduced_searchspace._get_n_comp_rep_columns("x") == 1
assert reduced_searchspace._get_n_comp_rep_columns("Color") == 2


def test_reduced_get_comp_rep_parameter_indices(reduced_searchspace):
"""Verify that get_comp_rep_parameter_indices works on the reduced space."""
indices = reduced_searchspace.get_comp_rep_parameter_indices("x")
assert len(indices) == 1
indices = reduced_searchspace.get_comp_rep_parameter_indices("Color")
assert len(indices) == 2


def test_reduced_blocked_attributes(reduced_searchspace):
"""Verify that all non-allowed attributes raise NotImplementedError."""
from baybe.searchspace.core import _ReducedSearchSpace

allowed = _ReducedSearchSpace._ALLOWED_ATTRIBUTES

# All public non-dunder attributes of SearchSpace
all_attrs = {name for name in dir(SearchSpace) if not name.startswith("_")}

# Attributes that should be blocked (not in allowlist)
blocked = all_attrs - allowed

for attr in sorted(blocked):
with pytest.raises(NotImplementedError, match=attr):
getattr(reduced_searchspace, attr)


def test_reduced_repr(reduced_searchspace):
"""Verify that repr does not crash on a reduced search space."""
result = repr(reduced_searchspace)
assert "_ReducedSearchSpace" in result


def test_reduced_str(reduced_searchspace):
"""Verify that str does not crash on a reduced search space."""
result = str(reduced_searchspace)
assert isinstance(result, str)


def test_reduced_eq(reduced_searchspace):
"""Verify that equality comparison works on reduced search spaces."""
ss = SearchSpace.from_product(
[
CategoricalParameter("Color", ["red", "blue"]),
NumericalContinuousParameter("x", (0.0, 1.0)),
TaskParameter("Task", ["A", "B"]),
]
)
other = ss._drop_parameters({"Task"})
assert reduced_searchspace == other


def test_reduced_unknown_parameter_raises():
"""Verify that removing a nonexistent parameter raises an error."""
ss = SearchSpace.from_product(
[
CategoricalParameter("Color", ["red", "blue"]),
NumericalContinuousParameter("x", (0.0, 1.0)),
]
)
with pytest.raises(ValueError, match="not found"):
ss._drop_parameters({"nonexistent"})


def test_reduced_multiple_parameters_removed():
"""Verify that multiple parameters can be removed at once."""
ss = SearchSpace.from_product(
[
CategoricalParameter("Color", ["red", "blue"]),
NumericalContinuousParameter("x", (0.0, 1.0)),
TaskParameter("Task", ["A", "B"]),
]
)
reduced = ss._drop_parameters({"Task", "Color"})
assert reduced.parameter_names == ("x",)


def test_reduced_kernel_product_matches_default_factory():
"""Verify that manual kernel split matches BayBE's default factory."""
import torch

from baybe.objectives.single import SingleTargetObjective
from baybe.surrogates.gaussian_process.presets.baybe import (
BayBEKernelFactory,
_BayBENumericalKernelFactory,
_BayBETaskKernelFactory,
)
from baybe.targets import NumericalTarget

task_param = TaskParameter("Task", ["A", "B", "C"])
cat_param = CategoricalParameter("Color", ["red", "blue", "green"])
cont_param = NumericalContinuousParameter("x", (0.0, 1.0))
searchspace = SearchSpace.from_product([cat_param, cont_param, task_param])

objective = SingleTargetObjective(NumericalTarget("y"))
measurements = pd.DataFrame(
{"Color": ["red"], "x": [0.5], "Task": ["A"], "y": [1.0]}
)

kernel_default = BayBEKernelFactory()(searchspace, objective, measurements)

reduced_ss = searchspace._drop_parameters({"Task"})
base_baybe = _BayBENumericalKernelFactory()(reduced_ss, objective, measurements)
base_gpytorch = base_baybe.to_gpytorch(searchspace)

task_only_ss = SearchSpace.from_product([task_param])
task_baybe = _BayBETaskKernelFactory()(task_only_ss, objective, measurements)
task_gpytorch = task_baybe.to_gpytorch(searchspace)

kernel_manual = base_gpytorch * task_gpytorch

assert type(kernel_default) is type(kernel_manual)
for k_def, k_man in zip(kernel_default.kernels, kernel_manual.kernels):
assert type(k_def) is type(k_man)
assert torch.equal(k_def.active_dims, k_man.active_dims)
Loading