Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions ax/adapter/tests/test_torch_moo_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
)
from ax.core.parameter_constraint import ParameterConstraint
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
from ax.generators.torch.botorch_moo import MultiObjectiveLegacyBoTorchGenerator
from ax.generators.torch.botorch_moo_defaults import (
infer_objective_thresholds,
pareto_frontier_evaluator,
Expand Down Expand Up @@ -104,7 +103,7 @@ def helper_test_pareto_frontier(
)
adapter = TorchAdapter(
experiment=exp,
generator=MultiObjectiveLegacyBoTorchGenerator(),
generator=BoTorchGenerator(),
)
with patch(
PARETO_FRONTIER_EVALUATOR_PATH, wraps=pareto_frontier_evaluator
Expand Down Expand Up @@ -272,7 +271,7 @@ def test_get_pareto_frontier_and_configs_input_validation(self) -> None:
experiment=exp,
search_space=exp.search_space,
data=exp.fetch_data(),
generator=MultiObjectiveLegacyBoTorchGenerator(),
generator=BoTorchGenerator(),
transforms=[],
)
observation_features = [
Expand Down Expand Up @@ -351,7 +350,7 @@ def test_hypervolume(self, _, cuda: bool = False) -> None:
)
adapter = TorchAdapter(
search_space=exp.search_space,
generator=MultiObjectiveLegacyBoTorchGenerator(),
generator=BoTorchGenerator(),
optimization_config=optimization_config,
transforms=[],
experiment=exp,
Expand Down Expand Up @@ -437,7 +436,7 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
data = exp.fetch_data()
adapter = TorchAdapter(
search_space=exp.search_space,
generator=MultiObjectiveLegacyBoTorchGenerator(),
generator=BoTorchGenerator(),
optimization_config=exp.optimization_config,
transforms=Cont_X_trans + Y_trans,
torch_device=torch.device("cuda" if cuda else "cpu"),
Expand Down Expand Up @@ -569,7 +568,7 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
set_rng_seed(0) # make model fitting deterministic
adapter = TorchAdapter(
search_space=exp.search_space,
generator=MultiObjectiveLegacyBoTorchGenerator(),
generator=BoTorchGenerator(),
optimization_config=exp.optimization_config,
transforms=ST_MTGP_trans,
experiment=exp,
Expand Down Expand Up @@ -633,7 +632,7 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
exp._trials = get_hss_trials_with_fixed_parameter(exp=exp)
adapter = TorchAdapter(
search_space=hss,
generator=MultiObjectiveLegacyBoTorchGenerator(),
generator=BoTorchGenerator(),
optimization_config=exp.optimization_config,
transforms=Cont_X_trans + Y_trans,
torch_device=torch.device("cuda" if cuda else "cpu"),
Expand Down Expand Up @@ -716,7 +715,7 @@ def test_best_point(self) -> None:
)
adapter = TorchAdapter(
search_space=exp.search_space,
generator=MultiObjectiveLegacyBoTorchGenerator(),
generator=BoTorchGenerator(),
optimization_config=exp.optimization_config,
transforms=[],
experiment=exp,
Expand Down
11 changes: 3 additions & 8 deletions ax/adapter/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
from ax.exceptions.generation_strategy import OptimizationConfigRequired
from ax.generators.torch.botorch import LegacyBoTorchGenerator
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
from ax.generators.torch.botorch_moo import MultiObjectiveLegacyBoTorchGenerator
from ax.generators.torch.botorch_moo_defaults import infer_objective_thresholds
from ax.generators.torch_base import TorchGenerator, TorchOptConfig
from ax.generators.types import TConfig
Expand Down Expand Up @@ -226,17 +225,13 @@ def infer_objective_thresholds(
"`infer_objective_thresholds` does not support risk measures."
)
# Infer objective thresholds.
if isinstance(self.generator, MultiObjectiveLegacyBoTorchGenerator):
model = self.generator.model
Xs = self.generator.Xs
elif isinstance(self.generator, BoTorchGenerator):
if isinstance(self.generator, BoTorchGenerator):
model = self.generator.surrogate.model
Xs = self.generator.surrogate.Xs
else:
raise UnsupportedError(
"Generator must be a MultiObjectiveLegacyBoTorchGenerator or an "
"appropriate Modular Botorch Generator to infer_objective_thresholds. "
f"Found {type(self.generator)}."
"Generator must be a Modular Botorch Generator to "
f"infer_objective_thresholds. Found {type(self.generator)}."
)

obj_thresholds = infer_objective_thresholds(
Expand Down
21 changes: 4 additions & 17 deletions ax/generators/tests/test_botorch_moo_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ax.core.search_space import SearchSpaceDigest
from ax.generators.torch.botorch_defaults import NO_OBSERVED_POINTS_MESSAGE
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
from ax.generators.torch.botorch_moo import MultiObjectiveLegacyBoTorchGenerator
from ax.generators.torch.botorch_moo_defaults import (
get_outcome_constraint_transforms,
get_qLogEHVI,
Expand All @@ -31,12 +30,10 @@
from ax.utils.common.testutils import TestCase
from ax.utils.testing.mock import mock_botorch_optimize_context_manager
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.model import Model
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.multi_objective.hypervolume import infer_reference_point
from botorch.utils.testing import MockModel, MockPosterior
from gpytorch.utils.warnings import NumericalWarning
from torch._tensor import Tensor


MOO_DEFAULTS_PATH: str = "ax.generators.torch.botorch_moo_defaults"
Expand Down Expand Up @@ -89,7 +86,7 @@ def setUp(self) -> None:
def test_pareto_frontier_raise_error_when_missing_data(self) -> None:
with self.assertRaises(ValueError):
pareto_frontier_evaluator(
model=MultiObjectiveLegacyBoTorchGenerator(),
model=BoTorchGenerator(),
objective_thresholds=self.objective_thresholds,
objective_weights=self.objective_weights,
Yvar=self.Yvar,
Expand Down Expand Up @@ -151,17 +148,7 @@ def test_pareto_frontier_evaluator_raw(self) -> None:
self.assertTrue(torch.equal(torch.tensor([], dtype=torch.long), indx))

def test_pareto_frontier_evaluator_predict(self) -> None:
def dummy_predict(
model: Model,
X: Tensor,
use_posterior_predictive: bool = False,
) -> tuple[Tensor, Tensor]:
# Add column to X that is a product of previous elements.
mean = torch.cat([X, torch.prod(X, dim=1).reshape(-1, 1)], dim=1)
cov = torch.zeros(mean.shape[0], mean.shape[1], mean.shape[1])
return mean, cov

model = MultiObjectiveLegacyBoTorchGenerator(model_predictor=dummy_predict)
model = BoTorchGenerator()
_fit_model(model=model, X=self.X, Y=self.Y, Yvar=self.Yvar)

Y, _, indx = pareto_frontier_evaluator(
Expand All @@ -171,11 +158,11 @@ def dummy_predict(
X=self.X,
)
pred = self.Y[2:4]
self.assertAllClose(Y, pred)
self.assertAllClose(Y, pred, rtol=1e-3)
self.assertTrue(torch.equal(torch.arange(2, 4), indx))

def test_pareto_frontier_evaluator_with_outcome_constraints(self) -> None:
model = MultiObjectiveLegacyBoTorchGenerator()
model = BoTorchGenerator()
Y, _, indx = pareto_frontier_evaluator(
model=model,
objective_weights=self.objective_weights,
Expand Down
Loading