@@ -176,44 +176,3 @@ def policy_gen(Tpre, X, period):
176176        return  self ._gen_data_with_policy (n_units , policy_gen , random_seed = random_seed )
177177
178178
179- # Auxiliary function for adding xticks and vertical lines when plotting results 
180- # for dynamic dml vs ground truth parameters. 
181- def  add_vlines (n_periods , n_treatments , hetero_inds ):
182-     locs , labels  =  plt .xticks ([], [])
183-     locs  +=  [-  .501  +  (len (hetero_inds ) +  1 ) /  2 ]
184-     labels  +=  ["\n \n $\\ tau_{{{}}}$" .format (0 )]
185-     locs  +=  [qx  for  qx  in  np .arange (len (hetero_inds ) +  1 )]
186-     labels  +=  ["$1$" ] +  ["$x_{{{}}}$" .format (qx ) for  qx  in  hetero_inds ]
187-     for  q  in  np .arange (1 , n_treatments ):
188-         plt .axvline (x = q  *  (len (hetero_inds ) +  1 ) -  .5 ,
189-                     linestyle = '--' , color = 'red' , alpha = .2 )
190-         locs  +=  [q  *  (len (hetero_inds ) +  1 ) -  .501  +  (len (hetero_inds ) +  1 ) /  2 ]
191-         labels  +=  ["\n \n $\\ tau_{{{}}}$" .format (q )]
192-         locs  +=  [(q  *  (len (hetero_inds ) +  1 ) +  qx )
193-                  for  qx  in  np .arange (len (hetero_inds ) +  1 )]
194-         labels  +=  ["$1$" ] +  ["$x_{{{}}}$" .format (qx ) for  qx  in  hetero_inds ]
195-     locs  +=  [-  .501  +  (len (hetero_inds ) +  1 ) *  n_treatments  /  2 ]
196-     labels  +=  ["\n \n \n \n $\\ theta_{{{}}}$" .format (0 )]
197-     for  t  in  np .arange (1 , n_periods ):
198-         plt .axvline (x = t  *  (len (hetero_inds ) +  1 ) * 
199-                     n_treatments  -  .5 , linestyle = '-' , alpha = .6 )
200-         locs  +=  [t  *  (len (hetero_inds ) +  1 ) *  n_treatments  -  .501  + 
201-                  (len (hetero_inds ) +  1 ) *  n_treatments  /  2 ]
202-         labels  +=  ["\n \n \n \n $\\ theta_{{{}}}$" .format (t )]
203-         locs  +=  [t  *  (len (hetero_inds ) +  1 ) * 
204-                  n_treatments  -  .501  +  (len (hetero_inds ) +  1 ) /  2 ]
205-         labels  +=  ["\n \n $\\ tau_{{{}}}$" .format (0 )]
206-         locs  +=  [t  *  (len (hetero_inds ) +  1 ) *  n_treatments  + 
207-                  qx  for  qx  in  np .arange (len (hetero_inds ) +  1 )]
208-         labels  +=  ["$1$" ] +  ["$x_{{{}}}$" .format (qx ) for  qx  in  hetero_inds ]
209-         for  q  in  np .arange (1 , n_treatments ):
210-             plt .axvline (x = t  *  (len (hetero_inds ) +  1 ) *  n_treatments  +  q  *  (len (hetero_inds ) +  1 ) -  .5 ,
211-                         linestyle = '--' , color = 'red' , alpha = .2 )
212-             locs  +=  [t  *  (len (hetero_inds ) +  1 ) *  n_treatments  +  q  * 
213-                      (len (hetero_inds ) +  1 ) -  .501  +  (len (hetero_inds ) +  1 ) /  2 ]
214-             labels  +=  ["\n \n $\\ tau_{{{}}}$" .format (q )]
215-             locs  +=  [t  *  (len (hetero_inds ) +  1 ) *  n_treatments  +  (q  *  (len (hetero_inds ) +  1 ) +  qx )
216-                      for  qx  in  np .arange (len (hetero_inds ) +  1 )]
217-             labels  +=  ["$1$" ] +  ["$x_{{{}}}$" .format (qx ) for  qx  in  hetero_inds ]
218-     plt .xticks (locs , labels )
219-     plt .tight_layout ()
0 commit comments