@@ -302,17 +302,19 @@ def plot(self):
302302 fig , ax = plt .subplots ()
303303
304304 # Plot raw data
305- sns .lineplot (
306- self .data ,
307- x = self .time_variable_name ,
308- y = self .outcome_variable_name ,
309- hue = self .group_variable_name ,
310- units = "unit" ,
311- estimator = None ,
312- alpha = 0.25 ,
313- ax = ax ,
314- )
305+ # NOTE: This will not work when there is just ONE unit in each group
306+ # sns.lineplot(
307+ # self.data,
308+ # x=self.time_variable_name,
309+ # y=self.outcome_variable_name,
310+ # hue=self.group_variable_name,
311+ # # units="unit",
312+ # estimator=None,
313+ # alpha=0.25,
314+ # ax=ax,
315+ # )
315316 # Plot model fit to control group
317+ # NOTE: This will not work when there is just ONE unit in each group
316318 parts = ax .violinplot (
317319 az .extract (
318320 self .y_pred_control , group = "posterior_predictive" , var_names = "mu"
@@ -328,6 +330,7 @@ def plot(self):
328330 pc .set_alpha (0.5 )
329331
330332 # Plot model fit to treatment group
333+ # NOTE: This will not work when there is just ONE unit in each group
331334 parts = ax .violinplot (
332335 az .extract (
333336 self .y_pred_treatment , group = "posterior_predictive" , var_names = "mu"
@@ -337,18 +340,19 @@ def plot(self):
337340 showmedians = False ,
338341 widths = 0.2 ,
339342 )
340- # Plot counterfactual - post-test for treatment group IF no treatment had occurred.
341- parts = ax .violinplot (
342- az .extract (
343- self .y_pred_counterfactual ,
344- group = "posterior_predictive" ,
345- var_names = "mu" ,
346- ).values .T ,
347- positions = self .x_pred_counterfactual [self .time_variable_name ].values ,
348- showmeans = False ,
349- showmedians = False ,
350- widths = 0.2 ,
351- )
343+ # # Plot counterfactual - post-test for treatment group IF no treatment had occurred.
344+ # # NOTE: This will not work when there is just ONE unit in each group
345+ # parts = ax.violinplot(
346+ # az.extract(
347+ # self.y_pred_counterfactual,
348+ # group="posterior_predictive",
349+ # var_names="mu",
350+ # ).values.T,
351+ # positions=self.x_pred_counterfactual[self.time_variable_name].values,
352+ # showmeans=False,
353+ # showmedians=False,
354+ # widths=0.2,
355+ # )
352356 # arrow to label the causal impact
353357 y_pred_treatment = (
354358 self .y_pred_treatment ["posterior_predictive" ]
@@ -378,9 +382,9 @@ def plot(self):
378382 )
379383 # formatting
380384 ax .set (
381- xlim = [- 0.15 , 1.25 ],
382- xticks = [ 0 , 1 ] ,
383- xticklabels = ["pre" , "post" ],
385+ # xlim=[-0.15, 1.25],
386+ xticks = self . x_pred_treatment [ self . time_variable_name ]. values ,
387+ # xticklabels=["pre", "post"],
384388 title = self ._causal_impact_summary_stat (),
385389 )
386390 ax .legend (fontsize = LEGEND_FONT_SIZE )
0 commit comments