Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ax/adapter/transforms/relativize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ax.core.search_space import SearchSpace
from ax.exceptions.core import DataRequiredError
from ax.generators.types import TConfig
from ax.utils.stats.statstools import relativize, unrelativize
from ax.utils.stats.math_utils import relativize, unrelativize
from pyre_extensions import none_throws

if TYPE_CHECKING:
Expand Down
6 changes: 2 additions & 4 deletions ax/adapter/transforms/tests/test_relativize_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from ax.generators.base import Generator
from ax.metrics.branin import BraninMetric
from ax.utils.common.testutils import TestCase
from ax.utils.stats.statstools import relativize_data
from ax.utils.testing.core_stubs import (
get_branin_data_batch,
get_branin_experiment,
Expand Down Expand Up @@ -341,8 +340,7 @@ def test_multitask_data(self) -> None:
)
relative_observations = observations_from_data(
experiment=experiment,
data=relativize_data(
data=data,
data=data.relativize(
status_quo_name="status_quo",
as_percent=True,
include_sq=True,
Expand Down Expand Up @@ -376,7 +374,7 @@ def test_multitask_data(self) -> None:
)

# not checking RelativizeWithConstantControl here
# because relativize_data uses delta method
# because data.relativize uses delta method
transform = Relativize(search_space=None, adapter=adapter)

relative_obs_t = transform.transform_observations(observations)
Expand Down
2 changes: 1 addition & 1 deletion ax/adapter/transforms/tests/test_transform_to_new_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ax.exceptions.core import DataRequiredError
from ax.generators.base import Generator
from ax.utils.common.testutils import TestCase
from ax.utils.stats.statstools import relativize
from ax.utils.stats.math_utils import relativize
from ax.utils.testing.core_stubs import (
get_branin_data_batch,
get_branin_experiment,
Expand Down
2 changes: 1 addition & 1 deletion ax/adapter/transforms/transform_to_new_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ax.core.utils import get_target_trial_index
from ax.generators.types import TConfig
from ax.utils.common.logger import get_logger
from ax.utils.stats.statstools import relativize, unrelativize
from ax.utils.stats.math_utils import relativize, unrelativize
from pyre_extensions import assert_is_instance, none_throws

if TYPE_CHECKING:
Expand Down
3 changes: 1 addition & 2 deletions ax/analysis/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ax.core.trial_status import TrialStatus # noqa
from ax.exceptions.core import UserInputError
from ax.utils.common.testutils import TestCase
from ax.utils.stats.statstools import relativize_data
from ax.utils.testing.core_stubs import get_offline_experiments, get_online_experiments
from ax.utils.testing.mock import mock_botorch_optimize
from ax.utils.testing.modeling_stubs import get_default_generation_strategy_at_MBM_node
Expand Down Expand Up @@ -676,7 +675,7 @@ def test_online(self) -> None:
# resemble those we see in an online setting.
for experiment in get_online_experiments():
data = experiment.lookup_data()
rel_df = relativize_data(data).df
rel_df = data.relativize().df
raw_df = data.df
metric_name = next(iter(experiment.metrics.keys()))
raw_arm_value = raw_df[
Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.stats.statstools import relativize
from ax.utils.stats.math_utils import relativize
from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds
from pyre_extensions import none_throws

Expand Down
75 changes: 75 additions & 0 deletions ax/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
TClassDecoderRegistry,
TDecoderRegistry,
)
from ax.utils.stats.math_utils import relativize as relativize_func
from pyre_extensions import assert_is_instance

logger: Logger = get_logger(__name__)
Expand Down Expand Up @@ -306,6 +307,80 @@ def from_multiple_data(cls, data: Iterable[Data]) -> Data:
"""
return cls.from_multiple(data=data)

def relativize(
self,
status_quo_name: str = "status_quo",
as_percent: bool = False,
include_sq: bool = False,
bias_correction: bool = True,
control_as_constant: bool = False,
) -> "Data":
"""Relativize this data object w.r.t. a status_quo arm.

Args:
status_quo_name: The name of the status_quo arm.
as_percent: If True, return results as percentage change.
include_sq: Include status quo in final df.
bias_correction: Whether to apply bias correction when computing relativized
metric values. Uses a second-order Taylor expansion for approximating
the means and standard errors or the ratios, see
ax.utils.stats.statstools.relativize for more details.
control_as_constant: If true, control is treated as a constant.
bias_correction is ignored when this is true.

Returns:
The new data object with the relativized metrics (excluding the
status_quo arm)

"""

df = self.df.copy()
grp_cols = list(
{"trial_index", "metric_name", "random_split"}.intersection(
df.columns.values
)
)

grouped_df = df.groupby(grp_cols)
dfs = []
for grp in grouped_df.groups.keys():
subgroup_df = grouped_df.get_group(grp)
is_sq = subgroup_df["arm_name"] == status_quo_name

sq_mean, sq_sem = (
subgroup_df[is_sq][["mean", "sem"]].drop_duplicates().values.flatten()
)

# rm status quo from final df to relativize
if not include_sq:
subgroup_df = subgroup_df[~is_sq]
means_rel, sems_rel = relativize_func(
means_t=subgroup_df["mean"].values,
sems_t=subgroup_df["sem"].values,
mean_c=sq_mean,
sem_c=sq_sem,
as_percent=as_percent,
bias_correction=bias_correction,
control_as_constant=control_as_constant,
)
dfs.append(
pd.concat(
[
subgroup_df.drop(["mean", "sem"], axis=1),
pd.DataFrame(
np.array([means_rel, sems_rel]).T,
columns=["mean", "sem"],
index=subgroup_df.index,
),
],
axis=1,
)
)
df_rel = pd.concat(dfs, axis=0)
if include_sq:
df_rel.loc[df_rel["arm_name"] == status_quo_name, "sem"] = 0.0
return Data(df_rel)

def clone(self: TData) -> TData:
"""Returns a new Data object with the same underlying dataframe."""
return self.__class__(df=deepcopy(self.full_df))
Expand Down
140 changes: 140 additions & 0 deletions ax/core/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from unittest.mock import patch

import numpy as np

import pandas as pd
from ax.core.data import Data
from ax.core.map_data import MAP_KEY, MapData
Expand Down Expand Up @@ -343,3 +345,141 @@ def test_safecast_df(self) -> None:
safecast_df = Data._safecast_df(df=df)
self.assertEqual(safecast_df.index.get_level_values(0).to_list(), [0])
self.assertEqual(df["trial_index"].dtype, int)


class RelativizeDataTest(TestCase):
def setUp(self) -> None:
self.df = pd.DataFrame(
[
{
"trial_index": 0,
"mean": 2,
"sem": 0,
"metric_name": "foobar",
"metric_signature": "foobar",
"arm_name": "status_quo",
},
{
"trial_index": 0,
"mean": 5,
"sem": 0,
"metric_name": "foobaz",
"metric_signature": "foobaz",
"arm_name": "status_quo",
},
{
"trial_index": 0,
"mean": 1,
"sem": 0,
"metric_name": "foobar",
"metric_signature": "foobar",
"arm_name": "0_0",
},
{
"trial_index": 0,
"mean": 10,
"sem": 0,
"metric_name": "foobaz",
"metric_signature": "foobaz",
"arm_name": "0_0",
},
]
)

self.expected_relativized_df = pd.DataFrame(
[
{
"trial_index": 0,
"mean": -0.5,
"sem": 0,
"metric_name": "foobar",
"metric_signature": "foobar",
"arm_name": "0_0",
},
{
"trial_index": 0,
"mean": 1,
"sem": 0,
"metric_name": "foobaz",
"metric_signature": "foobaz",
"arm_name": "0_0",
},
]
)
self.expected_relativized_df_with_sq = pd.DataFrame(
[
{
"trial_index": 0,
"mean": 0,
"sem": 0,
"metric_name": "foobar",
"metric_signature": "foobar",
"arm_name": "status_quo",
},
{
"trial_index": 0,
"mean": -0.5,
"sem": 0,
"metric_name": "foobar",
"metric_signature": "foobar",
"arm_name": "0_0",
},
{
"trial_index": 0,
"mean": 0,
"sem": 0,
"metric_name": "foobaz",
"metric_signature": "foobaz",
"arm_name": "status_quo",
},
{
"trial_index": 0,
"mean": 1,
"sem": 0,
"metric_name": "foobaz",
"metric_signature": "foobaz",
"arm_name": "0_0",
},
]
)

def test_relativize_data(self) -> None:
data = Data(
df=self.df,
)
expected_relativized_data = Data(df=self.expected_relativized_df)

expected_relativized_data_with_sq = Data(
df=self.expected_relativized_df_with_sq
)

actual_relativized_data = data.relativize()
self.assertEqual(expected_relativized_data, actual_relativized_data)

actual_relativized_data_with_sq = data.relativize(include_sq=True)
self.assertEqual(
expected_relativized_data_with_sq, actual_relativized_data_with_sq
)

def test_relativize_data_no_sem(self) -> None:
df = self.df.copy()
df["sem"] = np.nan
data = Data(df=df)

expected_relativized_df = self.expected_relativized_df.copy()
expected_relativized_df["sem"] = np.nan
expected_relativized_data = Data(df=expected_relativized_df)

expected_relativized_df_with_sq = self.expected_relativized_df_with_sq.copy()
expected_relativized_df_with_sq.loc[
expected_relativized_df_with_sq["arm_name"] != "status_quo", "sem"
] = np.nan
expected_relativized_data_with_sq = Data(df=expected_relativized_df_with_sq)

actual_relativized_data = data.relativize()
self.assertEqual(expected_relativized_data, actual_relativized_data)

actual_relativized_data_with_sq = data.relativize(include_sq=True)
self.assertEqual(
expected_relativized_data_with_sq, actual_relativized_data_with_sq
)
2 changes: 1 addition & 1 deletion ax/plot/pareto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from ax.exceptions.core import AxError, UnsupportedError, UserInputError
from ax.generators.torch_base import TorchGenerator
from ax.utils.common.logger import get_logger
from ax.utils.stats.statstools import relativize
from ax.utils.stats.math_utils import relativize
from botorch.acquisition.monte_carlo import qSimpleRegret
from botorch.utils.multi_objective import is_non_dominated
from botorch.utils.multi_objective.hypervolume import infer_reference_point
Expand Down
2 changes: 1 addition & 1 deletion ax/plot/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import assert_is_instance_optional
from ax.utils.stats.statstools import relativize
from ax.utils.stats.math_utils import relativize
from plotly import subplots

logger: Logger = get_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion ax/plot/tests/test_pareto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)

from ax.utils.common.testutils import TestCase
from ax.utils.stats.statstools import relativize
from ax.utils.stats.math_utils import relativize
from ax.utils.testing.core_stubs import (
get_branin_experiment,
get_branin_experiment_with_multi_objective,
Expand Down
Loading