2020from autoPyTorch .datasets .base_dataset import BaseDataset
2121from autoPyTorch .datasets .resampling_strategy import (
2222 CrossValTypes ,
23- HoldoutValTypes ,
23+ HoldoutTypes ,
2424)
2525
2626
@@ -44,13 +44,12 @@ class TabularDataset(BaseDataset):
4444 Y (Union[np.ndarray, pd.Series]): training data targets.
4545 X_test (Optional[Union[np.ndarray, pd.DataFrame]]): input testing data.
4646 Y_test (Optional[Union[np.ndarray, pd.DataFrame]]): testing data targets
47- resampling_strategy (Union[CrossValTypes, HoldoutValTypes ]),
48- (default=HoldoutValTypes.holdout_validation ):
47+ resampling_strategy (Union[CrossValTypes, HoldoutTypes ]),
48+ (default=HoldoutTypes.holdout ):
4949 strategy to split the training data.
50- resampling_strategy_args (Optional[Dict[str, Any]]): arguments
51- required for the chosen resampling strategy. If None, uses
52- the default values provided in DEFAULT_RESAMPLING_PARAMETERS
53- in ```datasets/resampling_strategy.py```.
50+ resampling_strategy_args (Optional[Dict[str, Any]]):
51+ arguments required for the chosen resampling strategy.
52+ The details are provided in autoPytorch/datasets/resampling_strategy.py
5453 shuffle: Whether to shuffle the data before performing splits
5554 seed (int), (default=1): seed to be used for reproducibility.
5655 train_transforms (Optional[torchvision.transforms.Compose]):
@@ -67,9 +66,8 @@ def __init__(self,
6766 Y : Union [np .ndarray , pd .Series ],
6867 X_test : Optional [Union [np .ndarray , pd .DataFrame ]] = None ,
6968 Y_test : Optional [Union [np .ndarray , pd .DataFrame ]] = None ,
70- resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes . holdout_validation ,
69+ resampling_strategy : Union [CrossValTypes , HoldoutTypes ] = HoldoutTypes . holdout ,
7170 resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
72- shuffle : Optional [bool ] = True ,
7371 seed : Optional [int ] = 42 ,
7472 train_transforms : Optional [torchvision .transforms .Compose ] = None ,
7573 val_transforms : Optional [torchvision .transforms .Compose ] = None ,
@@ -92,7 +90,7 @@ def __init__(self,
9290 self .num_features = validator .feature_validator .num_features
9391 self .categories = validator .feature_validator .categories
9492
95- super ().__init__ (train_tensors = (X , Y ), test_tensors = (X_test , Y_test ), shuffle = shuffle ,
93+ super ().__init__ (train_tensors = (X , Y ), test_tensors = (X_test , Y_test ),
9694 resampling_strategy = resampling_strategy ,
9795 resampling_strategy_args = resampling_strategy_args ,
9896 seed = seed , train_transforms = train_transforms ,
0 commit comments