diff --git a/ax/benchmark/methods/modular_botorch.py b/ax/benchmark/methods/modular_botorch.py index 338b30cc37f..cbca5e0e507 100644 --- a/ax/benchmark/methods/modular_botorch.py +++ b/ax/benchmark/methods/modular_botorch.py @@ -41,6 +41,7 @@ def get_sobol_mbm_generation_strategy( acquisition_cls: type[AcquisitionFunction] | None = None, name: str | None = None, num_sobol_trials: int = 5, + model_kwargs_override: dict[str, Any] | None = None, model_gen_kwargs: dict[str, Any] | None = None, batch_size: int = 1, ) -> GenerationStrategy: @@ -53,6 +54,8 @@ def get_sobol_mbm_generation_strategy( name: Name that will be attached to the `GenerationStrategy`. num_sobol_trials: Number of Sobol trials; can refer to the number of `BatchTrial`s. + model_kwargs_override: Passed to the MBM BoTorch `GenerationStep` inside + `model_kwargs`. model_gen_kwargs: Passed to the BoTorch `GenerationStep` and ultimately to the BoTorch `Model`. @@ -79,6 +82,9 @@ def get_sobol_mbm_generation_strategy( else: acqf_name = "" + if model_kwargs_override is not None: + model_kwargs.update(model_kwargs_override) + model_name = model_names_abbrevations.get(model_cls.__name__, model_cls.__name__) # Historically all benchmarks were sequential, so sequential benchmarks # don't get anything added to their name, for continuity @@ -113,6 +119,7 @@ def get_sobol_botorch_modular_acquisition( acquisition_cls: type[AcquisitionFunction] | None = None, name: str | None = None, num_sobol_trials: int = 5, + model_kwargs_override: dict[str, Any] | None = None, model_gen_kwargs: dict[str, Any] | None = None, batch_size: int = 1, ) -> BenchmarkMethod: @@ -126,6 +133,8 @@ def get_sobol_botorch_modular_acquisition( num_sobol_trials: Number of Sobol trials; if the orchestrator_options specify to use `BatchTrial`s, then this refers to the number of `BatchTrial`s. + model_kwargs_override: Passed to the MBM BoTorch `GenerationStep` inside + `model_kwargs`. model_gen_kwargs: Passed to the BoTorch `GenerationStep` and ultimately to the BoTorch `Model`. batch_size: Passed to the created ``BenchmarkMethod``. @@ -158,6 +167,7 @@ def get_sobol_botorch_modular_acquisition( acquisition_cls=acquisition_cls, name=name, num_sobol_trials=num_sobol_trials, + model_kwargs_override=model_kwargs_override, model_gen_kwargs=model_gen_kwargs, batch_size=batch_size, )