1414import torchvision
1515
1616from autoPyTorch .constants import CLASSIFICATION_OUTPUTS , STRING_TO_OUTPUT_TYPES
17- from autoPyTorch .datasets .resampling_strategy import CrossValTypes , HoldoutTypes
17+ from autoPyTorch .datasets .resampling_strategy import CrossValTypes , HoldoutValTypes
1818from autoPyTorch .utils .common import FitRequirement
1919
2020BaseDatasetInputType = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
@@ -69,7 +69,7 @@ def __init__(
6969 dataset_name : Optional [str ] = None ,
7070 val_tensors : Optional [BaseDatasetInputType ] = None ,
7171 test_tensors : Optional [BaseDatasetInputType ] = None ,
72- resampling_strategy : Union [CrossValTypes , HoldoutTypes ] = HoldoutTypes . holdout ,
72+ resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes . holdout_validation ,
7373 resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
7474 seed : Optional [int ] = 42 ,
7575 train_transforms : Optional [torchvision .transforms .Compose ] = None ,
@@ -85,8 +85,8 @@ def __init__(
8585 validation data
8686 test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute):
8787 test data
88- resampling_strategy (Union[CrossValTypes, HoldoutTypes ]),
89- (default=HoldoutTypes.holdout ):
88+ resampling_strategy (Union[CrossValTypes, HoldoutValTypes ]),
89+ (default=HoldoutValTypes.holdout_validation ):
9090 strategy to split the training data.
9191 resampling_strategy_args (Optional[Dict[str, Any]]):
9292 arguments required for the chosen resampling strategy.
@@ -196,7 +196,7 @@ def _get_indices(self) -> np.ndarray:
196196
197197 def _process_resampling_strategy_args (self ) -> None :
198198 if not any (isinstance (self .resampling_strategy , val_type )
199- for val_type in [HoldoutTypes , CrossValTypes ]):
199+ for val_type in [HoldoutValTypes , CrossValTypes ]):
200200 raise ValueError (f"resampling_strategy { self .resampling_strategy } is not supported." )
201201
202202 if self .resampling_strategy_args is not None and \
@@ -229,7 +229,7 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
229229
230230 labels_to_stratify = self .train_tensors [- 1 ] if self .is_stratify else None
231231
232- if isinstance (self .resampling_strategy , HoldoutTypes ):
232+ if isinstance (self .resampling_strategy , HoldoutValTypes ):
233233 val_share = self .resampling_strategy_args ['val_share' ]
234234
235235 return self .resampling_strategy (
0 commit comments