Skip to content

Commit a4274f2

Browse files
committed
ENH: Add options to consolidate with ecg2x notebooks
1 parent 347d8ed commit a4274f2

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

ml4ht/data/data_loader.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,21 +282,32 @@ def __getitem__(self, item: int) -> Batch:
282282
)
283283

284284

285-
def numpy_collate_fn(samples: List[Batch]) -> Batch:
285+
def numpy_collate_fn(
286+
samples: List[Batch],
287+
auto_float: bool = True
288+
) -> Batch:
286289
"""
287290
Merges a list of ml4ht batch formatted data.
288291
Can be used as 'collate_fn` in torch.utils.data.DataLoader
289292
so that the torch data loader is compatible with tensorflow models
290293
"""
291294
# construct correctly-shaped empty arrays for input and output of model
292295
in_batch_keys = list(samples[0][0])
296+
if auto_float:
297+
in_dtypes = {k: np.float32 for k in in_batch_keys}
298+
else:
299+
in_dtypes = {k: samples[0][0][k].dtype for k in in_batch_keys}
293300
in_batch = {
294-
k: np.empty((len(samples),) + samples[0][0][k].shape, dtype=np.float32)
301+
k: np.empty((len(samples),) + samples[0][0][k].shape, dtype=in_dtypes[k])
295302
for k in in_batch_keys
296303
}
297304
out_batch_keys = list(samples[0][1])
305+
if auto_float:
306+
out_dtypes = {k: np.float32 for k in out_batch_keys}
307+
else:
308+
out_dtypes = {k: samples[0][1][k].dtype for k in out_batch_keys}
298309
out_batch = {
299-
k: np.empty((len(samples),) + samples[0][1][k].shape, dtype=np.float32)
310+
k: np.empty((len(samples),) + samples[0][1][k].shape, dtype=out_dtypes[k])
300311
for k in out_batch_keys
301312
}
302313
# fill in the values of the input and output arrays

ml4ht/data/sample_getter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ def __init__(
2929
input_data_descriptions: List[DataDescription],
3030
output_data_descriptions: List[DataDescription],
3131
option_picker: OptionPicker = None,
32+
restricted_sample_id_idx = None,
3233
):
3334
self.input_data_descriptions = input_data_descriptions
3435
self.output_data_descriptions = output_data_descriptions
3536
self.option_picker = option_picker or self._default_option_picker
37+
self.restricted_sample_id_idx = restricted_sample_id_idx
3638

3739
@staticmethod
3840
def _default_option_picker(
@@ -67,6 +69,8 @@ def __call__(self, sample_id: SampleID) -> Batch:
6769
sample_id,
6870
self.input_data_descriptions + self.output_data_descriptions,
6971
)
72+
if self.restricted_sample_id_idx is not None:
73+
sample_id = sample_id[self.restricted_sample_id_idx]
7074
tensors_in = self._half_batch(sample_id, loading_options, True)
7175
tensors_out = self._half_batch(sample_id, loading_options, False)
7276
return tensors_in, tensors_out

0 commit comments

Comments
 (0)