1- from  enum  import  IntEnum 
2- from  typing  import  Any , Dict , List , Optional , Tuple , Union 
1+ from  enum  import  Enum 
2+ from  functools  import  partial 
3+ from  typing  import  List , NamedTuple , Optional , Tuple , Union 
34
45import  numpy  as  np 
56
1213    train_test_split 
1314)
1415
15- from  typing_extensions  import  Protocol 
16+ from  torch . utils . data  import  Dataset 
1617
1718
18- # Use callback protocol as workaround, since callable with function fields count 'self' as argument 
19- class  CrossValFunc (Protocol ):
20-     def  __call__ (self ,
21-                  random_state : np .random .RandomState ,
22-                  num_splits : int ,
23-                  indices : np .ndarray ,
24-                  stratify : Optional [Any ]) ->  List [Tuple [np .ndarray , np .ndarray ]]:
25-         ...
19+ class  _ResamplingStrategyArgs (NamedTuple ):
20+     val_share : float  =  0.33 
21+     num_splits : int  =  5 
22+     shuffle : bool  =  False 
23+     stratify : bool  =  False 
2624
2725
28- class  HoldOutFunc (Protocol ):
29-     def  __call__ (self , random_state : np .random .RandomState , val_share : float ,
30-                  indices : np .ndarray , stratify : Optional [Any ]
31-                  ) ->  Tuple [np .ndarray , np .ndarray ]:
32-         ...
33- 
34- 
35- class  CrossValTypes (IntEnum ):
36-     """The type of cross validation 
37- 
38-     This class is used to specify the cross validation function 
39-     and is not supposed to be instantiated. 
40- 
41-     Examples: This class is supposed to be used as follows 
42-     >>> cv_type = CrossValTypes.k_fold_cross_validation 
43-     >>> print(cv_type.name) 
44- 
45-     k_fold_cross_validation 
46- 
47-     >>> for cross_val_type in CrossValTypes: 
48-             print(cross_val_type.name, cross_val_type.value) 
49- 
50-     stratified_k_fold_cross_validation 1 
51-     k_fold_cross_validation 2 
52-     stratified_shuffle_split_cross_validation 3 
53-     shuffle_split_cross_validation 4 
54-     time_series_cross_validation 5 
55-     """ 
56-     stratified_k_fold_cross_validation  =  1 
57-     k_fold_cross_validation  =  2 
58-     stratified_shuffle_split_cross_validation  =  3 
59-     shuffle_split_cross_validation  =  4 
60-     time_series_cross_validation  =  5 
61- 
62-     def  is_stratified (self ) ->  bool :
63-         stratified  =  [self .stratified_k_fold_cross_validation ,
64-                       self .stratified_shuffle_split_cross_validation ]
65-         return  getattr (self , self .name ) in  stratified 
66- 
67- 
68- class  HoldoutValTypes (IntEnum ):
69-     """TODO: change to enum using functools.partial""" 
70-     """The type of hold out validation (refer to CrossValTypes' doc-string)""" 
71-     holdout_validation  =  6 
72-     stratified_holdout_validation  =  7 
73- 
74-     def  is_stratified (self ) ->  bool :
75-         stratified  =  [self .stratified_holdout_validation ]
76-         return  getattr (self , self .name ) in  stratified 
77- 
78- 
79- # TODO: replace it with another way 
80- RESAMPLING_STRATEGIES  =  [CrossValTypes , HoldoutValTypes ]
81- 
82- DEFAULT_RESAMPLING_PARAMETERS  =  {
83-     HoldoutValTypes .holdout_validation : {
84-         'val_share' : 0.33 ,
85-     },
86-     HoldoutValTypes .stratified_holdout_validation : {
87-         'val_share' : 0.33 ,
88-     },
89-     CrossValTypes .k_fold_cross_validation : {
90-         'num_splits' : 5 ,
91-     },
92-     CrossValTypes .stratified_k_fold_cross_validation : {
93-         'num_splits' : 5 ,
94-     },
95-     CrossValTypes .shuffle_split_cross_validation : {
96-         'num_splits' : 5 ,
97-     },
98-     CrossValTypes .time_series_cross_validation : {
99-         'num_splits' : 5 ,
100-     },
101- }  # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] 
102- 
103- 
104- class  HoldOutFuncs ():
26+ class  HoldoutFuncs ():
10527    @staticmethod  
106-     def  holdout_validation (random_state : np .random .RandomState ,
107-                            val_share : float ,
108-                            indices : np .ndarray ,
109-                            ** kwargs : Any 
110-                            ) ->  Tuple [np .ndarray , np .ndarray ]:
111-         shuffle  =  kwargs .get ('shuffle' , True )
112-         train , val  =  train_test_split (indices , test_size = val_share ,
113-                                       shuffle = shuffle ,
114-                                       random_state = random_state  if  shuffle  else  None ,
115-                                       )
28+     def  holdout_validation (
29+         random_state : np .random .RandomState ,
30+         val_share : float ,
31+         indices : np .ndarray ,
32+         shuffle : bool  =  False ,
33+         labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] =  None 
34+     ):
35+ 
36+         train , val  =  train_test_split (
37+             indices , test_size = val_share , shuffle = shuffle ,
38+             random_state = random_state  if  shuffle  else  None ,
39+             stratify = labels_to_stratify 
40+         )
11641        return  train , val 
11742
118-     @staticmethod  
119-     def  stratified_holdout_validation (random_state : np .random .RandomState ,
120-                                       val_share : float ,
121-                                       indices : np .ndarray ,
122-                                       ** kwargs : Any 
123-                                       ) ->  Tuple [np .ndarray , np .ndarray ]:
124-         train , val  =  train_test_split (indices , test_size = val_share , shuffle = True , stratify = kwargs ["stratify" ],
125-                                       random_state = random_state )
126-         return  train , val 
127- 
128-     @classmethod  
129-     def  get_holdout_validators (cls , * holdout_val_types : HoldoutValTypes ) ->  Dict [str , HoldOutFunc ]:
130- 
131-         holdout_validators  =  {
132-             holdout_val_type .name : getattr (cls , holdout_val_type .name )
133-             for  holdout_val_type  in  holdout_val_types 
134-         }
135-         return  holdout_validators 
136- 
13743
13844class  CrossValFuncs ():
139-     @staticmethod  
140-     def  shuffle_split_cross_validation (random_state : np .random .RandomState ,
141-                                        num_splits : int ,
142-                                        indices : np .ndarray ,
143-                                        ** kwargs : Any 
144-                                        ) ->  List [Tuple [np .ndarray , np .ndarray ]]:
145-         cv  =  ShuffleSplit (n_splits = num_splits , random_state = random_state )
146-         splits  =  list (cv .split (indices ))
147-         return  splits 
148- 
149-     @staticmethod  
150-     def  stratified_shuffle_split_cross_validation (random_state : np .random .RandomState ,
151-                                                   num_splits : int ,
152-                                                   indices : np .ndarray ,
153-                                                   ** kwargs : Any 
154-                                                   ) ->  List [Tuple [np .ndarray , np .ndarray ]]:
155-         cv  =  StratifiedShuffleSplit (n_splits = num_splits , random_state = random_state )
156-         splits  =  list (cv .split (indices , kwargs ["stratify" ]))
157-         return  splits 
158- 
159-     @staticmethod  
160-     def  stratified_k_fold_cross_validation (random_state : np .random .RandomState ,
161-                                            num_splits : int ,
162-                                            indices : np .ndarray ,
163-                                            ** kwargs : Any 
164-                                            ) ->  List [Tuple [np .ndarray , np .ndarray ]]:
165-         cv  =  StratifiedKFold (n_splits = num_splits , random_state = random_state )
166-         splits  =  list (cv .split (indices , kwargs ["stratify" ]))
167-         return  splits 
45+     # (shuffle, is_stratify) -> split_fn 
46+     _args2split_fn  =  {
47+         (True , True ): StratifiedShuffleSplit ,
48+         (True , False ): ShuffleSplit ,
49+         (False , True ): StratifiedKFold ,
50+         (False , False ): KFold ,
51+     }
16852
16953    @staticmethod  
170-     def  k_fold_cross_validation (random_state : np .random .RandomState ,
171-                                 num_splits : int ,
172-                                 indices : np .ndarray ,
173-                                 ** kwargs : Any 
174-                                 ) ->  List [Tuple [np .ndarray , np .ndarray ]]:
54+     def  k_fold_cross_validation (
55+         random_state : np .random .RandomState ,
56+         num_splits : int ,
57+         indices : np .ndarray ,
58+         shuffle : bool  =  False ,
59+         labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] =  None 
60+     ) ->  List [Tuple [np .ndarray , np .ndarray ]]:
17561        """ 
176-         Standard k fold cross validation. 
177- 
178-         Args: 
179-             indices (np.ndarray): array of indices to be split 
180-             num_splits (int): number of cross validation splits 
181- 
18262        Returns: 
18363            splits (List[Tuple[List, List]]): list of tuples of training and validation indices 
18464        """ 
185-         shuffle  =  kwargs .get ('shuffle' , True )
186-         cv  =  KFold (n_splits = num_splits , random_state = random_state  if  shuffle  else  None , shuffle = shuffle )
65+ 
66+         split_fn  =  CrossValFuncs ._args2split_fn [(shuffle , labels_to_stratify  is  not None )]
67+         cv  =  split_fn (n_splits = num_splits , random_state = random_state )
18768        splits  =  list (cv .split (indices ))
18869        return  splits 
18970
19071    @staticmethod  
191-     def  time_series_cross_validation (random_state : np .random .RandomState ,
192-                                      num_splits : int ,
193-                                      indices : np .ndarray ,
194-                                      ** kwargs : Any 
195-                                      ) ->  List [Tuple [np .ndarray , np .ndarray ]]:
72+     def  time_series (
73+         random_state : np .random .RandomState ,
74+         num_splits : int ,
75+         indices : np .ndarray ,
76+         shuffle : bool  =  False ,
77+         labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] =  None 
78+     ) ->  List [Tuple [np .ndarray , np .ndarray ]]:
19679        """ 
19780        Returns train and validation indices respecting the temporal ordering of the data. 
19881
@@ -215,10 +98,115 @@ def time_series_cross_validation(random_state: np.random.RandomState,
21598        splits  =  list (cv .split (indices ))
21699        return  splits 
217100
218-     @classmethod  
219-     def  get_cross_validators (cls , * cross_val_types : CrossValTypes ) ->  Dict [str , CrossValFunc ]:
220-         cross_validators  =  {
221-             cross_val_type .name : getattr (cls , cross_val_type .name )
222-             for  cross_val_type  in  cross_val_types 
223-         }
224-         return  cross_validators 
101+ 
102+ class  CrossValTypes (Enum ):
103+     """The type of cross validation 
104+ 
105+     This class is used to specify the cross validation function 
106+     and is not supposed to be instantiated. 
107+ 
108+     Examples: This class is supposed to be used as follows 
109+     >>> cv_type = CrossValTypes.k_fold_cross_validation 
110+     >>> print(cv_type.name) 
111+ 
112+     k_fold_cross_validation 
113+ 
114+     >>> for cross_val_type in CrossValTypes: 
115+             print(cross_val_type.name, cross_val_type.value) 
116+ 
117+     k_fold_cross_validation functools.partial(<function CrossValFuncs.k_fold_cross_validation at ...>) 
118+     time_series <function CrossValFuncs.time_series> 
119+     """ 
120+     k_fold_cross_validation  =  partial (CrossValFuncs .k_fold_cross_validation )
121+     time_series  =  partial (CrossValFuncs .time_series )
122+ 
123+     def  __call__ (
124+         self ,
125+         random_state : np .random .RandomState ,
126+         indices : np .ndarray ,
127+         num_splits : int  =  5 ,
128+         shuffle : bool  =  False ,
129+         labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] =  None 
130+     ) ->  List [Tuple [np .ndarray , np .ndarray ]]:
131+         """ 
132+         This function allows to call and type-check the specified function. 
133+ 
134+         Args: 
135+             random_state (np.random.RandomState): random number genetor for the reproducibility 
136+             num_splits (int): The number of splits in cross validation 
137+             indices (np.ndarray): The indices of data points in a dataset 
138+             shuffle (bool): If shuffle the indices or not 
139+             labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]): 
140+                 The labels of the corresponding data points. It is used for the stratification. 
141+ 
142+         Returns: 
143+             splits (List[Tuple[np.ndarray, np.ndarray]]): 
144+                 splits[a split identifier][0: train, 1: val][a data point identifier] 
145+ 
146+         """ 
147+         return  self .value (
148+             random_state = random_state ,
149+             num_splits = num_splits ,
150+             indices = indices ,
151+             shuffle = shuffle ,
152+             labels_to_stratify = labels_to_stratify 
153+         )
154+ 
155+ 
156+ class  HoldoutValTypes (Enum ):
157+     """The type of holdout validation 
158+ 
159+     This class is used to specify the holdout validation function 
160+     and is not supposed to be instantiated. 
161+ 
162+     Examples: This class is supposed to be used as follows 
163+     >>> holdout_type = HoldoutValTypes.holdout_validation 
164+     >>> print(holdout_type.name) 
165+ 
166+     holdout_validation 
167+ 
168+     >>> print(holdout_type.value) 
169+ 
170+     functools.partial(<function HoldoutValTypes.holdout_validation at ...>) 
171+ 
172+     >>> for holdout_type in HoldoutValTypes: 
173+             print(holdout_type.name) 
174+ 
175+     holdout_validation 
176+ 
177+     Additionally, HoldoutValTypes.<function> can be called directly. 
178+     """ 
179+ 
180+     holdout  =  partial (HoldoutFuncs .holdout_validation )
181+ 
182+     def  __call__ (
183+         self ,
184+         random_state : np .random .RandomState ,
185+         indices : np .ndarray ,
186+         val_share : float  =  0.33 ,
187+         shuffle : bool  =  False ,
188+         labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] =  None 
189+     ) ->  List [Tuple [np .ndarray , np .ndarray ]]:
190+         """ 
191+         This function allows to call and type-check the specified function. 
192+ 
193+         Args: 
194+             random_state (np.random.RandomState): random number genetor for the reproducibility 
195+             val_share (float): The ratio of validation dataset vs the given dataset 
196+             indices (np.ndarray): The indices of data points in a dataset 
197+             shuffle (bool): If shuffle the indices or not 
198+             labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]): 
199+                 The labels of the corresponding data points. It is used for the stratification. 
200+ 
201+         Returns: 
202+             splits (List[Tuple[np.ndarray, np.ndarray]]): 
203+                 splits[a split identifier][0: train, 1: val][a data point identifier] 
204+ 
205+         """ 
206+         return  self .value (
207+             random_state = random_state ,
208+             val_share = val_share ,
209+             indices = indices ,
210+             shuffle = shuffle ,
211+             labels_to_stratify = labels_to_stratify 
212+         )
0 commit comments