diff --git a/ax/adapter/tests/test_torch_moo_adapter.py b/ax/adapter/tests/test_torch_moo_adapter.py index 337c6d85676..da7b941b0f0 100644 --- a/ax/adapter/tests/test_torch_moo_adapter.py +++ b/ax/adapter/tests/test_torch_moo_adapter.py @@ -32,7 +32,6 @@ ) from ax.core.parameter_constraint import ParameterConstraint from ax.generators.torch.botorch_modular.generator import BoTorchGenerator -from ax.generators.torch.botorch_moo import MultiObjectiveLegacyBoTorchGenerator from ax.generators.torch.botorch_moo_defaults import ( infer_objective_thresholds, pareto_frontier_evaluator, @@ -104,7 +103,7 @@ def helper_test_pareto_frontier( ) adapter = TorchAdapter( experiment=exp, - generator=MultiObjectiveLegacyBoTorchGenerator(), + generator=BoTorchGenerator(), ) with patch( PARETO_FRONTIER_EVALUATOR_PATH, wraps=pareto_frontier_evaluator @@ -272,7 +271,7 @@ def test_get_pareto_frontier_and_configs_input_validation(self) -> None: experiment=exp, search_space=exp.search_space, data=exp.fetch_data(), - generator=MultiObjectiveLegacyBoTorchGenerator(), + generator=BoTorchGenerator(), transforms=[], ) observation_features = [ @@ -351,7 +350,7 @@ def test_hypervolume(self, _, cuda: bool = False) -> None: ) adapter = TorchAdapter( search_space=exp.search_space, - generator=MultiObjectiveLegacyBoTorchGenerator(), + generator=BoTorchGenerator(), optimization_config=optimization_config, transforms=[], experiment=exp, @@ -437,7 +436,7 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None: data = exp.fetch_data() adapter = TorchAdapter( search_space=exp.search_space, - generator=MultiObjectiveLegacyBoTorchGenerator(), + generator=BoTorchGenerator(), optimization_config=exp.optimization_config, transforms=Cont_X_trans + Y_trans, torch_device=torch.device("cuda" if cuda else "cpu"), @@ -569,7 +568,7 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None: set_rng_seed(0) # make model fitting deterministic adapter = TorchAdapter( search_space=exp.search_space, - generator=MultiObjectiveLegacyBoTorchGenerator(), + generator=BoTorchGenerator(), optimization_config=exp.optimization_config, transforms=ST_MTGP_trans, experiment=exp, @@ -633,7 +632,7 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None: exp._trials = get_hss_trials_with_fixed_parameter(exp=exp) adapter = TorchAdapter( search_space=hss, - generator=MultiObjectiveLegacyBoTorchGenerator(), + generator=BoTorchGenerator(), optimization_config=exp.optimization_config, transforms=Cont_X_trans + Y_trans, torch_device=torch.device("cuda" if cuda else "cpu"), @@ -716,7 +715,7 @@ def test_best_point(self) -> None: ) adapter = TorchAdapter( search_space=exp.search_space, - generator=MultiObjectiveLegacyBoTorchGenerator(), + generator=BoTorchGenerator(), optimization_config=exp.optimization_config, transforms=[], experiment=exp, diff --git a/ax/adapter/torch.py b/ax/adapter/torch.py index e0af708e570..66dd7ed289f 100644 --- a/ax/adapter/torch.py +++ b/ax/adapter/torch.py @@ -68,7 +68,6 @@ from ax.exceptions.generation_strategy import OptimizationConfigRequired from ax.generators.torch.botorch import LegacyBoTorchGenerator from ax.generators.torch.botorch_modular.generator import BoTorchGenerator -from ax.generators.torch.botorch_moo import MultiObjectiveLegacyBoTorchGenerator from ax.generators.torch.botorch_moo_defaults import infer_objective_thresholds from ax.generators.torch_base import TorchGenerator, TorchOptConfig from ax.generators.types import TConfig @@ -226,17 +225,13 @@ def infer_objective_thresholds( "`infer_objective_thresholds` does not support risk measures." ) # Infer objective thresholds. - if isinstance(self.generator, MultiObjectiveLegacyBoTorchGenerator): - model = self.generator.model - Xs = self.generator.Xs - elif isinstance(self.generator, BoTorchGenerator): + if isinstance(self.generator, BoTorchGenerator): model = self.generator.surrogate.model Xs = self.generator.surrogate.Xs else: raise UnsupportedError( - "Generator must be a MultiObjectiveLegacyBoTorchGenerator or an " - "appropriate Modular Botorch Generator to infer_objective_thresholds. " - f"Found {type(self.generator)}." + "Generator must be a Modular Botorch Generator to " + f"infer_objective_thresholds. Found {type(self.generator)}." ) obj_thresholds = infer_objective_thresholds( diff --git a/ax/generators/tests/test_botorch_moo_defaults.py b/ax/generators/tests/test_botorch_moo_defaults.py index 6064ffe8a6e..1f0de5fc4d0 100644 --- a/ax/generators/tests/test_botorch_moo_defaults.py +++ b/ax/generators/tests/test_botorch_moo_defaults.py @@ -16,7 +16,6 @@ from ax.core.search_space import SearchSpaceDigest from ax.generators.torch.botorch_defaults import NO_OBSERVED_POINTS_MESSAGE from ax.generators.torch.botorch_modular.generator import BoTorchGenerator -from ax.generators.torch.botorch_moo import MultiObjectiveLegacyBoTorchGenerator from ax.generators.torch.botorch_moo_defaults import ( get_outcome_constraint_transforms, get_qLogEHVI, @@ -31,12 +30,10 @@ from ax.utils.common.testutils import TestCase from ax.utils.testing.mock import mock_botorch_optimize_context_manager from botorch.models.gp_regression import SingleTaskGP -from botorch.models.model import Model from botorch.utils.datasets import SupervisedDataset from botorch.utils.multi_objective.hypervolume import infer_reference_point from botorch.utils.testing import MockModel, MockPosterior from gpytorch.utils.warnings import NumericalWarning -from torch._tensor import Tensor MOO_DEFAULTS_PATH: str = "ax.generators.torch.botorch_moo_defaults" @@ -89,7 +86,7 @@ def setUp(self) -> None: def test_pareto_frontier_raise_error_when_missing_data(self) -> None: with self.assertRaises(ValueError): pareto_frontier_evaluator( - model=MultiObjectiveLegacyBoTorchGenerator(), + model=BoTorchGenerator(), objective_thresholds=self.objective_thresholds, objective_weights=self.objective_weights, Yvar=self.Yvar, @@ -151,17 +148,7 @@ def test_pareto_frontier_evaluator_raw(self) -> None: self.assertTrue(torch.equal(torch.tensor([], dtype=torch.long), indx)) def test_pareto_frontier_evaluator_predict(self) -> None: - def dummy_predict( - model: Model, - X: Tensor, - use_posterior_predictive: bool = False, - ) -> tuple[Tensor, Tensor]: - # Add column to X that is a product of previous elements. - mean = torch.cat([X, torch.prod(X, dim=1).reshape(-1, 1)], dim=1) - cov = torch.zeros(mean.shape[0], mean.shape[1], mean.shape[1]) - return mean, cov - - model = MultiObjectiveLegacyBoTorchGenerator(model_predictor=dummy_predict) + model = BoTorchGenerator() _fit_model(model=model, X=self.X, Y=self.Y, Yvar=self.Yvar) Y, _, indx = pareto_frontier_evaluator( @@ -171,11 +158,11 @@ def dummy_predict( X=self.X, ) pred = self.Y[2:4] - self.assertAllClose(Y, pred) + self.assertAllClose(Y, pred, rtol=1e-3) self.assertTrue(torch.equal(torch.arange(2, 4), indx)) def test_pareto_frontier_evaluator_with_outcome_constraints(self) -> None: - model = MultiObjectiveLegacyBoTorchGenerator() + model = BoTorchGenerator() Y, _, indx = pareto_frontier_evaluator( model=model, objective_weights=self.objective_weights, diff --git a/ax/generators/tests/test_botorch_moo_model.py b/ax/generators/tests/test_botorch_moo_model.py deleted file mode 100644 index e6b5ea38e90..00000000000 --- a/ax/generators/tests/test_botorch_moo_model.py +++ /dev/null @@ -1,882 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import dataclasses -from contextlib import ExitStack -from typing import Any -from unittest import mock - -import ax.generators.torch.botorch_moo_defaults as botorch_moo_defaults -import botorch.utils.multi_objective.hypervolume as hypervolume -import numpy as np -import torch -from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import AxError -from ax.generators.torch.botorch_defaults import get_qLogNEI -from ax.generators.torch.botorch_modular.optimizer_defaults import INIT_BATCH_LIMIT -from ax.generators.torch.botorch_moo import MultiObjectiveLegacyBoTorchGenerator -from ax.generators.torch.botorch_moo_defaults import ( - get_EHVI, - get_NEHVI, - get_qLogEHVI, - get_qLogNEHVI, - infer_objective_thresholds, -) -from ax.generators.torch.utils import HYPERSPHERE -from ax.generators.torch_base import TorchOptConfig -from ax.utils.common.testutils import TestCase -from ax.utils.testing.mock import mock_botorch_optimize -from ax.utils.testing.torch_stubs import get_torch_test_data -from botorch.acquisition.multi_objective import ( - logei as moo_logei, - monte_carlo as moo_monte_carlo, -) -from botorch.models import ModelListGP -from botorch.models.gp_regression import SingleTaskGP -from botorch.models.transforms.input import Warp -from botorch.optim.optimize import optimize_acqf_list -from botorch.sampling.normal import IIDNormalSampler -from botorch.utils.datasets import SupervisedDataset -from botorch.utils.multi_objective.hypervolume import infer_reference_point -from botorch.utils.multi_objective.scalarization import get_chebyshev_scalarization -from botorch.utils.testing import MockPosterior - - -FIT_MODEL_MO_PATH = "ax.generators.torch.botorch_defaults.fit_gpytorch_mll" -SAMPLE_SIMPLEX_UTIL_PATH = "ax.generators.torch.utils.sample_simplex" -SAMPLE_HYPERSPHERE_UTIL_PATH = "ax.generators.torch.utils.sample_hypersphere" -CHEBYSHEV_SCALARIZATION_PATH = ( - "ax.generators.torch.botorch_defaults.get_chebyshev_scalarization" -) -NEHVI_ACQF_PATH = ( - "botorch.acquisition.factory.moo_monte_carlo.qNoisyExpectedHypervolumeImprovement" -) -EHVI_ACQF_PATH = ( - "botorch.acquisition.factory.moo_monte_carlo.qExpectedHypervolumeImprovement" -) -LOG_NEHVI_ACQF_PATH = ( - "botorch.acquisition.factory.moo_logei.qLogNoisyExpectedHypervolumeImprovement" -) -LOG_EHVI_ACQF_PATH = ( - "botorch.acquisition.factory.moo_logei.qLogExpectedHypervolumeImprovement" -) -NOISY_PARTITIONING_PATH = ( - "botorch.utils.multi_objective.hypervolume.FastNondominatedPartitioning" -) -PARTITIONING_PATH = "botorch.acquisition.factory.FastNondominatedPartitioning" - - -def dummy_func(X: torch.Tensor) -> torch.Tensor: - return X - - -class BotorchMOOModelTest(TestCase): - def test_BotorchMOOModel_double(self) -> None: - self.test_BotorchMOOModel_with_random_scalarization(dtype=torch.double) - - def test_BotorchMOOModel_cuda(self) -> None: - if torch.cuda.is_available(): - for dtype in (torch.float, torch.double): - self.test_BotorchMOOModel_with_random_scalarization( - dtype=dtype, cuda=True - ) - for use_noisy in (True, False): - # test qLog(N)EHVI - self.test_BotorchMOOModel_with_qehvi( - dtype=dtype, cuda=True, use_noisy=use_noisy, use_log=True - ) - self.test_BotorchMOOModel_with_qehvi_and_outcome_constraints( - dtype=dtype, cuda=True, use_noisy=use_noisy, use_log=True - ) - - def test_BotorchMOOModel_with_qnehvi(self) -> None: - # testing non-log version - for dtype in (torch.float, torch.double): - self.test_BotorchMOOModel_with_qehvi( - dtype=dtype, use_noisy=True, use_log=False - ) - self.test_BotorchMOOModel_with_qehvi_and_outcome_constraints( - dtype=dtype, use_noisy=True, use_log=False - ) - - def test_BotorchMOOModel_with_qlognehvi(self) -> None: - for dtype in (torch.float, torch.double): - self.test_BotorchMOOModel_with_qehvi( - dtype=dtype, use_noisy=True, use_log=True - ) - self.test_BotorchMOOModel_with_qehvi_and_outcome_constraints( - dtype=dtype, use_noisy=True, use_log=True - ) - - @mock_botorch_optimize - def test_BotorchMOOModel_with_random_scalarization( - self, - dtype: torch.dtype = torch.float, - cuda: bool = False, - ) -> None: - tkwargs: dict[str, Any] = { - "device": torch.device("cuda") if cuda else torch.device("cpu"), - "dtype": dtype, - } - ( - Xs, - Ys, - Yvars, - bounds, - tfs, - feature_names, - _, - ) = get_torch_test_data(dtype=dtype, cuda=cuda, constant_noise=True) - training_data = [ - SupervisedDataset( - X=Xs, - Y=Ys, - Yvar=Yvars, - feature_names=feature_names, - outcome_names=[name], - ) - for name in ["m1", "m2"] - ] - - n = 3 - objective_weights = torch.tensor([1.0, 1.0], **tkwargs) - obj_t = torch.tensor([1.0, 1.0], **tkwargs) - - search_space_digest = SearchSpaceDigest( - feature_names=feature_names, - bounds=bounds, - task_features=tfs, - ) - model = MultiObjectiveLegacyBoTorchGenerator(acqf_constructor=get_qLogNEI) - with mock.patch(FIT_MODEL_MO_PATH) as _mock_fit_model: - model.fit( - datasets=training_data, - search_space_digest=search_space_digest, - ) - _mock_fit_model.assert_called_once() - - torch_opt_config = TorchOptConfig( - objective_weights=objective_weights, - objective_thresholds=obj_t, - model_gen_options={ - "acquisition_function_kwargs": {"random_scalarization": True}, - "subset_model": False, - }, - is_moo=True, - ) - with self.assertRaisesRegex(NotImplementedError, "Best observed"): - model.best_point( - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - with mock.patch( - SAMPLE_SIMPLEX_UTIL_PATH, - autospec=True, - return_value=torch.tensor([0.7, 0.3], **tkwargs), - ) as _mock_sample_simplex: - model.gen( - n, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - # Sample_simplex should be called once for generated candidate. - self.assertEqual(n, _mock_sample_simplex.call_count) - - torch_opt_config.model_gen_options["acquisition_function_kwargs"] = { - "random_scalarization": True, - "random_scalarization_distribution": HYPERSPHERE, - } - with mock.patch( - SAMPLE_HYPERSPHERE_UTIL_PATH, - autospec=True, - return_value=torch.tensor([0.6, 0.8], **tkwargs), - ) as _mock_sample_hypersphere: - model.gen( - n, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - # Sample_simplex should be called once per generated candidate. - self.assertEqual(n, _mock_sample_hypersphere.call_count) - - # test input warping - self.assertFalse(model.use_input_warping) - model = MultiObjectiveLegacyBoTorchGenerator( - acqf_constructor=get_qLogNEI, - use_input_warping=True, - ) - model.fit( - datasets=training_data, - search_space_digest=search_space_digest, - ) - self.assertTrue(model.use_input_warping) - self.assertIsInstance(model.model, ModelListGP) - # pyre-fixme[16]: Optional type has no attribute `models`. - for m in model.model.models: - self.assertTrue(hasattr(m, "input_transform")) - self.assertIsInstance(m.input_transform, Warp) - self.assertFalse(hasattr(model.model, "input_transform")) - - # test loocv pseudo likelihood - self.assertFalse(model.use_loocv_pseudo_likelihood) - model = MultiObjectiveLegacyBoTorchGenerator( - acqf_constructor=get_qLogNEI, - use_loocv_pseudo_likelihood=True, - ) - model.fit( - datasets=training_data, - search_space_digest=search_space_digest, - ) - self.assertTrue(model.use_loocv_pseudo_likelihood) - - @mock_botorch_optimize - def test_BotorchMOOModel_with_chebyshev_scalarization( - self, - dtype: torch.dtype = torch.float, - cuda: bool = False, - ) -> None: - tkwargs: dict[str, Any] = { - "device": torch.device("cuda") if cuda else torch.device("cpu"), - "dtype": dtype, - } - ( - Xs, - Ys, - Yvars, - bounds, - tfs, - feature_names, - _, - ) = get_torch_test_data(dtype=dtype, cuda=cuda, constant_noise=True) - training_data = [ - SupervisedDataset( - X=Xs, - Y=Ys, - Yvar=Yvars, - feature_names=feature_names, - outcome_names=[name], - ) - for name in ["m1", "m2"] - ] - - n = 3 - objective_weights = torch.tensor([1.0, 1.0], **tkwargs) - obj_t = torch.tensor([1.0, 1.0], **tkwargs) - - search_space_digest = SearchSpaceDigest( - feature_names=feature_names, - bounds=bounds, - task_features=tfs, - ) - model = MultiObjectiveLegacyBoTorchGenerator(acqf_constructor=get_qLogNEI) - with mock.patch(FIT_MODEL_MO_PATH) as _mock_fit_model: - model.fit( - datasets=training_data, - search_space_digest=search_space_digest, - ) - _mock_fit_model.assert_called_once() - - torch_opt_config = TorchOptConfig( - objective_weights=objective_weights, - objective_thresholds=obj_t, - model_gen_options={ - "acquisition_function_kwargs": {"chebyshev_scalarization": True}, - "optimizer_kwargs": {"options": {"batch_limit": 1}}, - }, - ) - with mock.patch( - CHEBYSHEV_SCALARIZATION_PATH, wraps=get_chebyshev_scalarization - ) as _mock_chebyshev_scalarization, mock.patch( - "ax.generators.torch.botorch_moo_defaults.optimize_acqf_list", - wraps=optimize_acqf_list, - ) as mock_optimize: - model.gen( - n, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - # get_chebyshev_scalarization should be called once for generated candidate. - self.assertEqual(n, _mock_chebyshev_scalarization.call_count) - self.assertEqual( - mock_optimize.call_args.kwargs["options"]["init_batch_limit"], - INIT_BATCH_LIMIT, - ) - self.assertEqual(mock_optimize.call_args.kwargs["options"]["batch_limit"], 1) - - def test_BotorchMOOModel_with_qehvi( - self, - dtype: torch.dtype = torch.float, - cuda: bool = False, - use_noisy: bool = False, - use_log: bool = True, - ) -> None: - if use_log: - if use_noisy: - acqf_constructor = get_qLogNEHVI - acquisition_path = LOG_NEHVI_ACQF_PATH - acqf_class = moo_logei.qLogNoisyExpectedHypervolumeImprovement - partitioning_path = NOISY_PARTITIONING_PATH - else: - acqf_constructor = get_qLogEHVI - acquisition_path = LOG_EHVI_ACQF_PATH - acqf_class = moo_logei.qLogExpectedHypervolumeImprovement - partitioning_path = PARTITIONING_PATH - else: - if use_noisy: - acqf_constructor = get_NEHVI - acquisition_path = NEHVI_ACQF_PATH - acqf_class = moo_monte_carlo.qNoisyExpectedHypervolumeImprovement - partitioning_path = NOISY_PARTITIONING_PATH - else: - acqf_constructor = get_EHVI - acquisition_path = EHVI_ACQF_PATH - acqf_class = moo_monte_carlo.qExpectedHypervolumeImprovement - partitioning_path = PARTITIONING_PATH - - tkwargs: dict[str, Any] = { - "device": torch.device("cuda") if cuda else torch.device("cpu"), - "dtype": dtype, - } - ( - Xs, - Ys, - Yvars, - bounds, - tfs, - feature_names, - _, - ) = get_torch_test_data(dtype=dtype, cuda=cuda, constant_noise=True) - training_data = [ - SupervisedDataset( - X=Xs, - Y=Ys, - Yvar=Yvars, - feature_names=feature_names, - outcome_names=[name], - ) - for name in ["m1", "m2", "m3"] - ] - - n = 3 - objective_weights = torch.tensor([1.0, 1.0, 0.0], **tkwargs) - obj_t = torch.tensor([1.0, 1.0, float("nan")], **tkwargs) - # pyre-fixme[6]: For 1st param expected `(Model, Tensor, Optional[Tuple[Tenso... - model = MultiObjectiveLegacyBoTorchGenerator(acqf_constructor=acqf_constructor) - - X_dummy = torch.tensor([[[1.0, 2.0, 3.0]]], **tkwargs) - acqfv_dummy = torch.tensor([[[1.0, 2.0, 3.0]]], **tkwargs) - - search_space_digest = SearchSpaceDigest( - feature_names=feature_names, - bounds=bounds, - task_features=tfs, - ) - with mock.patch(FIT_MODEL_MO_PATH) as _mock_fit_model: - model.fit( - datasets=training_data, - search_space_digest=search_space_digest, - ) - _mock_fit_model.assert_called_once() - with ExitStack() as es: - _mock_acqf = es.enter_context( - mock.patch( - acquisition_path, - wraps=acqf_class, - ) - ) - mock_optimize = es.enter_context( - mock.patch( - "ax.generators.torch.botorch_defaults.optimize_acqf", - return_value=(X_dummy, acqfv_dummy), - ) - ) - _mock_partitioning = es.enter_context( - mock.patch( - partitioning_path, - wraps=hypervolume.FastNondominatedPartitioning, - ) - ) - torch_opt_config = TorchOptConfig( - objective_weights=objective_weights, - objective_thresholds=obj_t, - model_gen_options={ - "optimizer_kwargs": {"options": {"batch_limit": 1}}, - }, - ) - gen_results = model.gen( - n, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - # the NEHVI acquisition function should be created only once. - self.assertEqual(1, _mock_acqf.call_count) - # check partitioning strategy - # NEHVI should call FastNondominatedPartitioning 1 time - # since a batched partitioning is used for 2 objectives - _mock_partitioning.assert_called_once() - self.assertTrue( - torch.equal( - gen_results.gen_metadata["objective_thresholds"][:2], - obj_t[:2].cpu(), - ) - ) - self.assertTrue( - torch.isnan(gen_results.gen_metadata["objective_thresholds"][-1]) - ) - _mock_fit_model = es.enter_context(mock.patch(FIT_MODEL_MO_PATH)) - # Optimizer options correctly passed through. - self.assertEqual( - mock_optimize.call_args.kwargs["options"]["init_batch_limit"], - INIT_BATCH_LIMIT, - ) - self.assertEqual( - mock_optimize.call_args.kwargs["options"]["batch_limit"], 1 - ) - # 3 objective - training_data_m3 = training_data + [training_data[-1]] - - model.fit( - datasets=training_data_m3, - search_space_digest=search_space_digest, - ) - torch_opt_config = TorchOptConfig( - objective_weights=torch.tensor([1.0, 1.0, 1.0], **tkwargs), - objective_thresholds=torch.tensor([1.0, 1.0, 1.0], **tkwargs), - ) - model.gen( - n, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - # check partitioning strategy - # NEHVI should call FastNondominatedPartitioning 129 times because - # we have called gen twice: The first time, a batch partitioning is used - # so there is one call to _mock_partitioning. The second time gen() is - # called with three objectives so 128 calls are made to _mock_partitioning - # because a BoxDecompositionList is used. qLogEHVI will only make 2 calls. - self.assertEqual( - len(_mock_partitioning.mock_calls), 129 if use_noisy else 2 - ) - - # test inferred objective thresholds in gen() - # create several data points - Xs = torch.cat([Xs, Xs - 0.1], dim=0) - Ys1 = torch.cat([Ys, Ys - 0.5], dim=0) - Ys2 = torch.cat([Ys, Ys + 0.5], dim=0) - Ys3 = torch.cat([Ys, Ys - 1.0], dim=0) - Yvars1 = torch.cat([Yvars, Yvars + 0.2], dim=0) - Yvars2 = torch.cat([Yvars, Yvars + 0.1], dim=0) - Yvars3 = torch.cat([Yvars, Yvars + 0.4], dim=0) - training_data_multiple = [ - SupervisedDataset( - X=Xs, - Y=Y, - Yvar=Yvar, - feature_names=feature_names, - outcome_names=[name], - ) - for Y, Yvar, name in zip( - [Ys1, Ys2, Ys3], [Yvars1, Yvars2, Yvars3], ["m1", "m2", "m3"] - ) - ] - model.fit( - datasets=training_data_multiple, - search_space_digest=search_space_digest, - ) - es.enter_context( - mock.patch( - "ax.generators.torch.botorch_moo_defaults._check_posterior_type", - wraps=lambda y: y, - ) - ) - _mock_model_infer_objective_thresholds = es.enter_context( - mock.patch( - "ax.generators.torch.botorch_moo.infer_objective_thresholds", - wraps=infer_objective_thresholds, - ) - ) - _mock_infer_reference_point = es.enter_context( - mock.patch( - "ax.generators.torch.botorch_moo_defaults.infer_reference_point", - wraps=infer_reference_point, - ) - ) - preds = torch.tensor( - [ - [11.0, 2.0], - [9.0, 3.0], - [12.0, 0.0], - [13.0, 0.0], - ], - **tkwargs, - ) - es.enter_context( - mock.patch.object( - model.model, - "posterior", - return_value=MockPosterior( - mean=preds, - samples=preds, - ), - ) - ) - es.enter_context( - mock.patch( - "botorch.acquisition.factory.get_sampler", - return_value=IIDNormalSampler(sample_shape=torch.Size([2])), - ) - ) - outcome_constraints = ( - torch.tensor([[1.0, 0.0, 0.0]], **tkwargs), - torch.tensor([[10.0]], **tkwargs), - ) - torch_opt_config = TorchOptConfig( - objective_weights=torch.tensor([-1.0, -1.0, 0.0], **tkwargs), - outcome_constraints=outcome_constraints, - model_gen_options={ - # do not used cached root decomposition since - # MockPosterior does not have an mvn attribute - "acquisition_function_kwargs": ( - { - "cache_root": False, - "prune_baseline": False, - } - if use_noisy - else {} - ), - }, - ) - gen_results = model.gen( - n, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - # the NEHVI acquisition function should be created only once. - self.assertEqual(_mock_acqf.call_count, 3) - ckwargs = _mock_model_infer_objective_thresholds.call_args[1] - X_observed = ckwargs["X_observed"] - sorted_idcs = X_observed[:, 0].argsort() - sorted_idcs2 = Xs[:, 0].argsort() - self.assertTrue(torch.equal(X_observed[sorted_idcs], Xs[sorted_idcs2])) - self.assertTrue( - torch.equal( - ckwargs["objective_weights"], - torch.tensor([-1.0, -1.0, 0.0], **tkwargs), - ) - ) - oc = ckwargs["outcome_constraints"] - self.assertTrue(torch.equal(oc[0], outcome_constraints[0])) - self.assertTrue(torch.equal(oc[1], outcome_constraints[1])) - subset_model = ckwargs["model"] - self.assertIsInstance(subset_model, SingleTaskGP) - self.assertEqual(subset_model.num_outputs, 2) - self.assertTrue( - torch.equal( - ckwargs["subset_idcs"], - torch.tensor([0, 1], device=tkwargs["device"]), - ) - ) - _mock_infer_reference_point.assert_called_once() - ckwargs = _mock_infer_reference_point.call_args[1] - self.assertEqual(ckwargs["scale"], 0.1) - self.assertTrue( - torch.equal( - ckwargs["pareto_Y"], torch.tensor([[-9.0, -3.0]], **tkwargs) - ) - ) - self.assertIn("objective_thresholds", gen_results.gen_metadata) - obj_t = gen_results.gen_metadata["objective_thresholds"] - self.assertTrue( - torch.equal(obj_t[:2], torch.tensor([9.9, 3.3], dtype=tkwargs["dtype"])) - ) - self.assertTrue(np.isnan(obj_t[2])) - # test providing model with extra tracking metrics and objective thresholds - provided_obj_t = torch.tensor([10.0, 4.0, float("nan")], **tkwargs) - gen_results = model.gen( - n, - search_space_digest=search_space_digest, - torch_opt_config=dataclasses.replace( - torch_opt_config, - objective_thresholds=provided_obj_t, - ), - ) - self.assertIn("objective_thresholds", gen_results.gen_metadata) - obj_t = gen_results.gen_metadata["objective_thresholds"] - self.assertTrue(torch.equal(obj_t[:2], provided_obj_t[:2].cpu())) - self.assertTrue(np.isnan(obj_t[2])) - - @mock_botorch_optimize - def test_BotorchMOOModel_with_random_scalarization_and_outcome_constraints( - self, - dtype: torch.dtype = torch.float, - cuda: bool = False, - ) -> None: - tkwargs: dict[str, Any] = { - "device": torch.device("cuda") if cuda else torch.device("cpu"), - "dtype": dtype, - } - ( - Xs, - Ys, - Yvars, - bounds, - tfs, - feature_names, - _, - ) = get_torch_test_data(dtype=dtype, cuda=cuda, constant_noise=True) - training_data = [ - SupervisedDataset( - X=Xs, - Y=Ys, - Yvar=Yvars, - feature_names=feature_names, - outcome_names=[name], - ) - for name in ["m1", "m2"] - ] - - n = 2 - objective_weights = torch.tensor([1.0, 1.0], **tkwargs) - obj_t = torch.tensor([1.0, 1.0], **tkwargs) - model = MultiObjectiveLegacyBoTorchGenerator(acqf_constructor=get_qLogNEI) - - search_space_digest = SearchSpaceDigest( - feature_names=feature_names, - bounds=bounds, - task_features=tfs, - ) - with mock.patch(FIT_MODEL_MO_PATH) as _mock_fit_model: - model.fit( - datasets=training_data, - search_space_digest=search_space_digest, - ) - _mock_fit_model.assert_called_once() - - with mock.patch( - SAMPLE_SIMPLEX_UTIL_PATH, - autospec=True, - return_value=torch.tensor([0.7, 0.3], **tkwargs), - ) as _mock_sample_simplex: - model.gen( - n, - search_space_digest=search_space_digest, - torch_opt_config=TorchOptConfig( - objective_weights=objective_weights, - outcome_constraints=( - torch.tensor([[1.0, 1.0]], **tkwargs), - torch.tensor([[10.0]], **tkwargs), - ), - model_gen_options={ - "acquisition_function_kwargs": {"random_scalarization": True}, - }, - objective_thresholds=obj_t, - ), - ) - self.assertEqual(n, _mock_sample_simplex.call_count) - - @mock_botorch_optimize - def test_BotorchMOOModel_with_chebyshev_scalarization_and_outcome_constraints( - self, - dtype: torch.dtype = torch.float, - cuda: bool = False, - ) -> None: - tkwargs: dict[str, Any] = { - "device": torch.device("cuda") if cuda else torch.device("cpu"), - "dtype": torch.float, - } - ( - Xs, - Ys, - Yvars, - bounds, - tfs, - feature_names, - _, - ) = get_torch_test_data(dtype=dtype, cuda=cuda, constant_noise=True) - training_data = [ - SupervisedDataset( - X=Xs, - Y=Ys, - Yvar=Yvars, - feature_names=feature_names, - outcome_names=[name], - ) - for name in ["m1", "m2"] - ] - - n = 2 - objective_weights = torch.tensor([1.0, 1.0], **tkwargs) - obj_t = torch.tensor([1.0, 1.0], **tkwargs) - model = MultiObjectiveLegacyBoTorchGenerator(acqf_constructor=get_qLogNEI) - - search_space_digest = SearchSpaceDigest( - feature_names=feature_names, - bounds=bounds, - task_features=tfs, - ) - with mock.patch(FIT_MODEL_MO_PATH) as _mock_fit_model: - model.fit( - datasets=training_data, - search_space_digest=search_space_digest, - ) - _mock_fit_model.assert_called_once() - - torch_opt_config = TorchOptConfig( - objective_weights=objective_weights, - outcome_constraints=( - torch.tensor([[1.0, 1.0]], **tkwargs), - torch.tensor([[10.0]], **tkwargs), - ), - model_gen_options={ - "acquisition_function_kwargs": {"chebyshev_scalarization": True}, - }, - objective_thresholds=obj_t, - ) - with mock.patch( - CHEBYSHEV_SCALARIZATION_PATH, wraps=get_chebyshev_scalarization - ) as _mock_chebyshev_scalarization: - model.gen( - n, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - # get_chebyshev_scalarization should be called once for generated candidate. - self.assertEqual(n, _mock_chebyshev_scalarization.call_count) - - @mock_botorch_optimize - def test_BotorchMOOModel_with_qehvi_and_outcome_constraints( - self, - dtype: torch.dtype = torch.float, - cuda: bool = False, - use_noisy: bool = False, - use_log: bool = True, - ) -> None: - if use_log: - acqf_constructor = ( - botorch_moo_defaults.get_qLogNEHVI - if use_noisy - else botorch_moo_defaults.get_qLogEHVI - ) - else: - acqf_constructor = ( - botorch_moo_defaults.get_NEHVI - if use_noisy - else botorch_moo_defaults.get_EHVI - ) - - tkwargs: dict[str, Any] = { - "device": torch.device("cuda") if cuda else torch.device("cpu"), - "dtype": dtype, - } - ( - Xs, - Ys, - Yvars, - bounds, - tfs, - feature_names, - _, - ) = get_torch_test_data(dtype=dtype, cuda=cuda, constant_noise=True) - bounds[0] = (0.0, 1.0) # make one data point out of bounds - training_data = [ - SupervisedDataset( - X=Xs, - Y=Ys, - Yvar=Yvars, - feature_names=feature_names, - outcome_names=[name], - ) - for name in ["m1", "m2", "m3"] - ] - - n = 3 - objective_weights = torch.tensor([1.0, 1.0, 0.0], **tkwargs) - obj_t = torch.tensor([1.0, 1.0, 1.0], **tkwargs) - # pyre-fixme[6]: For 1st param expected `(Model, Tensor, Optional[Tuple[Tenso... - model = MultiObjectiveLegacyBoTorchGenerator(acqf_constructor=acqf_constructor) - - search_space_digest = SearchSpaceDigest( - feature_names=feature_names, - bounds=bounds, - task_features=tfs, - ) - with mock.patch(FIT_MODEL_MO_PATH) as _mock_fit_model: - model.fit( - datasets=training_data, - search_space_digest=search_space_digest, - ) - _mock_fit_model.assert_called_once() - - # test wrong number of objective thresholds - torch_opt_config = TorchOptConfig( - objective_weights=objective_weights, - objective_thresholds=torch.tensor([1.0, 1.0], **tkwargs), - ) - with self.assertRaises(AxError): - model.gen( - n, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - # test that objective thresholds and weights are properly subsetted - obj_t = torch.tensor([1.0, 1.0, 1.0], **tkwargs) - torch_opt_config = dataclasses.replace( - torch_opt_config, - objective_thresholds=obj_t, - ) - with mock.patch.object( - model, - "acqf_constructor", - wraps=acqf_constructor, # botorch_moo_defaults.get_qLogNEHVI, - ) as mock_get_nehvi: - model.gen( - n, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - mock_get_nehvi.assert_called_once() - _, ckwargs = mock_get_nehvi.call_args - self.assertEqual(ckwargs["model"].num_outputs, 2) - self.assertTrue( - torch.equal(ckwargs["objective_weights"], objective_weights[:-1]) - ) - self.assertTrue(torch.equal(ckwargs["objective_thresholds"], obj_t[:-1])) - self.assertIsNone(ckwargs["outcome_constraints"]) - # the second datapoint is out of bounds - self.assertTrue(torch.equal(ckwargs["X_observed"], Xs[:1])) - self.assertIsNone(ckwargs["X_pending"]) - - # test that outcome constraints are passed properly - oc = ( - torch.tensor([[0.0, 0.0, 1.0]], **tkwargs), - torch.tensor([[10.0]], **tkwargs), - ) - torch_opt_config = dataclasses.replace( - torch_opt_config, - outcome_constraints=oc, - ) - with mock.patch.object( - model, - "acqf_constructor", - wraps=acqf_constructor, # botorch_moo_defaults.get_qLogNEHVI, - ) as mock_get_nehvi: - model.gen( - n, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - mock_get_nehvi.assert_called_once() - _, ckwargs = mock_get_nehvi.call_args - self.assertEqual(ckwargs["model"].num_outputs, 3) - self.assertTrue(torch.equal(ckwargs["objective_weights"], objective_weights)) - self.assertTrue(torch.equal(ckwargs["objective_thresholds"], obj_t)) - self.assertTrue(torch.equal(ckwargs["outcome_constraints"][0], oc[0])) - self.assertTrue(torch.equal(ckwargs["outcome_constraints"][1], oc[1])) - # the second datapoint is out of bounds - self.assertTrue(torch.equal(ckwargs["X_observed"], Xs[:1])) - self.assertIsNone(ckwargs["X_pending"]) diff --git a/ax/generators/torch/botorch_moo.py b/ax/generators/torch/botorch_moo.py deleted file mode 100644 index 74d5874481e..00000000000 --- a/ax/generators/torch/botorch_moo.py +++ /dev/null @@ -1,398 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from collections.abc import Callable -from typing import Any, Optional - -import torch -from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import AxError -from ax.generators.torch.botorch import ( - get_rounding_func, - LegacyBoTorchGenerator, - TBestPointRecommender, - TModelConstructor, - TModelPredictor, - TOptimizer, -) -from ax.generators.torch.botorch_defaults import ( - get_and_fit_model, - recommend_best_observed_point, - scipy_optimizer, - TAcqfConstructor, -) -from ax.generators.torch.botorch_moo_defaults import ( - get_qLogNEHVI, - infer_objective_thresholds, - pareto_frontier_evaluator, - scipy_optimizer_list, - TFrontierEvaluator, -) -from ax.generators.torch.utils import ( - _get_X_pending_and_observed, - _to_inequality_constraints, - predict_from_model, - randomize_objective_weights, - subset_model, -) -from ax.generators.torch_base import TorchGenerator, TorchGenResults, TorchOptConfig -from ax.utils.common.constants import Keys -from ax.utils.common.docutils import copy_doc -from botorch.acquisition.acquisition import AcquisitionFunction -from botorch.models.model import Model -from pyre_extensions import assert_is_instance, none_throws -from torch import Tensor - - -TOptimizerList = Callable[ - [ - list[AcquisitionFunction], - Tensor, - Optional[list[tuple[Tensor, Tensor, float]]], - Optional[dict[int, float]], - Optional[Callable[[Tensor], Tensor]], - Any, - ], - tuple[Tensor, Tensor], -] - - -class MultiObjectiveLegacyBoTorchGenerator(LegacyBoTorchGenerator): - r""" - Customizable multi-objective model. - - By default, this uses an Expected Hypervolume Improvment function to find the - pareto frontier of a function with multiple outcomes. This behavior - can be modified by providing custom implementations of the following - components: - - - a `model_constructor` that instantiates and fits a model on data - - a `model_predictor` that predicts outcomes using the fitted model - - a `acqf_constructor` that creates an acquisition function from a fitted model - - a `acqf_optimizer` that optimizes the acquisition function - - Args: - model_constructor: A callable that instantiates and fits a model on data, - with signature as described below. - model_predictor: A callable that predicts using the fitted model, with - signature as described below. - acqf_constructor: A callable that creates an acquisition function from a - fitted model, with signature as described below. - acqf_optimizer: A callable that optimizes an acquisition - function, with signature as described below. - - - - Call signatures: - - :: - - model_constructor( - Xs, - Ys, - Yvars, - task_features, - fidelity_features, - metric_signatures, - state_dict, - **kwargs, - ) -> model - - Here `Xs`, `Ys`, `Yvars` are lists of tensors (one element per outcome), - `task_features` identifies columns of Xs that should be modeled as a task, - `fidelity_features` is a list of ints that specify the positions of fidelity - parameters in 'Xs', `metric_signatures` provides the names of each `Y` in `Ys`, - `state_dict` is a pytorch module state dict, and `model` is a BoTorch `Model`. - Optional kwargs are being passed through from the `LegacyBoTorchGenerator` - constructor. This callable is assumed to return a fitted BoTorch model that has - the same dtype and lives on the same device as the input tensors. - - :: - - model_predictor(model, X) -> [mean, cov] - - Here `model` is a fitted botorch model, `X` is a tensor of candidate points, - and `mean` and `cov` are the posterior mean and covariance, respectively. - - :: - - acqf_constructor( - model, - objective_weights, - outcome_constraints, - X_observed, - X_pending, - **kwargs, - ) -> acq_function - - - Here `model` is a botorch `Model`, `objective_weights` is a tensor of weights - for the model outputs, `outcome_constraints` is a tuple of tensors describing - the (linear) outcome constraints, `X_observed` are previously observed points, - and `X_pending` are points whose evaluation is pending. `acq_function` is a - BoTorch acquisition function crafted from these inputs. For additional - details on the arguments, see `get_qLogNEHVI`. - - :: - - acqf_optimizer( - acq_function, - bounds, - n, - inequality_constraints, - fixed_features, - rounding_func, - **kwargs, - ) -> candidates - - Here `acq_function` is a BoTorch `AcquisitionFunction`, `bounds` is a tensor - containing bounds on the parameters, `n` is the number of candidates to be - generated, `inequality_constraints` are inequality constraints on parameter - values, `fixed_features` specifies features that should be fixed during - generation, and `rounding_func` is a callback that rounds an optimization - result appropriately. `candidates` is a tensor of generated candidates. - For additional details on the arguments, see `scipy_optimizer`. - - :: - - frontier_evaluator( - model, - objective_weights, - objective_thresholds, - X, - Y, - Yvar, - outcome_constraints, - ) - - Here `model` is a botorch `Model`, `objective_thresholds` is used in hypervolume - evaluations, `objective_weights` is a tensor of weights applied to the objectives - (sign represents direction), `X`, `Y`, `Yvar` are tensors, `outcome_constraints` is - a tuple of tensors describing the (linear) outcome constraints. - """ - - dtype: torch.dtype | None - device: torch.device | None - Xs: list[Tensor] - Ys: list[Tensor] - Yvars: list[Tensor] - - def __init__( - self, - model_constructor: TModelConstructor = get_and_fit_model, - model_predictor: TModelPredictor = predict_from_model, - # pyre-fixme[9]: acqf_constructor has type `Callable[[Model, Tensor, - # Optional[Tuple[Tensor, Tensor]], Optional[Tensor], Optional[Tensor], Any], - # AcquisitionFunction]`; used as `Callable[[Model, Tensor, - # Optional[Tuple[Tensor, Tensor]], Optional[Tensor], Optional[Tensor], - # **(Any)], AcquisitionFunction]`. - acqf_constructor: TAcqfConstructor = get_qLogNEHVI, - # pyre-fixme[9]: acqf_optimizer has type `Callable[[AcquisitionFunction, - # Tensor, int, Optional[Dict[int, float]], Optional[Callable[[Tensor], - # Tensor]], Any], Tensor]`; used as `Callable[[AcquisitionFunction, Tensor, - # int, Optional[Dict[int, float]], Optional[Callable[[Tensor], Tensor]], - # **(Any)], Tensor]`. - acqf_optimizer: TOptimizer = scipy_optimizer, - # TODO: Remove best_point_recommender for botorch_moo. Used in adapter._gen. - best_point_recommender: TBestPointRecommender = recommend_best_observed_point, - frontier_evaluator: TFrontierEvaluator = pareto_frontier_evaluator, - refit_on_cv: bool = False, - warm_start_refitting: bool = False, - use_input_warping: bool = False, - use_loocv_pseudo_likelihood: bool = False, - prior: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - self.model_constructor = model_constructor - self.model_predictor = model_predictor - self.acqf_constructor = acqf_constructor - self.acqf_optimizer = acqf_optimizer - self.best_point_recommender = best_point_recommender - self.frontier_evaluator = frontier_evaluator - # pyre-fixme[4]: Attribute must be annotated. - self._kwargs = kwargs - self.refit_on_cv = refit_on_cv - self.warm_start_refitting = warm_start_refitting - self.use_input_warping = use_input_warping - self.use_loocv_pseudo_likelihood = use_loocv_pseudo_likelihood - self.prior = prior - self.model: Model | None = None - self.Xs = [] - self.Ys = [] - self.Yvars = [] - self.dtype = None - self.device = None - self.task_features: list[int] = [] - self.fidelity_features: list[int] = [] - self.metric_signatures: list[str] = [] - - @copy_doc(TorchGenerator.gen) - def gen( - self, - n: int, - search_space_digest: SearchSpaceDigest, - torch_opt_config: TorchOptConfig, - ) -> TorchGenResults: - options = torch_opt_config.model_gen_options or {} - acf_options = options.get("acquisition_function_kwargs", {}) - optimizer_options = options.get("optimizer_kwargs", {}) - - if search_space_digest.fidelity_features: # untested - raise NotImplementedError( - "fidelity_features not implemented for base LegacyBoTorchGenerator" - ) - if ( - torch_opt_config.objective_thresholds is not None - and torch_opt_config.objective_weights.shape[0] - != none_throws(torch_opt_config.objective_thresholds).shape[0] - ): - raise AxError( - "Objective weights and thresholds most both contain an element for" - " each modeled metric." - ) - - X_pending, X_observed = _get_X_pending_and_observed( - Xs=self.Xs, - objective_weights=torch_opt_config.objective_weights, - bounds=search_space_digest.bounds, - pending_observations=torch_opt_config.pending_observations, - outcome_constraints=torch_opt_config.outcome_constraints, - linear_constraints=torch_opt_config.linear_constraints, - fixed_features=torch_opt_config.fixed_features, - ) - - model = none_throws(self.model) - full_objective_thresholds = torch_opt_config.objective_thresholds - full_objective_weights = torch_opt_config.objective_weights - full_outcome_constraints = torch_opt_config.outcome_constraints - # subset model only to the outcomes we need for the optimization - if options.get(Keys.SUBSET_MODEL, True): - subset_model_results = subset_model( - model=model, - objective_weights=torch_opt_config.objective_weights, - outcome_constraints=torch_opt_config.outcome_constraints, - objective_thresholds=torch_opt_config.objective_thresholds, - ) - model = subset_model_results.model - objective_weights = subset_model_results.objective_weights - outcome_constraints = subset_model_results.outcome_constraints - objective_thresholds = subset_model_results.objective_thresholds - idcs = subset_model_results.indices - else: - objective_weights = torch_opt_config.objective_weights - outcome_constraints = torch_opt_config.outcome_constraints - objective_thresholds = torch_opt_config.objective_thresholds - idcs = None - - bounds_ = torch.tensor( - search_space_digest.bounds, dtype=self.dtype, device=self.device - ) - bounds_ = bounds_.transpose(0, 1) - botorch_rounding_func = get_rounding_func(torch_opt_config.rounding_func) - if acf_options.pop("random_scalarization", False) or acf_options.get( - "chebyshev_scalarization", False - ): - # If using a list of acquisition functions, the algorithm to generate - # that list is configured by acquisition_function_kwargs. - if "random_scalarization_distribution" in acf_options: - randomize_weights_kws = { - "random_scalarization_distribution": acf_options[ - "random_scalarization_distribution" - ] - } - del acf_options["random_scalarization_distribution"] - else: - randomize_weights_kws = {} - objective_weights_list = [ - randomize_objective_weights(objective_weights, **randomize_weights_kws) - for _ in range(n) - ] - acquisition_function_list = [ - self.acqf_constructor( - model=model, - objective_weights=objective_weights, - outcome_constraints=outcome_constraints, - X_observed=X_observed, - X_pending=X_pending, - **acf_options, - ) - for objective_weights in objective_weights_list - ] - acquisition_function_list = [ - assert_is_instance(acq_function, AcquisitionFunction) - for acq_function in acquisition_function_list - ] - # Multiple acquisition functions require a sequential optimizer - # always use scipy_optimizer_list. - # TODO(jej): Allow any optimizer. - candidates, expected_acquisition_value = scipy_optimizer_list( - acq_function_list=acquisition_function_list, - bounds=bounds_, - inequality_constraints=_to_inequality_constraints( - linear_constraints=torch_opt_config.linear_constraints - ), - fixed_features=torch_opt_config.fixed_features, - rounding_func=botorch_rounding_func, - **optimizer_options, - ) - else: - if ( - objective_thresholds is None - or objective_thresholds[objective_weights != 0].isnan().any() - ): - full_objective_thresholds = infer_objective_thresholds( - model=model, - X_observed=none_throws(X_observed), - objective_weights=full_objective_weights, - outcome_constraints=full_outcome_constraints, - subset_idcs=idcs, - objective_thresholds=objective_thresholds, - ) - # subset the objective thresholds - objective_thresholds = ( - full_objective_thresholds - if idcs is None - else full_objective_thresholds[idcs].clone() - ) - acquisition_function = self.acqf_constructor( - model=model, - objective_weights=objective_weights, - objective_thresholds=objective_thresholds, - outcome_constraints=outcome_constraints, - X_observed=X_observed, - X_pending=X_pending, - **acf_options, - ) - acquisition_function = assert_is_instance( - acquisition_function, AcquisitionFunction - ) - # pyre-ignore: [28] - candidates, expected_acquisition_value = self.acqf_optimizer( - acq_function=assert_is_instance( - acquisition_function, AcquisitionFunction - ), - bounds=bounds_, - n=n, - inequality_constraints=_to_inequality_constraints( - linear_constraints=torch_opt_config.linear_constraints - ), - fixed_features=torch_opt_config.fixed_features, - rounding_func=botorch_rounding_func, - **optimizer_options, - ) - gen_metadata = { - "expected_acquisition_value": expected_acquisition_value.tolist(), - "objective_weights": full_objective_weights.cpu(), - } - if full_objective_thresholds is not None: - gen_metadata["objective_thresholds"] = full_objective_thresholds.cpu() - return TorchGenResults( - points=candidates.detach().cpu(), - weights=torch.ones(n, dtype=self.dtype), - gen_metadata=gen_metadata, - )