Skip to content
Open
12,613 changes: 12,613 additions & 0 deletions docs/source/notebooks/mmm/dev/mmm_example_new.ipynb

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions pymc_marketing/customer_choice/mnl_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ class MNLogit(RegressionModelBuilder):

Example `utility_equations` list:

>>> utility_equations = [
... "alt_1 ~ X1_alt1 + X2_alt1 | income",
... "alt_2 ~ X1_alt2 + X2_alt2 | income",
... "alt_3 ~ X1_alt3 + X2_alt3 | income",
... ]
.. code-block:: python

utility_equations = [
"alt_1 ~ X1_alt1 + X2_alt1 | income",
"alt_2 ~ X1_alt2 + X2_alt2 | income",
"alt_3 ~ X1_alt3 + X2_alt3 | income",
]

"""

Expand Down
22 changes: 13 additions & 9 deletions pymc_marketing/customer_choice/nested_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,22 @@ class NestedLogit(RegressionModelBuilder):

Example `utility_equations` list:

>>> utility_equations = [
... "alt_1 ~ X1_alt1 + X2_alt1 | income",
... "alt_2 ~ X1_alt2 + X2_alt2 | income",
... "alt_3 ~ X1_alt3 + X2_alt3 | income",
... ]
.. code-block:: python

utility_equations = [
"alt_1 ~ X1_alt1 + X2_alt1 | income",
"alt_2 ~ X1_alt2 + X2_alt2 | income",
"alt_3 ~ X1_alt3 + X2_alt3 | income",
]

Example nesting structure:

>>> nesting_structure = {
... "Nest1": ["alt1"],
... "Nest2": {"Nest2_1": ["alt_2", "alt_3"], "Nest_2_2": ["alt_4", "alt_5"]},
... }
.. code-block:: python

nesting_structure = {
"Nest1": ["alt1"],
"Nest2": {"Nest2_1": ["alt_2", "alt_3"], "Nest_2_2": ["alt_4", "alt_5"]},
}

"""

Expand Down
6 changes: 4 additions & 2 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ def preprocess(

Example
-------
>>> data = pd.DataFrame({"x1": [1, 2, 3], "y": [4, 5, 6]})
>>> self.preprocess("X", data)
.. code-block:: python

data = pd.DataFrame({"x1": [1, 2, 3], "y": [4, 5, 6]})
self.preprocess("X", data)

"""
data_cp = data.copy()
Expand Down
36 changes: 17 additions & 19 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,28 +2121,26 @@ def format_recovered_transformation_parameters(

Example
-------
>>> self.format_recovered_transformation_parameters(quantile=0.5)
>>> Output:
{
'x1': {
'saturation_params': {
'lam': 2.4761893929757077,
'beta': 0.360226791880304
.. code-block:: python

self.format_recovered_transformation_parameters(quantile=0.5)
# Output:
{
"x1": {
"saturation_params": {
"lam": 2.4761893929757077,
"beta": 0.360226791880304,
},
"adstock_params": {"alpha": 0.39910387900504796},
},
'adstock_params': {
'alpha': 0.39910387900504796
}
},
'x2': {
'saturation_params': {
'lam': 2.6485978655163436,
'beta': 0.2399381337197204
"x2": {
"saturation_params": {
"lam": 2.6485978655163436,
"beta": 0.2399381337197204,
},
"adstock_params": {"alpha": 0.18859423763437405},
},
'adstock_params': {
'alpha': 0.18859423763437405
}
}
}

"""
# Retrieve channel names
Expand Down
69 changes: 45 additions & 24 deletions pymc_marketing/mmm/multidimensional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,11 +1013,14 @@ def forward_pass(

Examples
--------
>>> mmm = MMM(
date_column="date_week",
channel_columns=["channel_1", "channel_2"],
target_column="target",
)
.. code-block:: python

mmm = MMM(
date_column="date_week",
channel_columns=["channel_1", "channel_2"],
target_column="target",
)

"""
first, second = (
(self.adstock, self.saturation)
Expand Down Expand Up @@ -1055,13 +1058,16 @@ def get_scales_as_xarray(self) -> dict[str, xr.DataArray]:

Examples
--------
>>> mmm = MMM(
date_column="date_week",
channel_columns=["channel_1", "channel_2"],
target_column="target",
)
>>> mmm.build_model(X, y)
>>> mmm.get_scales_as_xarray()
.. code-block:: python

mmm = MMM(
date_column="date_week",
channel_columns=["channel_1", "channel_2"],
target_column="target",
)
mmm.build_model(X, y)
mmm.get_scales_as_xarray()

"""
if not hasattr(self, "scalers"):
raise ValueError(
Expand Down Expand Up @@ -1100,9 +1106,12 @@ def add_original_scale_contribution_variable(self, var: list[str]) -> None:

Examples
--------
>>> model.add_original_scale_contribution_variable(
>>> var=["channel_contribution", "total_media_contribution", "y"]
>>> )
.. code-block:: python

model.add_original_scale_contribution_variable(
var=["channel_contribution", "total_media_contribution", "y"]
)

"""
self._validate_model_was_built()
target_dims = self.scalers._target.dims
Expand Down Expand Up @@ -1695,8 +1704,11 @@ def sample_posterior_predictive(
self.idata, **sample_posterior_predictive_kwargs
)

if extend_idata:
self.idata.extend(post_pred, join="right") # type: ignore
if extend_idata and self.idata is not None:
self.idata.add_groups(
posterior_predictive=post_pred.posterior_predictive,
posterior_predictive_constant_data=post_pred.constant_data,
) # type: ignore

group = "posterior_predictive"
posterior_predictive_samples = az.extract(post_pred, group, combined=combined)
Expand All @@ -1723,11 +1735,14 @@ def sensitivity(self) -> SensitivityAnalysis:

Examples
--------
>>> mmm.sensitivity.run_sweep(
... var_names=["channel_1", "channel_2"],
... sweep_values=np.linspace(0.5, 2.0, 10),
... sweep_type="multiplicative",
... )
.. code-block:: python

mmm.sensitivity.run_sweep(
var_names=["channel_1", "channel_2"],
sweep_values=np.linspace(0.5, 2.0, 10),
sweep_type="multiplicative",
)

"""
# Provide the underlying PyMC model, the model's inference data, and dims
return SensitivityAnalysis(
Expand Down Expand Up @@ -2145,7 +2160,10 @@ def create_fit_data(

Examples
--------
>>> ds = mmm.create_fit_data(X, y)
.. code-block:: python

ds = mmm.create_fit_data(X, y)

"""
# --- Coerce X to DataFrame ---
if isinstance(X, xr.Dataset):
Expand Down Expand Up @@ -2234,7 +2252,10 @@ def build_from_idata(self, idata: az.InferenceData) -> None:

Examples
--------
>>> mmm.build_from_idata(idata)
.. code-block:: python

mmm.build_from_idata(idata)

"""
dataset = idata.fit_data.to_dataframe()

Expand Down
Loading