99
1010import multiprocessing
1111import re
12- from typing import Any , Callable , Dict , List , Optional
12+ from typing import Any , Callable
1313
1414import torch
1515from datasets import Dataset
2424
2525def get_processed_dataset (
2626 dataset_args : DatasetArguments ,
27- processor : Optional [ Processor ] = None ,
27+ processor : Processor | None = None ,
2828 do_oneshot : bool = False ,
2929 do_train : bool = True ,
30- ) -> Optional [ Dict [ str , Dataset ]] :
30+ ) -> dict [ str , Dataset ] | None :
3131 """
3232 Loads datasets for each flow based on dataset_args, stores a Dataset for each
3333 enabled flow in datasets
@@ -50,17 +50,22 @@ def get_processed_dataset(
5050
5151 def _get_split_name (inp_str ):
5252 # strip out split name, for ex train[60%:] -> train
53- match = re .match (r"(\w*)\[.*\]" , inp_str )
54- if match is not None :
55- return match .group (1 )
53+ split_name_match = re .match (r"(\w*)\[.*\]" , inp_str )
54+ if split_name_match is not None :
55+ return split_name_match .group (1 )
5656 return inp_str
5757
58- if splits is None :
59- splits = {"all" : None }
60- elif isinstance (splits , str ):
61- splits = {_get_split_name (splits ): splits }
62- elif isinstance (splits , List ):
63- splits = {_get_split_name (s ): s for s in splits }
58+ match splits :
59+ case None :
60+ splits = {"all" : None }
61+ case str ():
62+ splits = {_get_split_name (splits ): splits }
63+ case list ():
64+ splits = {_get_split_name (s ): s for s in splits }
65+ case dict ():
66+ pass
67+ case _:
68+ raise ValueError (f"Invalid splits type: { type (splits )} " )
6469
6570 # default to custom dataset if dataset provided isn't a string
6671 registry_id = (
@@ -121,10 +126,10 @@ def get_calibration_dataloader(
121126
122127def format_calibration_data (
123128 tokenized_dataset : Dataset ,
124- num_calibration_samples : Optional [ int ] = None ,
129+ num_calibration_samples : int | None = None ,
125130 do_shuffle : bool = True ,
126131 collate_fn : Callable = default_data_collator ,
127- ) -> List [torch .Tensor ]:
132+ ) -> list [torch .Tensor ]:
128133 """
129134 Creates a dataloader out of the calibration dataset split, trimming it to
130135 the desired number of calibration samples
@@ -172,10 +177,10 @@ def format_calibration_data(
172177
173178
174179def make_dataset_splits (
175- tokenized_datasets : Dict [str , Any ],
180+ tokenized_datasets : dict [str , Any ],
176181 do_oneshot : bool = True ,
177182 do_train : bool = False ,
178- ) -> Dict [str , Dataset ]:
183+ ) -> dict [str , Dataset ]:
179184 """
180185 Restructures the datasets dictionary based on what tasks will be run
181186 train
0 commit comments