1- from  enum  import  IntEnum 
1+ from  enum  import  Enum 
2+ from  functools  import  partial 
23from  typing  import  Any , Dict , List , Optional , Tuple , Union 
34
45import  numpy  as  np 
1718
1819# Use callback protocol as workaround, since callable with function fields count 'self' as argument 
1920class  CrossValFunc (Protocol ):
21+     """TODO: This class is not required anymore, because CrossValTypes class does not require get_validators()""" 
2022    def  __call__ (self ,
2123                 num_splits : int ,
2224                 indices : np .ndarray ,
2325                 stratify : Optional [Any ]) ->  List [Tuple [np .ndarray , np .ndarray ]]:
2426        ...
2527
2628
27- class  HoldOutFunc (Protocol ):
29+ class  HoldoutValFunc (Protocol ):
2830    def  __call__ (self , val_share : float , indices : np .ndarray , stratify : Optional [Any ]
2931                 ) ->  Tuple [np .ndarray , np .ndarray ]:
3032        ...
3133
3234
33- class  CrossValTypes (IntEnum ):
34-     """The type of cross validation 
35- 
36-     This class is used to specify the cross validation function 
37-     and is not supposed to be instantiated. 
38- 
39-     Examples: This class is supposed to be used as follows 
40-     >>> cv_type = CrossValTypes.k_fold_cross_validation 
41-     >>> print(cv_type.name) 
42- 
43-     k_fold_cross_validation 
44- 
45-     >>> for cross_val_type in CrossValTypes: 
46-             print(cross_val_type.name, cross_val_type.value) 
47- 
48-     stratified_k_fold_cross_validation 1 
49-     k_fold_cross_validation 2 
50-     stratified_shuffle_split_cross_validation 3 
51-     shuffle_split_cross_validation 4 
52-     time_series_cross_validation 5 
53-     """ 
54-     stratified_k_fold_cross_validation  =  1 
55-     k_fold_cross_validation  =  2 
56-     stratified_shuffle_split_cross_validation  =  3 
57-     shuffle_split_cross_validation  =  4 
58-     time_series_cross_validation  =  5 
59- 
60-     def  is_stratified (self ) ->  bool :
61-         stratified  =  [self .stratified_k_fold_cross_validation ,
62-                       self .stratified_shuffle_split_cross_validation ]
63-         return  getattr (self , self .name ) in  stratified 
64- 
65- 
66- class  HoldoutValTypes (IntEnum ):
67-     """TODO: change to enum using functools.partial""" 
68-     """The type of hold out validation (refer to CrossValTypes' doc-string)""" 
69-     holdout_validation  =  6 
70-     stratified_holdout_validation  =  7 
71- 
72-     def  is_stratified (self ) ->  bool :
73-         stratified  =  [self .stratified_holdout_validation ]
74-         return  getattr (self , self .name ) in  stratified 
75- 
76- 
77- """TODO: deprecate soon""" 
78- RESAMPLING_STRATEGIES  =  [CrossValTypes , HoldoutValTypes ]
79- 
80- """TODO: deprecate soon""" 
81- DEFAULT_RESAMPLING_PARAMETERS  =  {
82-     HoldoutValTypes .holdout_validation : {
83-         'val_share' : 0.33 ,
84-     },
85-     HoldoutValTypes .stratified_holdout_validation : {
86-         'val_share' : 0.33 ,
87-     },
88-     CrossValTypes .k_fold_cross_validation : {
89-         'num_splits' : 3 ,
90-     },
91-     CrossValTypes .stratified_k_fold_cross_validation : {
92-         'num_splits' : 3 ,
93-     },
94-     CrossValTypes .shuffle_split_cross_validation : {
95-         'num_splits' : 3 ,
96-     },
97-     CrossValTypes .time_series_cross_validation : {
98-         'num_splits' : 3 ,
99-     },
100- }  # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] 
101- 
102- 
103- class  HoldOutFuncs ():
35+ class  HoldoutValFuncs ():
10436    @staticmethod  
105-     def  holdout_validation (val_share : float , indices : np .ndarray , ** kwargs : Any ) ->  Tuple [np .ndarray , np .ndarray ]:
37+     def  holdout_validation (val_share : float , indices : np .ndarray , stratify : Optional [Any ] =  None ) \
38+             ->  Tuple [np .ndarray , np .ndarray ]:
10639        train , val  =  train_test_split (indices , test_size = val_share , shuffle = False )
10740        return  train , val 
10841
10942    @staticmethod  
110-     def  stratified_holdout_validation (val_share : float , indices : np .ndarray , ** kwargs :  Any ) \
43+     def  stratified_holdout_validation (val_share : float , indices : np .ndarray , stratify :  Optional [ Any ]  =   None ) \
11144            ->  Tuple [np .ndarray , np .ndarray ]:
112-         train , val  =  train_test_split (indices , test_size = val_share , shuffle = False , stratify = kwargs [ " stratify" ] )
45+         train , val  =  train_test_split (indices , test_size = val_share , shuffle = False , stratify = stratify )
11346        return  train , val 
11447
115-     @classmethod  
116-     def  get_holdout_validators (cls , * holdout_val_types : Tuple [HoldoutValTypes ]) ->  Dict [str , HoldOutFunc ]:
117- 
118-         holdout_validators  =  {
119-             holdout_val_type .name : getattr (cls , holdout_val_type .name )
120-             for  holdout_val_type  in  holdout_val_types 
121-         }
122-         return  holdout_validators 
123- 
12448
12549class  CrossValFuncs ():
12650    @staticmethod  
127-     def  shuffle_split_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs :  Any ) \
51+     def  shuffle_split_cross_validation (num_splits : int , indices : np .ndarray , stratify :  Optional [ Any ]  =   None ) \
12852            ->  List [Tuple [np .ndarray , np .ndarray ]]:
12953        cv  =  ShuffleSplit (n_splits = num_splits )
13054        splits  =  list (cv .split (indices ))
13155        return  splits 
13256
13357    @staticmethod  
134-     def  stratified_shuffle_split_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
58+     def  stratified_shuffle_split_cross_validation (num_splits : int , indices : np .ndarray ,
59+                                                   stratify : Optional [Any ] =  None ) \
13560            ->  List [Tuple [np .ndarray , np .ndarray ]]:
13661        cv  =  StratifiedShuffleSplit (n_splits = num_splits )
137-         splits  =  list (cv .split (indices , kwargs [ " stratify" ] ))
62+         splits  =  list (cv .split (indices , stratify ))
13863        return  splits 
13964
14065    @staticmethod  
141-     def  stratified_k_fold_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs :  Any ) \
66+     def  stratified_k_fold_cross_validation (num_splits : int , indices : np .ndarray , stratify :  Optional [ Any ]  =   None ) \
14267            ->  List [Tuple [np .ndarray , np .ndarray ]]:
14368        cv  =  StratifiedKFold (n_splits = num_splits )
144-         splits  =  list (cv .split (indices , kwargs [ " stratify" ] ))
69+         splits  =  list (cv .split (indices , stratify ))
14570        return  splits 
14671
14772    @staticmethod  
148-     def  k_fold_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs :  Any ) \
73+     def  k_fold_cross_validation (num_splits : int , indices : np .ndarray , stratify :  Optional [ Any ]  =   None ) \
14974            ->  List [Tuple [np .ndarray , np .ndarray ]]:
15075        """ 
15176        Standard k fold cross validation. 
@@ -159,7 +84,7 @@ def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any)
15984        return  splits 
16085
16186    @staticmethod  
162-     def  time_series_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs :  Any ) \
87+     def  time_series_cross_validation (num_splits : int , indices : np .ndarray , stratify :  Optional [ Any ]  =   None ) \
16388            ->  List [Tuple [np .ndarray , np .ndarray ]]:
16489        """ 
16590        Returns train and validation indices respecting the temporal ordering of the data. 
@@ -176,10 +101,96 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs:
176101        splits  =  list (cv .split (indices ))
177102        return  splits 
178103
179-     @classmethod  
180-     def  get_cross_validators (cls , * cross_val_types : CrossValTypes ) ->  Dict [str , CrossValFunc ]:
181-         cross_validators  =  {
182-             cross_val_type .name : getattr (cls , cross_val_type .name )
183-             for  cross_val_type  in  cross_val_types 
184-         }
185-         return  cross_validators 
104+ 
105+ class  CrossValTypes (Enum ):
106+     """The type of cross validation 
107+ 
108+     This class is used to specify the cross validation function 
109+     and is not supposed to be instantiated. 
110+ 
111+     Examples: This class is supposed to be used as follows 
112+     >>> cv_type = CrossValTypes.k_fold_cross_validation 
113+     >>> print(cv_type.name) 
114+ 
115+     k_fold_cross_validation 
116+ 
117+     >>> print(cv_type.value) 
118+ 
119+     functools.partial(<function CrossValTypes.k_fold_cross_validation at ...>) 
120+ 
121+     >>> for cross_val_type in CrossValTypes: 
122+             print(cross_val_type.name) 
123+ 
124+     stratified_k_fold_cross_validation 
125+     k_fold_cross_validation 
126+     stratified_shuffle_split_cross_validation 
127+     shuffle_split_cross_validation 
128+     time_series_cross_validation 
129+ 
130+     Additionally, CrossValTypes.<function> can be called directly. 
131+     """ 
132+     stratified_k_fold_cross_validation  =  partial (CrossValFuncs .stratified_k_fold_cross_validation )
133+     k_fold_cross_validation  =  partial (CrossValFuncs .k_fold_cross_validation )
134+     stratified_shuffle_split_cross_validation  =  partial (CrossValFuncs .stratified_shuffle_split_cross_validation )
135+     shuffle_split_cross_validation  =  partial (CrossValFuncs .shuffle_split_cross_validation )
136+     time_series_cross_validation  =  partial (CrossValFuncs .time_series_cross_validation )
137+ 
138+     def  is_stratified (self ) ->  bool :
139+         stratified  =  [self .stratified_k_fold_cross_validation ,
140+                       self .stratified_shuffle_split_cross_validation ]
141+         return  getattr (self , self .name ) in  stratified 
142+ 
143+     def  __call__ (self , num_splits : int , indices : np .ndarray , stratify : Optional [Any ]
144+                  ) ->  Tuple [np .ndarray , np .ndarray ]:
145+         """TODO: doc-string and test files""" 
146+         self .value (num_splits = num_splits , indices = indices , stratify = stratify )
147+ 
148+     @staticmethod  
149+     def  get_validators (* choices : CrossValFunc ):
150+         """TODO: to be compatible, it is here now, but will be deprecated soon.""" 
151+         return  {choice .name : choice .value  for  choice  in  choices }
152+ 
153+ 
154+ class  HoldoutValTypes (Enum ):
155+     """The type of hold out validation (refer to CrossValTypes' doc-string)""" 
156+     holdout_validation  =  partial (HoldoutValFuncs .holdout_validation )
157+     stratified_holdout_validation  =  partial (HoldoutValFuncs .stratified_holdout_validation )
158+ 
159+     def  is_stratified (self ) ->  bool :
160+         stratified  =  [self .stratified_holdout_validation ]
161+         return  getattr (self , self .name ) in  stratified 
162+ 
163+     def  __call__ (self , val_share : float , indices : np .ndarray , stratify : Optional [Any ]
164+                  ) ->  Tuple [np .ndarray , np .ndarray ]:
165+         self .value (val_share = val_share , indices = indices , stratify = stratify )
166+ 
167+     @staticmethod  
168+     def  get_validators (* choices : HoldoutValFunc ):
169+         """TODO: to be compatible, it is here now, but will be deprecated soon.""" 
170+         return  {choice .name : choice .value  for  choice  in  choices }
171+ 
172+ 
173+ """TODO: deprecate soon (Will rename CrossValTypes -> CrossValFunc)""" 
174+ RESAMPLING_STRATEGIES  =  [CrossValTypes , HoldoutValTypes ]
175+ 
176+ """TODO: deprecate soon""" 
177+ DEFAULT_RESAMPLING_PARAMETERS  =  {
178+     HoldoutValTypes .holdout_validation : {
179+         'val_share' : 0.33 ,
180+     },
181+     HoldoutValTypes .stratified_holdout_validation : {
182+         'val_share' : 0.33 ,
183+     },
184+     CrossValTypes .k_fold_cross_validation : {
185+         'num_splits' : 3 ,
186+     },
187+     CrossValTypes .stratified_k_fold_cross_validation : {
188+         'num_splits' : 3 ,
189+     },
190+     CrossValTypes .shuffle_split_cross_validation : {
191+         'num_splits' : 3 ,
192+     },
193+     CrossValTypes .time_series_cross_validation : {
194+         'num_splits' : 3 ,
195+     },
196+ }  # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] 
0 commit comments