File tree Expand file tree Collapse file tree 2 files changed +18
-3
lines changed Expand file tree Collapse file tree 2 files changed +18
-3
lines changed Original file line number Diff line number Diff line change @@ -282,21 +282,32 @@ def __getitem__(self, item: int) -> Batch:
282282 )
283283
284284
285- def numpy_collate_fn (samples : List [Batch ]) -> Batch :
285+ def numpy_collate_fn (
286+ samples : List [Batch ],
287+ auto_float : bool = True
288+ ) -> Batch :
286289 """
287290 Merges a list of ml4ht batch formatted data.
288291 Can be used as 'collate_fn` in torch.utils.data.DataLoader
289292 so that the torch data loader is compatible with tensorflow models
290293 """
291294 # construct correctly-shaped empty arrays for input and output of model
292295 in_batch_keys = list (samples [0 ][0 ])
296+ if auto_float :
297+ in_dtypes = {k : np .float32 for k in in_batch_keys }
298+ else :
299+ in_dtypes = {k : samples [0 ][0 ][k ].dtype for k in in_batch_keys }
293300 in_batch = {
294- k : np .empty ((len (samples ),) + samples [0 ][0 ][k ].shape , dtype = np . float32 )
301+ k : np .empty ((len (samples ),) + samples [0 ][0 ][k ].shape , dtype = in_dtypes [ k ] )
295302 for k in in_batch_keys
296303 }
297304 out_batch_keys = list (samples [0 ][1 ])
305+ if auto_float :
306+ out_dtypes = {k : np .float32 for k in out_batch_keys }
307+ else :
308+ out_dtypes = {k : samples [0 ][1 ][k ].dtype for k in out_batch_keys }
298309 out_batch = {
299- k : np .empty ((len (samples ),) + samples [0 ][1 ][k ].shape , dtype = np . float32 )
310+ k : np .empty ((len (samples ),) + samples [0 ][1 ][k ].shape , dtype = out_dtypes [ k ] )
300311 for k in out_batch_keys
301312 }
302313 # fill in the values of the input and output arrays
Original file line number Diff line number Diff line change @@ -29,10 +29,12 @@ def __init__(
2929 input_data_descriptions : List [DataDescription ],
3030 output_data_descriptions : List [DataDescription ],
3131 option_picker : OptionPicker = None ,
32+ restricted_sample_id_idx = None ,
3233 ):
3334 self .input_data_descriptions = input_data_descriptions
3435 self .output_data_descriptions = output_data_descriptions
3536 self .option_picker = option_picker or self ._default_option_picker
37+ self .restricted_sample_id_idx = restricted_sample_id_idx
3638
3739 @staticmethod
3840 def _default_option_picker (
@@ -67,6 +69,8 @@ def __call__(self, sample_id: SampleID) -> Batch:
6769 sample_id ,
6870 self .input_data_descriptions + self .output_data_descriptions ,
6971 )
72+ if self .restricted_sample_id_idx is not None :
73+ sample_id = sample_id [self .restricted_sample_id_idx ]
7074 tensors_in = self ._half_batch (sample_id , loading_options , True )
7175 tensors_out = self ._half_batch (sample_id , loading_options , False )
7276 return tensors_in , tensors_out
You can’t perform that action at this time.
0 commit comments