Skip to content

Commit 571c3bf

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Move relativize functions out of statstools to avoid circular dependency in imports (facebook#4341)
Summary: To resolve a circular dependency between `ax.core` and `statstools`, this diff moves the `relativize`, `unrelativize`, and `relativize_data` functions out of statstools. The `Data` class now provides a relativized method to replace relativize_data, and all relevant imports, dependencies, and unit tests have been updated accordingly. Differential Revision: D83095707
1 parent f6a84cd commit 571c3bf

16 files changed

+651
-457
lines changed

ax/adapter/transforms/relativize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ax.core.search_space import SearchSpace
2828
from ax.exceptions.core import DataRequiredError
2929
from ax.generators.types import TConfig
30-
from ax.utils.stats.statstools import relativize, unrelativize
30+
from ax.utils.stats.math_utils import relativize, unrelativize
3131
from pyre_extensions import none_throws
3232

3333
if TYPE_CHECKING:

ax/adapter/transforms/tests/test_relativize_transform.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from ax.generators.base import Generator
3737
from ax.metrics.branin import BraninMetric
3838
from ax.utils.common.testutils import TestCase
39-
from ax.utils.stats.statstools import relativize_data
4039
from ax.utils.testing.core_stubs import (
4140
get_branin_data_batch,
4241
get_branin_experiment,
@@ -352,8 +351,7 @@ def test_multitask_data(self) -> None:
352351
)
353352
relative_observations = observations_from_data(
354353
experiment=experiment,
355-
data=relativize_data(
356-
data=data,
354+
data=data.relativize(
357355
status_quo_name="status_quo",
358356
as_percent=True,
359357
include_sq=True,
@@ -387,7 +385,7 @@ def test_multitask_data(self) -> None:
387385
)
388386

389387
# not checking RelativizeWithConstantControl here
390-
# because relativize_data uses delta method
388+
# because data.relativize uses delta method
391389
transform = Relativize(search_space=None, adapter=adapter)
392390

393391
relative_obs_t = transform.transform_observations(observations)

ax/adapter/transforms/tests/test_transform_to_new_sq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ax.exceptions.core import DataRequiredError
2323
from ax.generators.base import Generator
2424
from ax.utils.common.testutils import TestCase
25-
from ax.utils.stats.statstools import relativize
25+
from ax.utils.stats.math_utils import relativize
2626
from ax.utils.testing.core_stubs import (
2727
get_branin_data_batch,
2828
get_branin_experiment,

ax/adapter/transforms/transform_to_new_sq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ax.core.utils import get_target_trial_index
2424
from ax.generators.types import TConfig
2525
from ax.utils.common.logger import get_logger
26-
from ax.utils.stats.statstools import relativize, unrelativize
26+
from ax.utils.stats.math_utils import relativize, unrelativize
2727
from pyre_extensions import assert_is_instance, none_throws
2828

2929
if TYPE_CHECKING:

ax/analysis/tests/test_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from ax.core.trial_status import TrialStatus # noqa
2121
from ax.exceptions.core import UserInputError
2222
from ax.utils.common.testutils import TestCase
23-
from ax.utils.stats.statstools import relativize_data
2423
from ax.utils.testing.core_stubs import get_offline_experiments, get_online_experiments
2524
from ax.utils.testing.mock import mock_botorch_optimize
2625
from ax.utils.testing.modeling_stubs import get_default_generation_strategy_at_MBM_node
@@ -676,7 +675,7 @@ def test_online(self) -> None:
676675
# resemble those we see in an online setting.
677676
for experiment in get_online_experiments():
678677
data = experiment.lookup_data()
679-
rel_df = relativize_data(data).df
678+
rel_df = data.relativize().df
680679
raw_df = data.df
681680
metric_name = next(iter(experiment.metrics.keys()))
682681
raw_arm_value = raw_df[

ax/analysis/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ax.generation_strategy.generation_strategy import GenerationStrategy
2828
from ax.utils.common.constants import Keys
2929
from ax.utils.common.logger import get_logger
30-
from ax.utils.stats.statstools import relativize
30+
from ax.utils.stats.math_utils import relativize
3131
from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds
3232
from pyre_extensions import none_throws
3333

ax/core/data.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
TClassDecoderRegistry,
2828
TDecoderRegistry,
2929
)
30+
from ax.utils.stats.math_utils import relativize
3031
from pyre_extensions import assert_is_instance
3132

3233
logger: Logger = get_logger(__name__)
@@ -392,6 +393,91 @@ def from_multiple_data(cls, data: Iterable[Data]) -> Data:
392393
"""
393394
return cls.from_multiple(data=data)
394395

396+
def relativize(
397+
self,
398+
status_quo_name: str = "status_quo",
399+
as_percent: bool = False,
400+
include_sq: bool = False,
401+
bias_correction: bool = True,
402+
control_as_constant: bool = False,
403+
) -> "Data":
404+
"""Relativize this data object w.r.t. a status_quo arm.
405+
406+
Args:
407+
status_quo_name: The name of the status_quo arm.
408+
as_percent: If True, return results as percentage change.
409+
include_sq: Include status quo in final df.
410+
bias_correction: Whether to apply bias correction when computing relativized
411+
metric values. Uses a second-order Taylor expansion for approximating
412+
the means and standard errors or the ratios, see
413+
ax.utils.stats.statstools.relativize for more details.
414+
control_as_constant: If true, control is treated as a constant.
415+
bias_correction is ignored when this is true.
416+
417+
Returns:
418+
The new data object with the relativized metrics (excluding the
419+
status_quo arm)
420+
421+
"""
422+
423+
df = self.df.copy()
424+
grp_cols = list(
425+
{"trial_index", "metric_name", "random_split"}.intersection(
426+
df.columns.values
427+
)
428+
)
429+
430+
grouped_df = df.groupby(grp_cols)
431+
dfs = []
432+
for grp in grouped_df.groups.keys():
433+
subgroup_df = grouped_df.get_group(grp)
434+
is_sq = subgroup_df["arm_name"] == status_quo_name
435+
436+
# Check if status quo exists in this subgroup
437+
sq_data = (
438+
subgroup_df[is_sq][["mean", "sem"]].drop_duplicates().values.flatten()
439+
)
440+
if len(sq_data) == 0:
441+
# No status quo in this subgroup, skip relativization
442+
continue
443+
elif len(sq_data) != 2:
444+
raise ValueError(
445+
f"Expected exactly 2 values (mean, sem) for status quo, "
446+
f"got {len(sq_data)}"
447+
)
448+
449+
sq_mean, sq_sem = sq_data
450+
451+
# rm status quo from final df to relativize
452+
if not include_sq:
453+
subgroup_df = subgroup_df[~is_sq]
454+
means_rel, sems_rel = relativize(
455+
means_t=subgroup_df["mean"].values,
456+
sems_t=subgroup_df["sem"].values,
457+
mean_c=sq_mean,
458+
sem_c=sq_sem,
459+
as_percent=as_percent,
460+
bias_correction=bias_correction,
461+
control_as_constant=control_as_constant,
462+
)
463+
dfs.append(
464+
pd.concat(
465+
[
466+
subgroup_df.drop(["mean", "sem"], axis=1),
467+
pd.DataFrame(
468+
np.array([means_rel, sems_rel]).T,
469+
columns=["mean", "sem"],
470+
index=subgroup_df.index,
471+
),
472+
],
473+
axis=1,
474+
)
475+
)
476+
df_rel = pd.concat(dfs, axis=0)
477+
if include_sq:
478+
df_rel.loc[df_rel["arm_name"] == status_quo_name, "sem"] = 0.0
479+
return Data(df_rel)
480+
395481
def clone(self) -> Data:
396482
"""Returns a new Data object with the same underlying dataframe."""
397483
return Data(df=deepcopy(self.df))

ax/core/tests/test_data.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from unittest.mock import patch
1010

11+
import numpy as np
12+
1113
import pandas as pd
1214
from ax.core.data import Data
1315
from ax.core.map_data import MAP_KEY, MapData
@@ -366,3 +368,141 @@ def test_filter(self) -> None:
366368
filtered = data.filter(metric_names=["a"])
367369
self.assertEqual(len(filtered.df), 3)
368370
self.assertEqual(set(filtered.df["metric_name"]), {"a"})
371+
372+
373+
class RelativizeDataTest(TestCase):
374+
def setUp(self) -> None:
375+
self.df = pd.DataFrame(
376+
[
377+
{
378+
"trial_index": 0,
379+
"mean": 2,
380+
"sem": 0,
381+
"metric_name": "foobar",
382+
"metric_signature": "foobar",
383+
"arm_name": "status_quo",
384+
},
385+
{
386+
"trial_index": 0,
387+
"mean": 5,
388+
"sem": 0,
389+
"metric_name": "foobaz",
390+
"metric_signature": "foobaz",
391+
"arm_name": "status_quo",
392+
},
393+
{
394+
"trial_index": 0,
395+
"mean": 1,
396+
"sem": 0,
397+
"metric_name": "foobar",
398+
"metric_signature": "foobar",
399+
"arm_name": "0_0",
400+
},
401+
{
402+
"trial_index": 0,
403+
"mean": 10,
404+
"sem": 0,
405+
"metric_name": "foobaz",
406+
"metric_signature": "foobaz",
407+
"arm_name": "0_0",
408+
},
409+
]
410+
)
411+
412+
self.expected_relativized_df = pd.DataFrame(
413+
[
414+
{
415+
"trial_index": 0,
416+
"mean": -0.5,
417+
"sem": 0,
418+
"metric_name": "foobar",
419+
"metric_signature": "foobar",
420+
"arm_name": "0_0",
421+
},
422+
{
423+
"trial_index": 0,
424+
"mean": 1,
425+
"sem": 0,
426+
"metric_name": "foobaz",
427+
"metric_signature": "foobaz",
428+
"arm_name": "0_0",
429+
},
430+
]
431+
)
432+
self.expected_relativized_df_with_sq = pd.DataFrame(
433+
[
434+
{
435+
"trial_index": 0,
436+
"mean": 0,
437+
"sem": 0,
438+
"metric_name": "foobar",
439+
"metric_signature": "foobar",
440+
"arm_name": "status_quo",
441+
},
442+
{
443+
"trial_index": 0,
444+
"mean": -0.5,
445+
"sem": 0,
446+
"metric_name": "foobar",
447+
"metric_signature": "foobar",
448+
"arm_name": "0_0",
449+
},
450+
{
451+
"trial_index": 0,
452+
"mean": 0,
453+
"sem": 0,
454+
"metric_name": "foobaz",
455+
"metric_signature": "foobaz",
456+
"arm_name": "status_quo",
457+
},
458+
{
459+
"trial_index": 0,
460+
"mean": 1,
461+
"sem": 0,
462+
"metric_name": "foobaz",
463+
"metric_signature": "foobaz",
464+
"arm_name": "0_0",
465+
},
466+
]
467+
)
468+
469+
def test_relativize_data(self) -> None:
470+
data = Data(
471+
df=self.df,
472+
)
473+
expected_relativized_data = Data(df=self.expected_relativized_df)
474+
475+
expected_relativized_data_with_sq = Data(
476+
df=self.expected_relativized_df_with_sq
477+
)
478+
479+
actual_relativized_data = data.relativize()
480+
self.assertEqual(expected_relativized_data, actual_relativized_data)
481+
482+
actual_relativized_data_with_sq = data.relativize(include_sq=True)
483+
self.assertEqual(
484+
expected_relativized_data_with_sq, actual_relativized_data_with_sq
485+
)
486+
487+
def test_relativize_data_no_sem(self) -> None:
488+
df = self.df.copy()
489+
df["sem"] = np.nan
490+
data = Data(df=df)
491+
492+
expected_relativized_df = self.expected_relativized_df.copy()
493+
expected_relativized_df["sem"] = np.nan
494+
expected_relativized_data = Data(df=expected_relativized_df)
495+
496+
expected_relativized_df_with_sq = self.expected_relativized_df_with_sq.copy()
497+
expected_relativized_df_with_sq.loc[
498+
expected_relativized_df_with_sq["arm_name"] != "status_quo", "sem"
499+
] = np.nan
500+
expected_relativized_data_with_sq = Data(df=expected_relativized_df_with_sq)
501+
502+
actual_relativized_data = data.relativize()
503+
self.assertEqual(expected_relativized_data, actual_relativized_data)
504+
505+
actual_relativized_data_with_sq = data.relativize(include_sq=True)
506+
self.assertEqual(
507+
expected_relativized_data_with_sq, actual_relativized_data_with_sq
508+
)

ax/plot/pareto_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from ax.exceptions.core import AxError, UnsupportedError, UserInputError
4242
from ax.generators.torch_base import TorchGenerator
4343
from ax.utils.common.logger import get_logger
44-
from ax.utils.stats.statstools import relativize
44+
from ax.utils.stats.math_utils import relativize
4545
from botorch.acquisition.monte_carlo import qSimpleRegret
4646
from botorch.utils.multi_objective import is_non_dominated
4747
from botorch.utils.multi_objective.hypervolume import infer_reference_point

ax/plot/scatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
)
4646
from ax.utils.common.logger import get_logger
4747
from ax.utils.common.typeutils import assert_is_instance_optional
48-
from ax.utils.stats.statstools import relativize
48+
from ax.utils.stats.math_utils import relativize
4949
from plotly import subplots
5050

5151
logger: Logger = get_logger(__name__)

0 commit comments

Comments
 (0)