Skip to content

Commit e54ec2e

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Inference trace and Best Point Recommendation (BPR) bugfix (#4128)
Summary: This diff addresses two issues in the computation of inference trace: 1. The generation strategy is copied inside run_optimization_with_orchestrator --> we retrieve the traces on an unused generation strategy --> get_best_point defaults to the best raw observation on ALL obserations 2. Relevant data not filtered in the fallback option for get_best_parameters_from_model_predictions_with_trial_index Both of these individually lead to the inference trace being incorrect - the first to the best raw value of ALL trials, the second to the best predicted across ALL trials. Changes: - Moved copying of generation strategy to the level`benchmark_replication`, since results need to be computed on the used `generation_strategy` and not an empty copy. This means that `run_optimization_with_orchestrator` no longer `clone_and_reset`'s the GS. - Clearer sequencing in get_best_parameters_from_model_predictions_with_trial_index - Removed model fit quality check as part of BPR Previous, redacted changes: - Added argument use_model_only_if_good to force model-based BPR even if model fit is bad Differential Revision: D80019803
1 parent 2da0640 commit e54ec2e

File tree

6 files changed

+357
-283
lines changed

6 files changed

+357
-283
lines changed

ax/benchmark/benchmark.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,8 @@ def get_best_parameters(
326326
trial_indices: Iterable[int] | None = None,
327327
) -> TParameterization | None:
328328
"""
329-
Get the most promising point.
330-
331-
Only SOO is supported. It will return None if no best point can be found.
329+
Get the most promising point. Returns None if no point is predicted to
330+
satisfy all outcome constraints.
332331
333332
Args:
334333
experiment: The experiment to get the data from. This should contain
@@ -421,16 +420,17 @@ def get_benchmark_result_from_experiment_and_gs(
421420
optimization_config=problem.optimization_config,
422421
)
423422
)
424-
inference_trace = get_inference_trace(
425-
trial_completion_order=trial_completion_order,
426-
experiment=experiment,
427-
problem=problem,
428-
generation_strategy=generation_strategy,
429-
)
430-
431-
optimization_trace = (
432-
inference_trace if problem.report_inference_value_as_trace else oracle_trace
433-
)
423+
if problem.report_inference_value_as_trace:
424+
inference_trace = get_inference_trace(
425+
trial_completion_order=trial_completion_order,
426+
experiment=experiment,
427+
problem=problem,
428+
generation_strategy=generation_strategy,
429+
)
430+
optimization_trace = inference_trace
431+
else:
432+
optimization_trace = oracle_trace
433+
inference_trace = None
434434

435435
score_trace = compute_score_trace(
436436
optimization_trace=optimization_trace,
@@ -507,7 +507,7 @@ def run_optimization_with_orchestrator(
507507

508508
orchestrator = Orchestrator(
509509
experiment=experiment,
510-
generation_strategy=method.generation_strategy.clone_reset(),
510+
generation_strategy=method.generation_strategy,
511511
options=orchestrator_options,
512512
)
513513

@@ -562,6 +562,8 @@ def benchmark_replication(
562562
Return:
563563
``BenchmarkResult`` object.
564564
"""
565+
# Reset the generation strategy to ensure that it is in an unused state.
566+
method.generation_strategy = method.generation_strategy.clone_reset()
565567
experiment = run_optimization_with_orchestrator(
566568
problem=problem,
567569
method=method,

ax/benchmark/benchmark_result.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ class BenchmarkResult(Base):
5252
points based only on data that would be observable in realistic
5353
settings, as specified by `BenchmarkMethod.get_best_parameters`, and
5454
then evaluating the oracle objective value of that point according
55-
to the problem's `OptimizationConfig`.
55+
to the problem's `OptimizationConfig`. Only reported if
56+
report_inference_value_as_trace is enabled in the BenchmarkProblem.
5657
5758
By default, if it is not overridden,
5859
`BenchmarkMethod.get_best_parameters` uses the empirical best point
@@ -96,7 +97,7 @@ class BenchmarkResult(Base):
9697
seed: int
9798

9899
oracle_trace: npt.NDArray
99-
inference_trace: npt.NDArray
100+
inference_trace: npt.NDArray | None
100101
optimization_trace: npt.NDArray
101102
score_trace: npt.NDArray
102103
cost_trace: npt.NDArray

ax/benchmark/testing/benchmark_stubs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ def get_async_benchmark_problem(
352352
step_runtime_fn: TBenchmarkStepRuntimeFunction | None = None,
353353
n_steps: int = 1,
354354
lower_is_better: bool = False,
355+
report_inference_value_as_trace: bool = False,
355356
) -> BenchmarkProblem:
356357
search_space = get_discrete_search_space()
357358
test_function = IdentityTestFunction(n_steps=n_steps)
@@ -371,6 +372,7 @@ def get_async_benchmark_problem(
371372
baseline_value=19 if lower_is_better else 0,
372373
optimal_value=0 if lower_is_better else 19,
373374
step_runtime_function=step_runtime_fn,
375+
report_inference_value_as_trace=report_inference_value_as_trace,
374376
)
375377

376378

ax/benchmark/tests/test_benchmark.py

Lines changed: 58 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import numpy as np
1919
import torch
20-
from ax.adapter.factory import get_sobol
2120
from ax.adapter.registry import Generators
2221
from ax.benchmark.benchmark import (
2322
_get_oracle_value_of_params,
@@ -38,7 +37,6 @@
3837
from ax.benchmark.benchmark_problem import (
3938
BenchmarkProblem,
4039
create_problem_from_botorch,
41-
get_continuous_search_space,
4240
get_moo_opt_config,
4341
get_soo_opt_config,
4442
)
@@ -82,7 +80,10 @@
8280
from ax.utils.common.mock import mock_patch_method_original
8381
from ax.utils.common.testutils import TestCase
8482

85-
from ax.utils.testing.core_stubs import get_experiment_with_observations
83+
from ax.utils.testing.core_stubs import (
84+
get_branin_experiment,
85+
get_branin_experiment_with_multi_objective,
86+
)
8687
from ax.utils.testing.mock import mock_botorch_optimize
8788
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
8889
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
@@ -261,7 +262,9 @@ def test_replication_sobol_surrogate(self) -> None:
261262
self.assertTrue(np.isfinite(res.score_trace).all())
262263
self.assertTrue(np.all(res.score_trace <= 100))
263264

264-
def _test_replication_async(self, map_data: bool) -> None:
265+
def _test_replication_async(
266+
self, map_data: bool, report_inference_value_as_trace: bool
267+
) -> None:
265268
"""
266269
The test function is the identity function, higher is better, observed
267270
to be noiseless, and the same at every point on the trajectory. And the
@@ -349,6 +352,7 @@ def _test_replication_async(self, map_data: bool) -> None:
349352
problem = get_async_benchmark_problem(
350353
map_data=map_data,
351354
step_runtime_fn=step_runtime_fn,
355+
report_inference_value_as_trace=report_inference_value_as_trace,
352356
)
353357

354358
with mock_patch_method_original(
@@ -417,12 +421,15 @@ def _test_replication_async(self, map_data: bool) -> None:
417421
},
418422
f"Failure for trial {trial_index} with {case_name}",
419423
)
420-
self.assertFalse(np.isnan(result.inference_trace).any())
421-
self.assertEqual(
422-
result.inference_trace.tolist(),
423-
expected_traces[case_name],
424-
msg=case_name,
425-
)
424+
if report_inference_value_as_trace:
425+
self.assertFalse(np.isnan(result.inference_trace).any())
426+
self.assertEqual(
427+
result.inference_trace.tolist(),
428+
expected_traces[case_name],
429+
msg=case_name,
430+
)
431+
else:
432+
self.assertIsNone(result.inference_trace)
426433
self.assertEqual(
427434
result.oracle_trace.tolist(),
428435
expected_traces[case_name],
@@ -466,8 +473,15 @@ def _test_replication_async(self, map_data: bool) -> None:
466473
self.assertEqual(completed_times, expected_completed_times)
467474

468475
def test_replication_async(self) -> None:
469-
self._test_replication_async(map_data=False)
470-
self._test_replication_async(map_data=True)
476+
self._test_replication_async(
477+
map_data=False, report_inference_value_as_trace=False
478+
)
479+
self._test_replication_async(
480+
map_data=True, report_inference_value_as_trace=False
481+
)
482+
self._test_replication_async(
483+
map_data=False, report_inference_value_as_trace=True
484+
)
471485

472486
def test_run_optimization_with_orchestrator(self) -> None:
473487
method = get_async_benchmark_method()
@@ -491,6 +505,7 @@ def test_run_optimization_with_orchestrator(self) -> None:
491505
none_throws(runner.simulated_backend_runner).simulator._verbose_logging
492506
)
493507

508+
method.generation_strategy = method.generation_strategy.clone_reset()
494509
with self.subTest("Logs not produced by default"), self.assertNoLogs(
495510
level=logging.INFO, logger=logger
496511
), self.assertNoLogs(logger=logger):
@@ -618,9 +633,9 @@ def test_early_stopping(self) -> None:
618633
self.assertEqual(max_run, {0: 4, 1: 2, 2: 2, 3: 2})
619634

620635
def test_replication_variable_runtime(self) -> None:
621-
method = get_async_benchmark_method(max_pending_trials=1)
622636
for map_data in [False, True]:
623637
with self.subTest(map_data=map_data):
638+
method = get_async_benchmark_method(max_pending_trials=1)
624639
problem = get_async_benchmark_problem(
625640
map_data=map_data,
626641
step_runtime_fn=lambda params: params["x0"] + 1,
@@ -652,9 +667,7 @@ def test_replication_variable_runtime(self) -> None:
652667
self.assertEqual(start_times, expected_start_times)
653668

654669
@mock_botorch_optimize
655-
def _test_replication_with_inference_value(
656-
self, batch_size: int, report_inference_value_as_trace: bool
657-
) -> None:
670+
def _test_replication_with_inference_value(self, batch_size: int) -> None:
658671
seed = 1
659672
method = get_sobol_botorch_modular_acquisition(
660673
model_cls=SingleTaskGP,
@@ -667,35 +680,29 @@ def _test_replication_with_inference_value(
667680
num_trials = 4
668681
problem = get_single_objective_benchmark_problem(
669682
num_trials=num_trials,
670-
report_inference_value_as_trace=report_inference_value_as_trace,
683+
report_inference_value_as_trace=True,
671684
noise_std=100.0,
672685
)
673686
res = self.benchmark_replication(problem=problem, method=method, seed=seed)
674687
# The inference trace could coincide with the oracle trace, but it won't
675688
# happen in this example with high noise and a seed
676-
self.assertEqual(
677-
np.equal(res.inference_trace, res.optimization_trace).all(),
678-
report_inference_value_as_trace,
689+
self.assertTrue(
690+
np.equal(none_throws(res.inference_trace), res.optimization_trace).all(),
679691
)
680-
self.assertEqual(
692+
self.assertFalse(
681693
np.equal(res.oracle_trace, res.optimization_trace).all(),
682-
not report_inference_value_as_trace,
683694
)
684695

685696
self.assertEqual(res.optimization_trace.shape, (problem.num_trials,))
686-
self.assertTrue((res.inference_trace >= res.oracle_trace).all())
697+
self.assertTrue((none_throws(res.inference_trace) >= res.oracle_trace).all())
687698

688699
def test_replication_with_inference_value(self) -> None:
689-
for batch_size, report_inference_value_as_trace in product(
690-
[1, 2], [False, True]
691-
):
700+
for batch_size in [1, 2]:
692701
with self.subTest(
693702
batch_size=batch_size,
694-
report_inference_value_as_trace=report_inference_value_as_trace,
695703
):
696704
self._test_replication_with_inference_value(
697705
batch_size=batch_size,
698-
report_inference_value_as_trace=report_inference_value_as_trace,
699706
)
700707

701708
with self.assertRaisesRegex(
@@ -793,7 +800,11 @@ def test_replication_mbm(self) -> None:
793800
acquisition_cls=qLogNoisyExpectedImprovement,
794801
distribute_replications=False,
795802
),
796-
get_augmented_branin_problem(fidelity_or_task="fidelity"),
803+
get_single_objective_benchmark_problem(
804+
observe_noise_sd=False,
805+
num_trials=6,
806+
report_inference_value_as_trace=True,
807+
),
797808
"MBM::SingleTaskGP_qLogNEI",
798809
),
799810
]:
@@ -827,9 +838,7 @@ def test_replication_moo_sobol(self) -> None:
827838

828839
self.assertTrue(np.all(res.score_trace <= 100))
829840
self.assertEqual(len(res.cost_trace), problem.num_trials)
830-
self.assertEqual(len(res.inference_trace), problem.num_trials)
831-
# since inference trace is not supported for MOO, it should be all NaN
832-
self.assertTrue(np.isnan(res.inference_trace).all())
841+
self.assertIsNone(res.inference_trace)
833842

834843
def test_benchmark_one_method_problem(self) -> None:
835844
problem = get_single_objective_benchmark_problem()
@@ -1196,6 +1205,7 @@ def test_get_opt_trace_by_cumulative_epochs(self) -> None:
11961205
):
11971206
get_opt_trace_by_steps(experiment=experiment)
11981207

1208+
method.generation_strategy = method.generation_strategy.clone_reset()
11991209
with self.subTest("Constrained"):
12001210
problem = get_benchmark_problem("constrained_gramacy_observed_noise")
12011211
experiment = self.run_optimization_with_orchestrator(
@@ -1237,72 +1247,28 @@ def test_get_benchmark_result_with_cumulative_steps(self) -> None:
12371247
self.assertLessEqual(transformed.score_trace.min(), result.score_trace.min())
12381248

12391249
def test_get_best_parameters(self) -> None:
1240-
"""
1241-
Whether this produces the correct values is tested more thoroughly in
1242-
other tests such as `test_replication_with_inference_value` and
1243-
`test_get_inference_trace_from_params`. Setting up an experiment with
1244-
data and trials without just running a benchmark is a pain, so in those
1245-
tests, we just run a benchmark.
1246-
"""
1247-
gs = get_sobol_generation_strategy()
1248-
1249-
search_space = get_continuous_search_space(bounds=[(0, 1)])
1250-
moo_config = get_moo_opt_config(outcome_names=["a", "b"], ref_point=[0, 0])
1251-
experiment = Experiment(
1252-
name="test",
1253-
is_test=True,
1254-
search_space=search_space,
1255-
optimization_config=moo_config,
1250+
experiment = get_branin_experiment()
1251+
generation_strategy = get_sobol_generation_strategy()
1252+
mock_function = (
1253+
"ax.service.utils.best_point."
1254+
"get_best_parameters_from_model_predictions_with_trial_index"
12561255
)
12571256

1258-
with self.subTest("MOO not supported"), self.assertRaisesRegex(
1259-
NotImplementedError, "Please use `get_pareto_optimal_parameters`"
1260-
):
1261-
get_best_parameters(experiment=experiment, generation_strategy=gs)
1262-
1263-
soo_config = get_soo_opt_config(outcome_names=["a"])
1264-
with self.subTest("Empty experiment"):
1265-
result = get_best_parameters(
1266-
experiment=experiment.clone_with(optimization_config=soo_config),
1267-
generation_strategy=gs,
1268-
)
1257+
with patch(mock_function, return_value=None):
1258+
result = get_best_parameters(experiment, generation_strategy)
12691259
self.assertIsNone(result)
12701260

1271-
with self.subTest("All constraints violated"):
1272-
experiment = get_experiment_with_observations(
1273-
observations=[[1, -1], [2, -1]],
1274-
constrained=True,
1275-
)
1276-
best_point = get_best_parameters(
1277-
experiment=experiment, generation_strategy=gs
1278-
)
1279-
self.assertIsNone(best_point)
1280-
1281-
with self.subTest("No completed trials"):
1282-
experiment = get_experiment_with_observations(observations=[])
1283-
sobol_generator = get_sobol(search_space=experiment.search_space)
1284-
for _ in range(3):
1285-
trial = experiment.new_trial(generator_run=sobol_generator.gen(n=1))
1286-
trial.run()
1287-
best_point = get_best_parameters(
1288-
experiment=experiment, generation_strategy=gs
1289-
)
1290-
self.assertIsNone(best_point)
1261+
with patch(mock_function, return_value=(0, {"x": 1.0}, None)):
1262+
result = get_best_parameters(experiment, generation_strategy)
1263+
self.assertEqual(result, {"x": 1.0})
12911264

1292-
experiment = get_experiment_with_observations(
1293-
observations=[[1], [2]], constrained=False
1294-
)
1295-
with self.subTest("Working case"):
1296-
best_point = get_best_parameters(
1297-
experiment=experiment, generation_strategy=gs
1298-
)
1299-
self.assertEqual(best_point, experiment.trials[1].arms[0].parameters)
1300-
1301-
with self.subTest("Trial indices"):
1302-
best_point = get_best_parameters(
1303-
experiment=experiment, generation_strategy=gs, trial_indices=[0]
1265+
with self.subTest("MOO not supported"), self.assertRaisesRegex(
1266+
NotImplementedError, "Please use `get_pareto_optimal_parameters`"
1267+
):
1268+
experiment = get_branin_experiment_with_multi_objective()
1269+
get_best_parameters(
1270+
experiment=experiment, generation_strategy=generation_strategy
13041271
)
1305-
self.assertEqual(best_point, experiment.trials[0].arms[0].parameters)
13061272

13071273
def test_get_benchmark_result_from_experiment_and_gs(self) -> None:
13081274
problem = get_single_objective_benchmark_problem()

0 commit comments

Comments
 (0)