@@ -541,48 +541,79 @@ class InterventionTimeEstimator(PyMCModel):
541541        ...     t, 
542542        ...     y, 
543543        ...     coords, 
544-         ...     effect=[ "impulse"]  
544+         ...     priors={ "impulse":[]}  
545545        ... ) 
546546        Inference data... 
547547    """ 
548548
549-     def  build_model (self , t , y , coords , effect ,  span ,  grain_season ):
549+     def  build_model (self , t , y , coords , time_range ,  grain_season ,  priors ):
550550        """ 
551551        Defines the PyMC model 
552552
553553        :param t: An array of values representing the time over which y is spread 
554554        :param y: An array of values representing our outcome y 
555-         :param coords: A dictionary with the coordinate names for our instruments 
555+         :param coords: An optional dictionary with the coordinate names for our instruments. 
556+             In particular, used to determine the number of seasons. 
557+         :param time_range: An optional tuple providing a specific time_range where the 
558+             intervention effect should have taken place. 
559+         :param priors: An optional dictionary of priors for the parameters of the 
560+             different distributions. 
561+             :code:`priors = {"alpha":[0, 5], "beta":[0,2], "level":[5, 5], "impulse":[1, 2 ,3]}` 
556562        """ 
557563
558564        with  self :
559565            self .add_coords (coords )
560566
561-             if  span  is  None :
562-                 span  =  (t .min (), t .max ())
567+             if  time_range  is  None :
568+                 time_range  =  (t .min (), t .max ())
563569
564570            # --- Priors --- 
565-             switchpoint  =  pm .Uniform ("switchpoint" , lower = span [0 ], upper = span [1 ])
566-             alpha  =  pm .Normal (name = "alpha" , mu = 0 , sigma = 10 )
567-             beta  =  pm .Normal (name = "beta" , mu = 0 , sigma = 10 )
571+             switchpoint  =  pm .Uniform (
572+                 "switchpoint" , lower = time_range [0 ], upper = time_range [1 ]
573+             )
574+             alpha  =  pm .Normal (name = "alpha" , mu = 0 , sigma = 50 )
575+             beta  =  pm .Normal (name = "beta" , mu = 0 , sigma = 50 )
568576            seasons  =  0 
569577            if  "seasons"  in  coords  and  len (coords ["seasons" ]) >  0 :
570578                season_idx  =  np .arange (len (y )) //  grain_season  %  len (coords ["seasons" ])
571-                 season_effect  =  pm .Normal ("season" , mu = 0 , sigma = 1 , dims = "seasons" )
572-                 seasons  =  season_effect [season_idx ]
579+                 seasons_effect  =  pm .Normal (
580+                     "seasons_effect" , mu = 0 , sigma = 50 , dims = "seasons" 
581+                 )
582+                 seasons  =  seasons_effect [season_idx ]
573583
574584            # --- Intervention effect --- 
575585            level  =  trend  =  impulse  =  0 
576586
577-             if  "level"  in  effect :
578-                 level  =  pm .Normal ("level" , mu = 0 , sigma = 10 )
579- 
580-             if  "trend"  in  effect :
581-                 trend  =  pm .Normal ("trend" , mu = 0 , sigma = 10 )
582- 
583-             if  "impulse"  in  effect :
584-                 impulse_amplitude  =  pm .Normal ("impulse_amplitude" , mu = 0 , sigma = 1 )
585-                 decay_rate  =  pm .HalfNormal ("decay_rate" , sigma = 1 )
587+             if  "level"  in  priors :
588+                 mu , sigma  =  (
589+                     (0 , 50 )
590+                     if  len (priors ["level" ]) !=  2 
591+                     else  (priors ["level" ][0 ], priors ["level" ][1 ])
592+                 )
593+                 level  =  pm .Normal (
594+                     "level" ,
595+                     mu = mu ,
596+                     sigma = sigma ,
597+                 )
598+             if  "trend"  in  priors :
599+                 mu , sigma  =  (
600+                     (0 , 50 )
601+                     if  len (priors ["trend" ]) !=  2 
602+                     else  (priors ["trend" ][0 ], priors ["trend" ][1 ])
603+                 )
604+                 trend  =  pm .Normal ("trend" , mu = mu , sigma = sigma )
605+             if  "impulse"  in  priors :
606+                 mu , sigma1 , sigma2  =  (
607+                     (0 , 50 , 50 )
608+                     if  len (priors ["impulse" ]) !=  3 
609+                     else  (
610+                         priors ["impulse" ][0 ],
611+                         priors ["impulse" ][1 ],
612+                         priors ["impulse" ][2 ],
613+                     )
614+                 )
615+                 impulse_amplitude  =  pm .Normal ("impulse_amplitude" , mu = mu , sigma = sigma1 )
616+                 decay_rate  =  pm .HalfNormal ("decay_rate" , sigma = sigma2 )
586617                impulse  =  impulse_amplitude  *  pm .math .exp (
587618                    - decay_rate  *  abs (t  -  switchpoint )
588619                )
@@ -597,16 +628,16 @@ def build_model(self, t, y, coords, effect, span, grain_season):
597628            )
598629            # Compute and store the the sum of the intervention and the time series 
599630            mu  =  pm .Deterministic ("mu" , mu_ts  +  weight  *  mu_in )
631+             sigma  =  pm .HalfNormal ("sigma" , 1 )
600632
601633            # --- Likelihood --- 
602-             pm .Normal ("y_hat" , mu = mu , sigma = 2 , observed = y )
634+             pm .Normal ("y_hat" , mu = mu , sigma = sigma , observed = y )
603635
604-     def  fit (self , t , y , coords , effect = [],  span = None , grain_season = 1 , n = 1000 ):
636+     def  fit (self , t , y , coords , time_range = None , grain_season = 1 ,  priors = {} , n = 1000 ):
605637        """ 
606638        Draw samples from posterior distribution 
607639        """ 
608-         self .sample_kwargs ["progressbar" ] =  False 
609-         self .build_model (t , y , coords , effect , span , grain_season )
640+         self .build_model (t , y , coords , time_range , grain_season , priors )
610641        with  self :
611-             self .idata  =  pm .sample (n , ** self .sample_kwargs )
642+             self .idata  =  pm .sample (n , progressbar = False ,  ** self .sample_kwargs )
612643        return  self .idata 
0 commit comments