diff --git a/ax/api/tests/test_client.py b/ax/api/tests/test_client.py index 935081ed718..664765eedd2 100644 --- a/ax/api/tests/test_client.py +++ b/ax/api/tests/test_client.py @@ -489,6 +489,14 @@ def test_attach_data(self) -> None: ), ) + # With NaN / Inf values. + for value in [float("nan"), float("inf"), float("-inf")]: + with self.assertRaisesRegex(ValueError, "null or inf values"): + client.attach_data( + trial_index=trial_index, + raw_data={"foo": (value, 0.0), "bar": (0.5, 0.0)}, + ) + def test_complete_trial(self) -> None: client = Client() diff --git a/ax/core/data.py b/ax/core/data.py index f645c90b0e7..ee04c29f3df 100644 --- a/ax/core/data.py +++ b/ax/core/data.py @@ -107,7 +107,7 @@ def __init__( raise ValueError(f"Columns {list(extra_columns)} are not supported.") df = df.dropna(axis=0, how="all", ignore_index=True) df = self._safecast_df(df=df) - + self._check_for_nan_inf(df=df) # Reorder the columns for easier viewing col_order = [c for c in self.column_data_types() if c in df.columns] self._df = df.reindex(columns=col_order, copy=False) @@ -149,6 +149,14 @@ def _safecast_df( df[col] = df[col].astype(dtype) return df + def _check_for_nan_inf(self, df: pd.DataFrame) -> None: + """Check for NaNs or infs in the "mean" column of the dataframe.""" + if not (mask := np.isfinite(df["mean"])).all(): + raise ValueError( + "Data contains null or inf values for the mean. " + f"Invalid rows: {df[~mask]}" + ) + def required_columns(self) -> set[str]: """Names of columns that must be present in the underlying ``DataFrame``.""" return self.REQUIRED_COLUMNS diff --git a/ax/core/map_data.py b/ax/core/map_data.py index 9d3ee5b0998..2a5bd80a3ad 100644 --- a/ax/core/map_data.py +++ b/ax/core/map_data.py @@ -165,7 +165,7 @@ def __init__( self._map_df = self._safecast_df( df=df, extra_column_types=self.map_key_to_type ) - + self._check_for_nan_inf(df=self._map_df) col_order = [ c for c in self.column_data_types(extra_column_types=self.map_key_to_type) diff --git a/ax/core/tests/test_data.py b/ax/core/tests/test_data.py index c523a5f6eb4..6e243c6f657 100644 --- a/ax/core/tests/test_data.py +++ b/ax/core/tests/test_data.py @@ -144,11 +144,27 @@ def test_clone(self) -> None: self.assertIsNot(data.df, data_clone.df) self.assertIsNone(data_clone._db_id) - def test_BadData(self) -> None: + def test_bad_data(self) -> None: df = pd.DataFrame([{"bad_field": "0_0", "bad_field_2": {"x": 0, "y": "a"}}]) with self.assertRaises(ValueError): Data(df=df) + # Invalid mean values. + for value in [None, float("nan"), float("inf"), float("-inf")]: + df = pd.DataFrame( + [ + { + "trial_index": 0, + "arm_name": "arm", + "metric_name": "metric", + "mean": value, + "sem": 0.0, + } + ] + ) + with self.assertRaisesRegex(ValueError, "null or inf values"): + Data(df=df) + def test_EmptyData(self) -> None: df = Data().df self.assertTrue(df.empty) diff --git a/ax/core/tests/test_map_data.py b/ax/core/tests/test_map_data.py index 278bcb6ce06..f50aa92b448 100644 --- a/ax/core/tests/test_map_data.py +++ b/ax/core/tests/test_map_data.py @@ -122,6 +122,27 @@ def test_properties(self) -> None: self.assertEqual(self.mmd.map_keys, ["epoch"]) self.assertEqual(self.mmd.map_key_to_type, {"epoch": int}) + def test_bad_data(self) -> None: + # Invalid mean values. + for value in [None, float("nan"), float("inf"), float("-inf")]: + df = pd.DataFrame( + [ + { + "trial_index": 0, + "arm_name": "arm", + "metric_name": "metric", + "mean": value, + "sem": 0.0, + "epoch": 0, + } + ] + ) + with self.assertRaisesRegex(ValueError, "null or inf values"): + MapData( + df=df, + map_key_infos=[MapKeyInfo(key="epoch", default_value=0.0)], + ) + def test_clone(self) -> None: self.mmd._db_id = 1234 clone = self.mmd.clone() diff --git a/ax/core/tests/test_utils.py b/ax/core/tests/test_utils.py index 2ee703d7ec8..5097347c4d7 100644 --- a/ax/core/tests/test_utils.py +++ b/ax/core/tests/test_utils.py @@ -114,15 +114,6 @@ def setUp(self) -> None: "start_time": "2018-01-01", "end_time": "2018-01-02", }, - { - "arm_name": "0_1", - "mean": float("nan"), - "sem": float("nan"), - "trial_index": 1, - "metric_name": "a", - "start_time": "2018-01-01", - "end_time": "2018-01-02", - }, { "arm_name": "0_1", "mean": 3.7, @@ -143,17 +134,8 @@ def setUp(self) -> None: }, { "arm_name": "0_2", - "mean": float("nan"), - "sem": float("nan"), - "trial_index": 1, - "metric_name": "b", - "start_time": "2018-01-01", - "end_time": "2018-01-02", - }, - { - "arm_name": "0_2", - "mean": float("nan"), - "sem": float("nan"), + "mean": 0.2, + "sem": None, "trial_index": 1, "metric_name": "c", "start_time": "2018-01-01", @@ -185,7 +167,7 @@ def test_get_missing_metrics(self) -> None: expected = MissingMetrics( {"a": {("0_1", 1)}}, {"b": {("0_2", 1)}}, - {"c": {("0_0", 1), ("0_1", 1), ("0_2", 1)}}, + {"c": {("0_0", 1), ("0_1", 1)}}, ) actual = get_missing_metrics(self.data, self.optimization_config) self.assertEqual(actual, expected) diff --git a/ax/core/utils.py b/ax/core/utils.py index 47a954e62b8..51c98015a9e 100644 --- a/ax/core/utils.py +++ b/ax/core/utils.py @@ -652,16 +652,14 @@ def extract_map_keys_from_opt_config( # -------------------- Context manager and decorator utils. --------------------- -# pyre-ignore[3]: Allowing `Any` in this case def batch_trial_only(msg: str | None = None) -> Callable[..., Any]: """A decorator to verify that the value passed to the `trial` argument to `func` is a `BatchTrial`. """ - # pyre-ignore[2,3]: Allowing `Any` in this case def batch_trial_only_decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) - def _batch_trial_only(*args: Any, **kwargs: Any) -> Any: # pyre-ignore[3] + def _batch_trial_only(*args: Any, **kwargs: Any) -> Any: if "trial" not in kwargs: raise AxError( f"Expected a keyword argument `trial` to `{func.__name__}`." diff --git a/ax/service/tests/test_best_point_utils.py b/ax/service/tests/test_best_point_utils.py index be60c46d3fe..f9a05db9d45 100644 --- a/ax/service/tests/test_best_point_utils.py +++ b/ax/service/tests/test_best_point_utils.py @@ -631,7 +631,7 @@ def test_is_row_feasible(self) -> None: [3, 3, -1], [2, 4, 1], [2, 0, 1], - # adding this to an otherwise feasible observation to test nan handling + # An otherwise feasible observation with missing metric. [2, 0, np.nan], ], constrained=True, @@ -640,10 +640,10 @@ def test_is_row_feasible(self) -> None: df=exp.lookup_data().df, optimization_config=none_throws(exp.optimization_config), ) - expected_per_arm = [False, True, False, True, True, False] + expected_per_arm = [False, True, False, True, True, True] expected_series = _repeat_elements( list_to_replicate=expected_per_arm, n_repeats=3 - ) + )[:-1] # Remove the last missing entry. pd.testing.assert_series_equal( feasible_series, expected_series, check_names=False ) @@ -669,7 +669,7 @@ def test_is_row_feasible(self) -> None: expected_per_arm = [False, True, True, True, True, True] expected_series = _repeat_elements( list_to_replicate=expected_per_arm, n_repeats=3 - ) + )[:-1] pd.testing.assert_series_equal( feasible_series, expected_series, check_names=False ) @@ -694,7 +694,7 @@ def test_is_row_feasible(self) -> None: expected_per_arm = [True, True, False, True, False, False] expected_series = _repeat_elements( list_to_replicate=expected_per_arm, n_repeats=3 - ) + )[:-1] pd.testing.assert_series_equal( feasible_series, expected_series, check_names=False ) diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 7fe09f5acc9..026a81ca016 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -15,7 +15,7 @@ from datetime import datetime, timedelta from functools import partial from logging import Logger -from math import prod +from math import isfinite, prod from pathlib import Path from typing import Any, cast, Sequence @@ -1113,6 +1113,7 @@ def get_experiment_with_observations( "trial_index": trial.index, } for m, o, s in zip(metrics, obs_i, sems_i, strict=True) + if isfinite(o) ] ) )