Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ax/benchmark/methods/modular_botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def get_sobol_mbm_generation_strategy(
acquisition_cls: type[AcquisitionFunction] | None = None,
name: str | None = None,
num_sobol_trials: int = 5,
model_kwargs_override: dict[str, Any] | None = None,
model_gen_kwargs: dict[str, Any] | None = None,
batch_size: int = 1,
) -> GenerationStrategy:
Expand All @@ -53,6 +54,8 @@ def get_sobol_mbm_generation_strategy(
name: Name that will be attached to the `GenerationStrategy`.
num_sobol_trials: Number of Sobol trials; can refer to the number of
`BatchTrial`s.
model_kwargs_override: Passed to the MBM BoTorch `GenerationStep` inside
`model_kwargs`.
model_gen_kwargs: Passed to the BoTorch `GenerationStep` and ultimately
to the BoTorch `Model`.

Expand All @@ -79,6 +82,9 @@ def get_sobol_mbm_generation_strategy(
else:
acqf_name = ""

if model_kwargs_override is not None:
model_kwargs.update(model_kwargs_override)

model_name = model_names_abbrevations.get(model_cls.__name__, model_cls.__name__)
# Historically all benchmarks were sequential, so sequential benchmarks
# don't get anything added to their name, for continuity
Expand Down Expand Up @@ -113,6 +119,7 @@ def get_sobol_botorch_modular_acquisition(
acquisition_cls: type[AcquisitionFunction] | None = None,
name: str | None = None,
num_sobol_trials: int = 5,
model_kwargs_override: dict[str, Any] | None = None,
model_gen_kwargs: dict[str, Any] | None = None,
batch_size: int = 1,
) -> BenchmarkMethod:
Expand All @@ -126,6 +133,8 @@ def get_sobol_botorch_modular_acquisition(
num_sobol_trials: Number of Sobol trials; if the orchestrator_options
specify to use `BatchTrial`s, then this refers to the number of
`BatchTrial`s.
model_kwargs_override: Passed to the MBM BoTorch `GenerationStep` inside
`model_kwargs`.
model_gen_kwargs: Passed to the BoTorch `GenerationStep` and ultimately
to the BoTorch `Model`.
batch_size: Passed to the created ``BenchmarkMethod``.
Expand Down Expand Up @@ -158,6 +167,7 @@ def get_sobol_botorch_modular_acquisition(
acquisition_cls=acquisition_cls,
name=name,
num_sobol_trials=num_sobol_trials,
model_kwargs_override=model_kwargs_override,
model_gen_kwargs=model_gen_kwargs,
batch_size=batch_size,
)
Expand Down