77
88from  smac .tae  import  StatusType 
99
10- from  autoPyTorch .automl_common . common . utils . backend  import  Backend 
11- from   autoPyTorch . constants   import  ( 
12-     CLASSIFICATION_TASKS ,
13-     MULTICLASSMULTIOUTPUT , 
10+ from  autoPyTorch .datasets . resampling_strategy  import  ( 
11+      CrossValTypes , 
12+     NoResamplingStrategyTypes ,
13+     check_resampling_strategy 
1414)
15- from  autoPyTorch .datasets .resampling_strategy  import  CrossValTypes , HoldoutValTypes 
1615from  autoPyTorch .evaluation .abstract_evaluator  import  (
1716    AbstractEvaluator ,
1817    EvaluationResults ,
2120from  autoPyTorch .evaluation .abstract_evaluator  import  EvaluatorParams , FixedPipelineParams 
2221from  autoPyTorch .utils .common  import  dict_repr , subsampler 
2322
24- __all__  =  ['TrainEvaluator' , 'eval_train_function' ]
23+ __all__  =  ['Evaluator' , 'eval_fn' ]
24+ 
2525
2626class  _CrossValidationResultsManager :
2727    def  __init__ (self , num_folds : int ):
@@ -83,15 +83,13 @@ def get_result_dict(self) -> Dict[str, Any]:
8383        )
8484
8585
86- class  TrainEvaluator (AbstractEvaluator ):
86+ class  Evaluator (AbstractEvaluator ):
8787    """ 
8888    This class builds a pipeline using the provided configuration. 
8989    A pipeline implementing the provided configuration is fitted 
9090    using the datamanager object retrieved from disc, via the backend. 
9191    After the pipeline is fitted, it is save to disc and the performance estimate 
92-     is communicated to the main process via a Queue. It is only compatible 
93-     with `CrossValTypes`, `HoldoutValTypes`, i.e, when the training data 
94-     is split and the validation set is used for SMBO optimisation. 
92+     is communicated to the main process via a Queue. 
9593
9694    Args: 
9795        queue (Queue): 
@@ -101,54 +99,27 @@ class TrainEvaluator(AbstractEvaluator):
10199            Fixed parameters for a pipeline 
102100        evaluator_params (EvaluatorParams): 
103101            The parameters for an evaluator. 
102+ 
103+     Attributes: 
104+         train (bool): 
105+             Whether the training data is split and the validation set is used for SMBO optimisation. 
106+         cross_validation (bool): 
107+             Whether we use cross validation or not. 
104108    """ 
105-     def  __init__ (self , backend : Backend , queue : Queue ,
106-                  metric : autoPyTorchMetric ,
107-                  budget : float ,
108-                  configuration : Union [int , str , Configuration ],
109-                  budget_type : str  =  None ,
110-                  pipeline_config : Optional [Dict [str , Any ]] =  None ,
111-                  seed : int  =  1 ,
112-                  output_y_hat_optimization : bool  =  True ,
113-                  num_run : Optional [int ] =  None ,
114-                  include : Optional [Dict [str , Any ]] =  None ,
115-                  exclude : Optional [Dict [str , Any ]] =  None ,
116-                  disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] =  None ,
117-                  init_params : Optional [Dict [str , Any ]] =  None ,
118-                  logger_port : Optional [int ] =  None ,
119-                  keep_models : Optional [bool ] =  None ,
120-                  all_supported_metrics : bool  =  True ,
121-                  search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] =  None ) ->  None :
122-         super ().__init__ (
123-             backend = backend ,
124-             queue = queue ,
125-             configuration = configuration ,
126-             metric = metric ,
127-             seed = seed ,
128-             output_y_hat_optimization = output_y_hat_optimization ,
129-             num_run = num_run ,
130-             include = include ,
131-             exclude = exclude ,
132-             disable_file_output = disable_file_output ,
133-             init_params = init_params ,
134-             budget = budget ,
135-             budget_type = budget_type ,
136-             logger_port = logger_port ,
137-             all_supported_metrics = all_supported_metrics ,
138-             pipeline_config = pipeline_config ,
139-             search_space_updates = search_space_updates 
140-         )
109+     def  __init__ (self , queue : Queue , fixed_pipeline_params : FixedPipelineParams , evaluator_params : EvaluatorParams ):
110+         resampling_strategy  =  fixed_pipeline_params .backend .load_datamanager ().resampling_strategy 
111+         self .train  =  not  isinstance (resampling_strategy , NoResamplingStrategyTypes )
112+         self .cross_validation  =  isinstance (resampling_strategy , CrossValTypes )
141113
142-         if  not  isinstance ( self .datamanager . resampling_strategy , ( CrossValTypes ,  HoldoutValTypes )) :
143-             resampling_strategy   =   self . datamanager . resampling_strategy 
144-             raise   ValueError ( 
145-                  f'resampling_strategy for TrainEvaluator must be in ' 
146-                  f'(CrossValTypes, HoldoutValTypes), but got  { resampling_strategy } ' 
147-              )
114+         if  not  self .train   and   fixed_pipeline_params . save_y_opt :
115+             # TODO: Add the test to cover here 
116+             # No resampling can not be used for building ensembles. save_y_opt=False ensures it 
117+             fixed_pipeline_params   =   fixed_pipeline_params . _replace ( save_y_opt = False ) 
118+ 
119+         super (). __init__ ( queue = queue ,  fixed_pipeline_params = fixed_pipeline_params ,  evaluator_params = evaluator_params )
148120
149-         self .splits  =  self .datamanager .splits 
150-         self .num_folds : int  =  len (self .splits )
151-         self .logger .debug ("Search space updates :{}" .format (self .search_space_updates ))
121+         if  self .train :
122+             self .logger .debug ("Search space updates :{}" .format (self .fixed_pipeline_params .search_space_updates ))
152123
153124    def  _evaluate_on_split (self , split_id : int ) ->  EvaluationResults :
154125        """ 
@@ -177,7 +148,7 @@ def _evaluate_on_split(self, split_id: int) -> EvaluationResults:
177148
178149        return  EvaluationResults (
179150            pipeline = pipeline ,
180-             opt_loss = self ._loss (labels = self .y_train [opt_split ], preds = opt_pred ),
151+             opt_loss = self ._loss (labels = self .y_train [opt_split ]  if   self . train   else   self . y_test , preds = opt_pred ),
181152            train_loss = self ._loss (labels = self .y_train [train_split ], preds = train_pred ),
182153            opt_pred = opt_pred ,
183154            valid_pred = valid_pred ,
@@ -203,6 +174,7 @@ def _cross_validation(self) -> EvaluationResults:
203174            results  =  self ._evaluate_on_split (split_id )
204175
205176            self .pipelines [split_id ] =  results .pipeline 
177+             assert  opt_split  is  not None   # mypy redefinition 
206178            cv_results .update (split_id , results , len (train_split ), len (opt_split ))
207179
208180        self .y_opt  =  np .concatenate ([y_opt  for  y_opt  in  Y_opt  if  y_opt  is  not None ])
@@ -214,15 +186,16 @@ def evaluate_loss(self) -> None:
214186        if  self .splits  is  None :
215187            raise  ValueError (f"cannot fit pipeline { self .__class__ .__name__ }  )
216188
217-         if  self .num_folds  ==  1 :
189+         if  self .cross_validation :
190+             results  =  self ._cross_validation ()
191+         else :
218192            _ , opt_split  =  self .splits [0 ]
219193            results  =  self ._evaluate_on_split (split_id = 0 )
220-             self .y_opt , self .pipelines [0 ] =  self .y_train [opt_split ], results .pipeline 
221-         else :
222-             results  =  self ._cross_validation ()
194+             self .pipelines [0 ] =  results .pipeline 
195+             self .y_opt  =  self .y_train [opt_split ] if  self .train  else  self .y_test 
223196
224197        self .logger .debug (
225-             f"In train evaluator. evaluate_loss, num_run: { self .num_run } { results .opt_loss }  
198+             f"In evaluate_loss, num_run: { self .num_run } { results .opt_loss }  
226199            f" status: { results .status } \n additional run info:\n { dict_repr (results .additional_run_info )}  
227200        )
228201        self .record_evaluation (results = results )
@@ -242,41 +215,23 @@ def _fit_and_evaluate_loss(
242215
243216        kwargs  =  {'pipeline' : pipeline , 'unique_train_labels' : self .unique_train_labels [split_id ]}
244217        train_pred  =  self .predict (subsampler (self .X_train , train_indices ), ** kwargs )
245-         opt_pred  =  self .predict (subsampler (self .X_train , opt_indices ), ** kwargs )
246-         valid_pred  =  self .predict (self .X_valid , ** kwargs )
247218        test_pred  =  self .predict (self .X_test , ** kwargs )
219+         valid_pred  =  self .predict (self .X_valid , ** kwargs )
220+ 
221+         # No resampling ===> evaluate on test dataset 
222+         opt_pred  =  self .predict (subsampler (self .X_train , opt_indices ), ** kwargs ) if  self .train  else  test_pred 
248223
249224        assert  train_pred  is  not None  and  opt_pred  is  not None   # mypy check 
250225        return  train_pred , opt_pred , valid_pred , test_pred 
251226
252227
253- # create closure for evaluating an algorithm 
254- def  eval_train_function (
255-     backend : Backend ,
256-     queue : Queue ,
257-     metric : autoPyTorchMetric ,
258-     budget : float ,
259-     config : Optional [Configuration ],
260-     seed : int ,
261-     output_y_hat_optimization : bool ,
262-     num_run : int ,
263-     include : Optional [Dict [str , Any ]],
264-     exclude : Optional [Dict [str , Any ]],
265-     disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] =  None ,
266-     pipeline_config : Optional [Dict [str , Any ]] =  None ,
267-     budget_type : str  =  None ,
268-     init_params : Optional [Dict [str , Any ]] =  None ,
269-     logger_port : Optional [int ] =  None ,
270-     all_supported_metrics : bool  =  True ,
271-     search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] =  None ,
272-     instance : str  =  None ,
273- ) ->  None :
228+ def  eval_fn (queue : Queue , fixed_pipeline_params : FixedPipelineParams , evaluator_params : EvaluatorParams ) ->  None :
274229    """ 
275230    This closure allows the communication between the TargetAlgorithmQuery and the 
276-     pipeline trainer (TrainEvaluator ). 
231+     pipeline trainer (Evaluator ). 
277232
278233    Fundamentally, smac calls the TargetAlgorithmQuery.run() method, which internally 
279-     builds a TrainEvaluator . The TrainEvaluator  builds a pipeline, stores the output files 
234+     builds an Evaluator . The Evaluator  builds a pipeline, stores the output files 
280235    to disc via the backend, and puts the performance result of the run in the queue. 
281236
282237    Args: 
@@ -288,7 +243,11 @@ def eval_train_function(
288243        evaluator_params (EvaluatorParams): 
289244            The parameters for an evaluator. 
290245    """ 
291-     evaluator  =  TrainEvaluator (
246+     resampling_strategy  =  fixed_pipeline_params .backend .load_datamanager ().resampling_strategy 
247+     check_resampling_strategy (resampling_strategy )
248+ 
249+     # NoResamplingStrategyTypes ==> test evaluator, otherwise ==> train evaluator 
250+     evaluator  =  Evaluator (
292251        queue = queue ,
293252        evaluator_params = evaluator_params ,
294253        fixed_pipeline_params = fixed_pipeline_params 
0 commit comments