-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbatch_generator.py
More file actions
288 lines (244 loc) · 14.2 KB
/
batch_generator.py
File metadata and controls
288 lines (244 loc) · 14.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
from time import time
from batchgenerators.augmentations.crop_and_pad_augmentations import crop
from batchgenerators.dataloading import MultiThreadedAugmenter
from batchgenerators.transforms import Compose
from batchgenerators.utilities.data_splitting import get_split_deterministic
from batchgenerators.utilities.file_and_folder_operations import *
import numpy as np
from batchgenerators.dataloading.data_loader import DataLoader
from batchgenerators.augmentations.utils import pad_nd_image
from batchgenerators.transforms.spatial_transforms import SpatialTransform_2, MirrorTransform
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, GammaTransform
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
from model import Modified3DUNet
from multi_threaded_augmentor import MultiThreadedAugmenter
from sklearn.model_selection import KFold
import numpy as np
import torch
from losses import dice_loss
from torch import optim
import torch.nn.functional as F
def get_split_deterministic(all_keys, fold=0, num_splits=5, random_state=12345):
"""
Splits a list of patient identifiers (or numbers) into num_splits folds and returns the split for fold fold.
:param all_keys:
:param fold:
:param num_splits:
:param random_state:
:return:
"""
all_keys_sorted = np.sort(list(all_keys))
splits = KFold(n_splits=num_splits, shuffle=True, random_state=random_state)
for i, (train_idx, test_idx) in enumerate(splits.split(all_keys_sorted)):
if i == fold:
train_keys = np.array(all_keys_sorted)[train_idx]
test_keys = np.array(all_keys_sorted)[test_idx]
break
return train_keys, test_keys
def get_list_of_patients(preprocessed_data_folder):
npy_files = subfiles(preprocessed_data_folder, suffix=".npy", join=True)
# remove npy file extension
patients = [i[:-4] for i in npy_files]
return patients
class VerseDataLoader3D(DataLoader):
def __init__(self, data, batch_size, patch_size, num_threads_in_multithreaded, seed_for_shuffle=1234,
return_incomplete=False, shuffle=True, infinite=True):
"""
data must be a list of patients as returned by get_list_of_patients (and split by get_split_deterministic)
patch_size is the spatial size the retured batch will have
"""
super().__init__(data, batch_size, num_threads_in_multithreaded, seed_for_shuffle, return_incomplete, shuffle,
infinite)
self.patch_size = patch_size
self.num_modalities = 1
self.indices = list(range(len(data)))
@staticmethod
def load_patient(patient):
#I have skipped case number 201 due to mmap_mode error
data = np.load(patient + ".npy", mmap_mode="r")
metadata = load_pickle(patient + ".pkl")
return data, metadata
def generate_train_batch(self):
# DataLoader has its own methods for selecting what patients to use next, see its Documentation
idx = self.get_indices()
patients_for_batch = [self._data[i] for i in idx]
# initialize empty array for data and seg
data = np.zeros((self.batch_size, self.num_modalities, *self.patch_size), dtype=np.float32)
seg = np.zeros((self.batch_size, 1, *self.patch_size), dtype=np.float32)
metadata = []
patient_names = []
# iterate over patients_for_batch and include them in the batch
for i, j in enumerate(patients_for_batch):
patient_data, patient_metadata = self.load_patient(j)
# this will only pad patient_data if its shape is smaller than self.patch_size
patient_data = pad_nd_image(patient_data, self.patch_size)
# now random crop to self.patch_size
# crop expects the data to be (b, c, x, y, z) but patient_data is (c, x, y, z) so we need to add one
# dummy dimension in order for it to work (@Todo, could be improved)
patient_data, patient_seg = crop(patient_data[:-1][None], patient_data[-1:][None], self.patch_size, crop_type="random")
data[i] = patient_data[0]
seg[i] = patient_seg[0]
metadata.append(patient_metadata)
patient_names.append(j)
return {'data': data, 'seg':seg, 'metadata':metadata, 'names':patient_names}
def get_train_transform(patch_size):
# we now create a list of transforms. These are not necessarily the best transforms to use for BraTS, this is just
# to showcase some things
tr_transforms = []
# the first thing we want to run is the SpatialTransform. It reduces the size of our data to patch_size and thus
# also reduces the computational cost of all subsequent operations. All subsequent operations do not modify the
# shape and do not transform spatially, so no border artifacts will be introduced
# Here we use the new SpatialTransform_2 which uses a new way of parameterizing elastic_deform
# We use all spatial transformations with a probability of 0.2 per sample. This means that 1 - (1 - 0.1) ** 3 = 27%
# of samples will be augmented, the rest will just be cropped
tr_transforms.append(
SpatialTransform_2(
patch_size, [i // 2 for i in patch_size],
do_elastic_deform=True, deformation_scale=(0, 0.25),
do_rotation=True,
angle_x=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
angle_y=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
angle_z=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
do_scale=True, scale=(0.75, 1.25),
border_mode_data='constant', border_cval_data=0,
border_mode_seg='constant', border_cval_seg=0,
order_seg=1, order_data=3,
random_crop=True,
p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1
)
)
# now we mirror along all axes
tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))
# brightness transform for 15% of samples
tr_transforms.append(BrightnessMultiplicativeTransform((0.7, 1.5), per_channel=True, p_per_sample=0.15))
# gamma transform. This is a nonlinear transformation of intensity values
# (https://en.wikipedia.org/wiki/Gamma_correction)
tr_transforms.append(GammaTransform(gamma_range=(0.5, 2), invert_image=False, per_channel=True, p_per_sample=0.15))
# we can also invert the image, apply the transform and then invert back
tr_transforms.append(GammaTransform(gamma_range=(0.5, 2), invert_image=True, per_channel=True, p_per_sample=0.15))
# Gaussian Noise
tr_transforms.append(GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15))
# blurring. Some BraTS cases have very blurry modalities. This can simulate more patients with this problem and
# thus make the model more robust to it
tr_transforms.append(GaussianBlurTransform(blur_sigma=(0.5, 1.5), different_sigma_per_channel=True,
p_per_channel=0.5, p_per_sample=0.15))
# now we compose these transforms together
tr_transforms = Compose(tr_transforms)
return tr_transforms
if __name__ == "__main__":
num_threads_for_brats_example = 6
brats_preprocessed_folder = '../data/images/preprocessed/'
patients = get_list_of_patients(brats_preprocessed_folder)
train, val = get_split_deterministic(patients, fold=0, num_splits=5, random_state=12345)
patch_size = (128, 128, 128)
batch_size = 1
# I recommend you don't use 'iteration oder all training data' as epoch because in patch based training this is
# really not super well defined. If you leave all arguments as default then each batch sill contain randomly
# selected patients. Since we don't care about epochs here we can set num_threads_in_multithreaded to anything.
dataloader = VerseDataLoader3D(train, batch_size, patch_size, 1)
batch = next(dataloader)
try:
from batchviewer import view_batch
# batch viewer can show up to 4d tensors. We can show only one sample, but that should be sufficient here
view_batch(batch['data'][0], batch['seg'][0])
except ImportError:
view_batch = None
# print("you can visualize batches with batchviewer. It's a nice and handy tool. You can get it here: "
# "https://github.com/FabianIsensee/BatchViewer")
# now we have some DataLoader. Let's go an get some augmentations
# first let's collect all shapes, you will see why later
shapes = [VerseDataLoader3D.load_patient(i)[0].shape[1:] for i in patients]
max_shape = np.max(shapes, 0)
max_shape = np.max((max_shape, patch_size), 0)
# we create a new instance of DataLoader. This one will return batches of shape max_shape. Cropping/padding is
# now done by SpatialTransform. If we do it this way we avoid border artifacts (the entire brain of all cases will
# be in the batch and SpatialTransform will use zeros which is exactly what we have outside the brain)
# this is viable here but not viable if you work with different data. If you work for example with CT scans that
# can be up to 500x500x500 voxels large then you should do this differently. There, instead of using max_shape you
# should estimate what shape you need to extract so that subsequent SpatialTransform does not introduce border
# artifacts
dataloader_train = VerseDataLoader3D(train, batch_size, max_shape, 1)
# during training I like to run a validation from time to time to see where I am standing. This is not a correct
# validation because just like training this is patch-based but it's good enough. We don't do augmentation for the
# validation, so patch_size is used as shape target here
dataloader_validation = VerseDataLoader3D(val, batch_size, patch_size, 1)
tr_transforms = get_train_transform(patch_size)
# finally we can create multithreaded transforms that we can actually use for training
# we don't pin memory here because this is pytorch specific.
tr_gen = MultiThreadedAugmenter(dataloader_train, tr_transforms, num_processes=num_threads_for_brats_example,
num_cached_per_queue=3,
seeds=None, pin_memory=False)
# we need less processes for vlaidation because we dont apply transformations
val_gen = MultiThreadedAugmenter(dataloader_validation, None,
num_processes=max(1, num_threads_for_brats_example // 2), num_cached_per_queue=1,
seeds=None,
pin_memory=False)
# lets start the MultiThreadedAugmenter. This is not necessary but allows them to start generating training
# batches while other things run in the main thread
tr_gen.restart()
val_gen.restart()
# now if this was a network training you would run epochs like this (remember tr_gen and val_gen generate
# inifinite examples! Don't do "for batch in tr_gen:"!!!):
num_batches_per_epoch = 10
num_validation_batches_per_epoch = 3
num_epochs = 5
# let's run this to get a time on how long it takes
time_per_epoch = []
start = time()
in_channels = 1
n_classes = 26
base_n_filter = 16
model = Modified3DUNet(in_channels, n_classes, base_n_filter).cuda()
for epoch in range(num_epochs):
start_epoch = time()
for b in range(num_batches_per_epoch):
# batch = next(tr_gen)
# image = torch.from_numpy(batch['data']).float().cuda()
# label = torch.from_numpy(batch['seg']).cuda().long()
# output_1, output_2 = model(image)
# one_hot_encode_labels = F.one_hot(label,n_classes)
# #$$$$should i be doing this??
# one_hot_encode_labels = torch.squeeze(one_hot_encode_labels,0)
# one_hot_encode_labels = one_hot_encode_labels.permute(0,4,1,2,3).contiguous()
# loss = dice_loss(output_2,one_hot_encode_labels)
# optimizer = optim.Adam(model.parameters())
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# print('Done')
image = torch.from_numpy(batch['data']).float().cuda()
label = torch.from_numpy(batch['seg']).cuda().long()
labels_for_conf = batch['seg']
output_1, output_2 = model(image)
one_hot_encode_labels = F.one_hot(label,n_classes)
one_hot_encode_labels = one_hot_encode_labels.permute(0,4,1,2,3).contiguous()
loss = dice_loss(output_2,one_hot_encode_labels)
conf_matrix = confusion_matrix(torch.argmax(output_2,1).view(-1).cpu().detach().numpy(), labels_for_conf.view(-1).cpu().detach().numpy())
TPR,TNR, PPV, FPR ,FNR, ACC = get_metrics(conf_matrix)
accuracy = accuracy_score(labels_for_conf.view(-1).cpu().detach().numpy(), torch.argmax(output_2,1).view(-1).cpu().detach().numpy())
mean_accuracy.append(accuracy)
logger.info('TPR == {} | \nTNR == {} | \nPRCSN == {} | \nFPR == {}\n | \nFNR == {} | \nACC == {}.'.format(TPR,TNR,PPV,FPR, FNR, ACC))
logger.info('Accuracy = {}'.format(accuracy))
losses.append(loss.item())
optimizer = optim.Adam(model.parameters())
optimizer.zero_grad()
loss.backward()
optimizer.step()
for b in range(num_validation_batches_per_epoch):
batch = next(val_gen)
# run validation here
end_epoch = time()
time_per_epoch.append(end_epoch - start_epoch)
end = time()
total_time = end - start
print("Running %d epochs took a total of %.2f seconds with time per epoch being %s" %
(num_epochs, total_time, str(time_per_epoch)))
# if you notice that you have CPU usage issues, reduce the probability with which the spatial transformations are
# applied in get_train_transform (down to 0.1 for example). SpatialTransform is the most expensive transform
# if you wish to visualize some augmented examples, install batchviewer and uncomment this
if view_batch is not None:
for _ in range(4):
batch = next(tr_gen)
view_batch(batch['data'][0], batch['seg'][0])
else:
print("Cannot visualize batches, install batchviewer first")