17
17
18
18
import numpy as np
19
19
import torch
20
- from ax .adapter .factory import get_sobol
21
20
from ax .adapter .registry import Generators
22
21
from ax .benchmark .benchmark import (
23
22
_get_oracle_value_of_params ,
38
37
from ax .benchmark .benchmark_problem import (
39
38
BenchmarkProblem ,
40
39
create_problem_from_botorch ,
41
- get_continuous_search_space ,
42
40
get_moo_opt_config ,
43
41
get_soo_opt_config ,
44
42
)
82
80
from ax .utils .common .mock import mock_patch_method_original
83
81
from ax .utils .common .testutils import TestCase
84
82
85
- from ax .utils .testing .core_stubs import get_experiment_with_observations
83
+ from ax .utils .testing .core_stubs import (
84
+ get_branin_experiment ,
85
+ get_branin_experiment_with_multi_objective ,
86
+ )
86
87
from ax .utils .testing .mock import mock_botorch_optimize
87
88
from botorch .acquisition .knowledge_gradient import qKnowledgeGradient
88
89
from botorch .acquisition .logei import qLogNoisyExpectedImprovement
@@ -261,7 +262,9 @@ def test_replication_sobol_surrogate(self) -> None:
261
262
self .assertTrue (np .isfinite (res .score_trace ).all ())
262
263
self .assertTrue (np .all (res .score_trace <= 100 ))
263
264
264
- def _test_replication_async (self , map_data : bool ) -> None :
265
+ def _test_replication_async (
266
+ self , map_data : bool , report_inference_value_as_trace : bool
267
+ ) -> None :
265
268
"""
266
269
The test function is the identity function, higher is better, observed
267
270
to be noiseless, and the same at every point on the trajectory. And the
@@ -349,6 +352,7 @@ def _test_replication_async(self, map_data: bool) -> None:
349
352
problem = get_async_benchmark_problem (
350
353
map_data = map_data ,
351
354
step_runtime_fn = step_runtime_fn ,
355
+ report_inference_value_as_trace = report_inference_value_as_trace ,
352
356
)
353
357
354
358
with mock_patch_method_original (
@@ -417,12 +421,15 @@ def _test_replication_async(self, map_data: bool) -> None:
417
421
},
418
422
f"Failure for trial { trial_index } with { case_name } " ,
419
423
)
420
- self .assertFalse (np .isnan (result .inference_trace ).any ())
421
- self .assertEqual (
422
- result .inference_trace .tolist (),
423
- expected_traces [case_name ],
424
- msg = case_name ,
425
- )
424
+ if report_inference_value_as_trace :
425
+ self .assertFalse (np .isnan (result .inference_trace ).any ())
426
+ self .assertEqual (
427
+ result .inference_trace .tolist (),
428
+ expected_traces [case_name ],
429
+ msg = case_name ,
430
+ )
431
+ else :
432
+ self .assertIsNone (result .inference_trace )
426
433
self .assertEqual (
427
434
result .oracle_trace .tolist (),
428
435
expected_traces [case_name ],
@@ -466,8 +473,15 @@ def _test_replication_async(self, map_data: bool) -> None:
466
473
self .assertEqual (completed_times , expected_completed_times )
467
474
468
475
def test_replication_async (self ) -> None :
469
- self ._test_replication_async (map_data = False )
470
- self ._test_replication_async (map_data = True )
476
+ self ._test_replication_async (
477
+ map_data = False , report_inference_value_as_trace = False
478
+ )
479
+ self ._test_replication_async (
480
+ map_data = True , report_inference_value_as_trace = False
481
+ )
482
+ self ._test_replication_async (
483
+ map_data = False , report_inference_value_as_trace = True
484
+ )
471
485
472
486
def test_run_optimization_with_orchestrator (self ) -> None :
473
487
method = get_async_benchmark_method ()
@@ -491,6 +505,7 @@ def test_run_optimization_with_orchestrator(self) -> None:
491
505
none_throws (runner .simulated_backend_runner ).simulator ._verbose_logging
492
506
)
493
507
508
+ method .generation_strategy = method .generation_strategy .clone_reset ()
494
509
with self .subTest ("Logs not produced by default" ), self .assertNoLogs (
495
510
level = logging .INFO , logger = logger
496
511
), self .assertNoLogs (logger = logger ):
@@ -618,9 +633,9 @@ def test_early_stopping(self) -> None:
618
633
self .assertEqual (max_run , {0 : 4 , 1 : 2 , 2 : 2 , 3 : 2 })
619
634
620
635
def test_replication_variable_runtime (self ) -> None :
621
- method = get_async_benchmark_method (max_pending_trials = 1 )
622
636
for map_data in [False , True ]:
623
637
with self .subTest (map_data = map_data ):
638
+ method = get_async_benchmark_method (max_pending_trials = 1 )
624
639
problem = get_async_benchmark_problem (
625
640
map_data = map_data ,
626
641
step_runtime_fn = lambda params : params ["x0" ] + 1 ,
@@ -652,9 +667,7 @@ def test_replication_variable_runtime(self) -> None:
652
667
self .assertEqual (start_times , expected_start_times )
653
668
654
669
@mock_botorch_optimize
655
- def _test_replication_with_inference_value (
656
- self , batch_size : int , report_inference_value_as_trace : bool
657
- ) -> None :
670
+ def _test_replication_with_inference_value (self , batch_size : int ) -> None :
658
671
seed = 1
659
672
method = get_sobol_botorch_modular_acquisition (
660
673
model_cls = SingleTaskGP ,
@@ -667,35 +680,29 @@ def _test_replication_with_inference_value(
667
680
num_trials = 4
668
681
problem = get_single_objective_benchmark_problem (
669
682
num_trials = num_trials ,
670
- report_inference_value_as_trace = report_inference_value_as_trace ,
683
+ report_inference_value_as_trace = True ,
671
684
noise_std = 100.0 ,
672
685
)
673
686
res = self .benchmark_replication (problem = problem , method = method , seed = seed )
674
687
# The inference trace could coincide with the oracle trace, but it won't
675
688
# happen in this example with high noise and a seed
676
- self .assertEqual (
677
- np .equal (res .inference_trace , res .optimization_trace ).all (),
678
- report_inference_value_as_trace ,
689
+ self .assertTrue (
690
+ np .equal (none_throws (res .inference_trace ), res .optimization_trace ).all (),
679
691
)
680
- self .assertEqual (
692
+ self .assertFalse (
681
693
np .equal (res .oracle_trace , res .optimization_trace ).all (),
682
- not report_inference_value_as_trace ,
683
694
)
684
695
685
696
self .assertEqual (res .optimization_trace .shape , (problem .num_trials ,))
686
- self .assertTrue ((res .inference_trace >= res .oracle_trace ).all ())
697
+ self .assertTrue ((none_throws ( res .inference_trace ) >= res .oracle_trace ).all ())
687
698
688
699
def test_replication_with_inference_value (self ) -> None :
689
- for batch_size , report_inference_value_as_trace in product (
690
- [1 , 2 ], [False , True ]
691
- ):
700
+ for batch_size in [1 , 2 ]:
692
701
with self .subTest (
693
702
batch_size = batch_size ,
694
- report_inference_value_as_trace = report_inference_value_as_trace ,
695
703
):
696
704
self ._test_replication_with_inference_value (
697
705
batch_size = batch_size ,
698
- report_inference_value_as_trace = report_inference_value_as_trace ,
699
706
)
700
707
701
708
with self .assertRaisesRegex (
@@ -793,7 +800,11 @@ def test_replication_mbm(self) -> None:
793
800
acquisition_cls = qLogNoisyExpectedImprovement ,
794
801
distribute_replications = False ,
795
802
),
796
- get_augmented_branin_problem (fidelity_or_task = "fidelity" ),
803
+ get_single_objective_benchmark_problem (
804
+ observe_noise_sd = False ,
805
+ num_trials = 6 ,
806
+ report_inference_value_as_trace = True ,
807
+ ),
797
808
"MBM::SingleTaskGP_qLogNEI" ,
798
809
),
799
810
]:
@@ -827,9 +838,7 @@ def test_replication_moo_sobol(self) -> None:
827
838
828
839
self .assertTrue (np .all (res .score_trace <= 100 ))
829
840
self .assertEqual (len (res .cost_trace ), problem .num_trials )
830
- self .assertEqual (len (res .inference_trace ), problem .num_trials )
831
- # since inference trace is not supported for MOO, it should be all NaN
832
- self .assertTrue (np .isnan (res .inference_trace ).all ())
841
+ self .assertIsNone (res .inference_trace )
833
842
834
843
def test_benchmark_one_method_problem (self ) -> None :
835
844
problem = get_single_objective_benchmark_problem ()
@@ -1196,6 +1205,7 @@ def test_get_opt_trace_by_cumulative_epochs(self) -> None:
1196
1205
):
1197
1206
get_opt_trace_by_steps (experiment = experiment )
1198
1207
1208
+ method .generation_strategy = method .generation_strategy .clone_reset ()
1199
1209
with self .subTest ("Constrained" ):
1200
1210
problem = get_benchmark_problem ("constrained_gramacy_observed_noise" )
1201
1211
experiment = self .run_optimization_with_orchestrator (
@@ -1237,72 +1247,28 @@ def test_get_benchmark_result_with_cumulative_steps(self) -> None:
1237
1247
self .assertLessEqual (transformed .score_trace .min (), result .score_trace .min ())
1238
1248
1239
1249
def test_get_best_parameters (self ) -> None :
1240
- """
1241
- Whether this produces the correct values is tested more thoroughly in
1242
- other tests such as `test_replication_with_inference_value` and
1243
- `test_get_inference_trace_from_params`. Setting up an experiment with
1244
- data and trials without just running a benchmark is a pain, so in those
1245
- tests, we just run a benchmark.
1246
- """
1247
- gs = get_sobol_generation_strategy ()
1248
-
1249
- search_space = get_continuous_search_space (bounds = [(0 , 1 )])
1250
- moo_config = get_moo_opt_config (outcome_names = ["a" , "b" ], ref_point = [0 , 0 ])
1251
- experiment = Experiment (
1252
- name = "test" ,
1253
- is_test = True ,
1254
- search_space = search_space ,
1255
- optimization_config = moo_config ,
1250
+ experiment = get_branin_experiment ()
1251
+ generation_strategy = get_sobol_generation_strategy ()
1252
+ mock_function = (
1253
+ "ax.service.utils.best_point."
1254
+ "get_best_parameters_from_model_predictions_with_trial_index"
1256
1255
)
1257
1256
1258
- with self .subTest ("MOO not supported" ), self .assertRaisesRegex (
1259
- NotImplementedError , "Please use `get_pareto_optimal_parameters`"
1260
- ):
1261
- get_best_parameters (experiment = experiment , generation_strategy = gs )
1262
-
1263
- soo_config = get_soo_opt_config (outcome_names = ["a" ])
1264
- with self .subTest ("Empty experiment" ):
1265
- result = get_best_parameters (
1266
- experiment = experiment .clone_with (optimization_config = soo_config ),
1267
- generation_strategy = gs ,
1268
- )
1257
+ with patch (mock_function , return_value = None ):
1258
+ result = get_best_parameters (experiment , generation_strategy )
1269
1259
self .assertIsNone (result )
1270
1260
1271
- with self .subTest ("All constraints violated" ):
1272
- experiment = get_experiment_with_observations (
1273
- observations = [[1 , - 1 ], [2 , - 1 ]],
1274
- constrained = True ,
1275
- )
1276
- best_point = get_best_parameters (
1277
- experiment = experiment , generation_strategy = gs
1278
- )
1279
- self .assertIsNone (best_point )
1280
-
1281
- with self .subTest ("No completed trials" ):
1282
- experiment = get_experiment_with_observations (observations = [])
1283
- sobol_generator = get_sobol (search_space = experiment .search_space )
1284
- for _ in range (3 ):
1285
- trial = experiment .new_trial (generator_run = sobol_generator .gen (n = 1 ))
1286
- trial .run ()
1287
- best_point = get_best_parameters (
1288
- experiment = experiment , generation_strategy = gs
1289
- )
1290
- self .assertIsNone (best_point )
1261
+ with patch (mock_function , return_value = (0 , {"x" : 1.0 }, None )):
1262
+ result = get_best_parameters (experiment , generation_strategy )
1263
+ self .assertEqual (result , {"x" : 1.0 })
1291
1264
1292
- experiment = get_experiment_with_observations (
1293
- observations = [[1 ], [2 ]], constrained = False
1294
- )
1295
- with self .subTest ("Working case" ):
1296
- best_point = get_best_parameters (
1297
- experiment = experiment , generation_strategy = gs
1298
- )
1299
- self .assertEqual (best_point , experiment .trials [1 ].arms [0 ].parameters )
1300
-
1301
- with self .subTest ("Trial indices" ):
1302
- best_point = get_best_parameters (
1303
- experiment = experiment , generation_strategy = gs , trial_indices = [0 ]
1265
+ with self .subTest ("MOO not supported" ), self .assertRaisesRegex (
1266
+ NotImplementedError , "Please use `get_pareto_optimal_parameters`"
1267
+ ):
1268
+ experiment = get_branin_experiment_with_multi_objective ()
1269
+ get_best_parameters (
1270
+ experiment = experiment , generation_strategy = generation_strategy
1304
1271
)
1305
- self .assertEqual (best_point , experiment .trials [0 ].arms [0 ].parameters )
1306
1272
1307
1273
def test_get_benchmark_result_from_experiment_and_gs (self ) -> None :
1308
1274
problem = get_single_objective_benchmark_problem ()
0 commit comments