99
1010import  multiprocessing 
1111import  re 
12- from  typing  import  Any , Callable ,  Dict ,  List ,  Optional 
12+ from  typing  import  Any , Callable 
1313
1414import  torch 
1515from  datasets  import  Dataset 
2424
2525def  get_processed_dataset (
2626    dataset_args : DatasetArguments ,
27-     processor : Optional [ Processor ]  =  None ,
27+     processor : Processor   |   None  =  None ,
2828    do_oneshot : bool  =  False ,
2929    do_train : bool  =  True ,
30- ) ->  Optional [ Dict [ str , Dataset ]] :
30+ ) ->  dict [ str , Dataset ]  |   None :
3131    """ 
3232    Loads datasets for each flow based on dataset_args, stores a Dataset for each 
3333    enabled flow in datasets 
@@ -50,17 +50,22 @@ def get_processed_dataset(
5050
5151    def  _get_split_name (inp_str ):
5252        # strip out split name, for ex train[60%:] -> train 
53-         match  =  re .match (r"(\w*)\[.*\]" , inp_str )
54-         if  match  is  not None :
55-             return  match .group (1 )
53+         split_name_match  =  re .match (r"(\w*)\[.*\]" , inp_str )
54+         if  split_name_match  is  not None :
55+             return  split_name_match .group (1 )
5656        return  inp_str 
5757
58-     if  splits  is  None :
59-         splits  =  {"all" : None }
60-     elif  isinstance (splits , str ):
61-         splits  =  {_get_split_name (splits ): splits }
62-     elif  isinstance (splits , List ):
63-         splits  =  {_get_split_name (s ): s  for  s  in  splits }
58+     match  splits :
59+         case  None :
60+             splits  =  {"all" : None }
61+         case  str ():
62+             splits  =  {_get_split_name (splits ): splits }
63+         case  list ():
64+             splits  =  {_get_split_name (s ): s  for  s  in  splits }
65+         case  dict ():
66+             pass 
67+         case  _:
68+             raise  ValueError (f"Invalid splits type: { type (splits )}  )
6469
6570    # default to custom dataset if dataset provided isn't a string 
6671    registry_id  =  (
@@ -121,10 +126,10 @@ def get_calibration_dataloader(
121126
122127def  format_calibration_data (
123128    tokenized_dataset : Dataset ,
124-     num_calibration_samples : Optional [ int ]  =  None ,
129+     num_calibration_samples : int   |   None  =  None ,
125130    do_shuffle : bool  =  True ,
126131    collate_fn : Callable  =  default_data_collator ,
127- ) ->  List [torch .Tensor ]:
132+ ) ->  list [torch .Tensor ]:
128133    """ 
129134    Creates a dataloader out of the calibration dataset split, trimming it to 
130135    the desired number of calibration samples 
@@ -172,10 +177,10 @@ def format_calibration_data(
172177
173178
174179def  make_dataset_splits (
175-     tokenized_datasets : Dict [str , Any ],
180+     tokenized_datasets : dict [str , Any ],
176181    do_oneshot : bool  =  True ,
177182    do_train : bool  =  False ,
178- ) ->  Dict [str , Dataset ]:
183+ ) ->  dict [str , Dataset ]:
179184    """ 
180185    Restructures the datasets dictionary based on what tasks will be run 
181186    train 
0 commit comments