Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ax/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,14 @@ def test_attach_data(self) -> None:
),
)

# With NaN / Inf values.
for value in [float("nan"), float("inf"), float("-inf")]:
with self.assertRaisesRegex(ValueError, "null or inf values"):
client.attach_data(
trial_index=trial_index,
raw_data={"foo": (value, 0.0), "bar": (0.5, 0.0)},
)

def test_complete_trial(self) -> None:
client = Client()

Expand Down
10 changes: 9 additions & 1 deletion ax/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
raise ValueError(f"Columns {list(extra_columns)} are not supported.")
df = df.dropna(axis=0, how="all", ignore_index=True)
df = self._safecast_df(df=df)

self._check_for_nan_inf(df=df)
# Reorder the columns for easier viewing
col_order = [c for c in self.column_data_types() if c in df.columns]
self._df = df.reindex(columns=col_order, copy=False)
Expand Down Expand Up @@ -149,6 +149,14 @@ def _safecast_df(
df[col] = df[col].astype(dtype)
return df

def _check_for_nan_inf(self, df: pd.DataFrame) -> None:
"""Check for NaNs or infs in the "mean" column of the dataframe."""
if not (mask := np.isfinite(df["mean"])).all():
raise ValueError(
"Data contains null or inf values for the mean. "
f"Invalid rows: {df[~mask]}"
)

def required_columns(self) -> set[str]:
"""Names of columns that must be present in the underlying ``DataFrame``."""
return self.REQUIRED_COLUMNS
Expand Down
2 changes: 1 addition & 1 deletion ax/core/map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(
self._map_df = self._safecast_df(
df=df, extra_column_types=self.map_key_to_type
)

self._check_for_nan_inf(df=self._map_df)
col_order = [
c
for c in self.column_data_types(extra_column_types=self.map_key_to_type)
Expand Down
18 changes: 17 additions & 1 deletion ax/core/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,27 @@ def test_clone(self) -> None:
self.assertIsNot(data.df, data_clone.df)
self.assertIsNone(data_clone._db_id)

def test_BadData(self) -> None:
def test_bad_data(self) -> None:
df = pd.DataFrame([{"bad_field": "0_0", "bad_field_2": {"x": 0, "y": "a"}}])
with self.assertRaises(ValueError):
Data(df=df)

# Invalid mean values.
for value in [None, float("nan"), float("inf"), float("-inf")]:
df = pd.DataFrame(
[
{
"trial_index": 0,
"arm_name": "arm",
"metric_name": "metric",
"mean": value,
"sem": 0.0,
}
]
)
with self.assertRaisesRegex(ValueError, "null or inf values"):
Data(df=df)

def test_EmptyData(self) -> None:
df = Data().df
self.assertTrue(df.empty)
Expand Down
21 changes: 21 additions & 0 deletions ax/core/tests/test_map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,27 @@ def test_properties(self) -> None:
self.assertEqual(self.mmd.map_keys, ["epoch"])
self.assertEqual(self.mmd.map_key_to_type, {"epoch": int})

def test_bad_data(self) -> None:
# Invalid mean values.
for value in [None, float("nan"), float("inf"), float("-inf")]:
df = pd.DataFrame(
[
{
"trial_index": 0,
"arm_name": "arm",
"metric_name": "metric",
"mean": value,
"sem": 0.0,
"epoch": 0,
}
]
)
with self.assertRaisesRegex(ValueError, "null or inf values"):
MapData(
df=df,
map_key_infos=[MapKeyInfo(key="epoch", default_value=0.0)],
)

def test_clone(self) -> None:
self.mmd._db_id = 1234
clone = self.mmd.clone()
Expand Down
24 changes: 3 additions & 21 deletions ax/core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,6 @@ def setUp(self) -> None:
"start_time": "2018-01-01",
"end_time": "2018-01-02",
},
{
"arm_name": "0_1",
"mean": float("nan"),
"sem": float("nan"),
"trial_index": 1,
"metric_name": "a",
"start_time": "2018-01-01",
"end_time": "2018-01-02",
},
{
"arm_name": "0_1",
"mean": 3.7,
Expand All @@ -143,17 +134,8 @@ def setUp(self) -> None:
},
{
"arm_name": "0_2",
"mean": float("nan"),
"sem": float("nan"),
"trial_index": 1,
"metric_name": "b",
"start_time": "2018-01-01",
"end_time": "2018-01-02",
},
{
"arm_name": "0_2",
"mean": float("nan"),
"sem": float("nan"),
"mean": 0.2,
"sem": None,
"trial_index": 1,
"metric_name": "c",
"start_time": "2018-01-01",
Expand Down Expand Up @@ -185,7 +167,7 @@ def test_get_missing_metrics(self) -> None:
expected = MissingMetrics(
{"a": {("0_1", 1)}},
{"b": {("0_2", 1)}},
{"c": {("0_0", 1), ("0_1", 1), ("0_2", 1)}},
{"c": {("0_0", 1), ("0_1", 1)}},
)
actual = get_missing_metrics(self.data, self.optimization_config)
self.assertEqual(actual, expected)
Expand Down
4 changes: 1 addition & 3 deletions ax/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,16 +652,14 @@ def extract_map_keys_from_opt_config(
# -------------------- Context manager and decorator utils. ---------------------


# pyre-ignore[3]: Allowing `Any` in this case
def batch_trial_only(msg: str | None = None) -> Callable[..., Any]:
"""A decorator to verify that the value passed to the `trial`
argument to `func` is a `BatchTrial`.
"""

# pyre-ignore[2,3]: Allowing `Any` in this case
def batch_trial_only_decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def _batch_trial_only(*args: Any, **kwargs: Any) -> Any: # pyre-ignore[3]
def _batch_trial_only(*args: Any, **kwargs: Any) -> Any:
if "trial" not in kwargs:
raise AxError(
f"Expected a keyword argument `trial` to `{func.__name__}`."
Expand Down
10 changes: 5 additions & 5 deletions ax/service/tests/test_best_point_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def test_is_row_feasible(self) -> None:
[3, 3, -1],
[2, 4, 1],
[2, 0, 1],
# adding this to an otherwise feasible observation to test nan handling
# An otherwise feasible observation with missing metric.
[2, 0, np.nan],
],
constrained=True,
Expand All @@ -640,10 +640,10 @@ def test_is_row_feasible(self) -> None:
df=exp.lookup_data().df,
optimization_config=none_throws(exp.optimization_config),
)
expected_per_arm = [False, True, False, True, True, False]
expected_per_arm = [False, True, False, True, True, True]
expected_series = _repeat_elements(
list_to_replicate=expected_per_arm, n_repeats=3
)
)[:-1] # Remove the last missing entry.
pd.testing.assert_series_equal(
feasible_series, expected_series, check_names=False
)
Expand All @@ -669,7 +669,7 @@ def test_is_row_feasible(self) -> None:
expected_per_arm = [False, True, True, True, True, True]
expected_series = _repeat_elements(
list_to_replicate=expected_per_arm, n_repeats=3
)
)[:-1]
pd.testing.assert_series_equal(
feasible_series, expected_series, check_names=False
)
Expand All @@ -694,7 +694,7 @@ def test_is_row_feasible(self) -> None:
expected_per_arm = [True, True, False, True, False, False]
expected_series = _repeat_elements(
list_to_replicate=expected_per_arm, n_repeats=3
)
)[:-1]
pd.testing.assert_series_equal(
feasible_series, expected_series, check_names=False
)
Expand Down
3 changes: 2 additions & 1 deletion ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from datetime import datetime, timedelta
from functools import partial
from logging import Logger
from math import prod
from math import isfinite, prod
from pathlib import Path
from typing import Any, cast, Sequence

Expand Down Expand Up @@ -1113,6 +1113,7 @@ def get_experiment_with_observations(
"trial_index": trial.index,
}
for m, o, s in zip(metrics, obs_i, sems_i, strict=True)
if isfinite(o)
]
)
)
Expand Down