diff --git a/recognition/2d_oasis-46412384/.gitignore b/recognition/2d_oasis-46412384/.gitignore new file mode 100644 index 000000000..f58af6896 --- /dev/null +++ b/recognition/2d_oasis-46412384/.gitignore @@ -0,0 +1,3 @@ +data +__pycache__ +_tdmodel* diff --git a/recognition/2d_oasis-46412384/README.md b/recognition/2d_oasis-46412384/README.md new file mode 100644 index 000000000..70142c893 --- /dev/null +++ b/recognition/2d_oasis-46412384/README.md @@ -0,0 +1,133 @@ +# 2D OASIS brain data segmentation with Improved UNet + +## Project Synopsis + +This project segments raw 2D brain MRI data slices into regions of distinct classification. +Using an improved UNet model, success criteria consisted of achieving a dice similarity coefficient of 0.9 minimum across the final predictions.\ +Each 2D data slice is segmented into 4 classes, distinguished by equally distributed greyscale shades in the segmentation mask: + +1. Background (black) +2. Cerebro-spinal fluid (dark grey) +3. Grey matter (light grey) +4. White matter (white) + +For example: + +![example batch prediction visualisation](figs/ex-batch_preds.png) + + +## Model Analysis + +### UNet Model + +The Improved UNet algorithm proves to be a reliable model architecture for 2D image segmentation tasks, particularly those of medical application. +In summary, it's design revolves around applying numerous fine filters over the image to discern patterns across the image. +These filters are applied repeatedly over several layers, between which the data is compressed, yielding a smaller result each time. +This act of compression - called pooling - allows for the algorithm to capture elements of both fine detail and abstract patterns across the data. +Finally, the data is upscaled back through each layer, where each layer of prediction is compiled presenting a final mask that matches the resolution of the input image. +This method of compaction to a final bottleneck and contrasting reconstruction presents the U-shape architecture that entitles the UNet. + +### Dataset + +This model is trained on the pre-split OASIS dataset. Training comprised 9664 slices, validation over 1120 validation slices, and final evaluation on 544 testing slices. +This achieves a distribution approximating 85 / 10 / 5 respectively, aligning with the train-heavier side of typical split ratios. + +The dataset consistutes 256x256 greyscale (single channel) images. Loading of each dataset consisted of: + +- Sequentially pairing each slice with their respective mask +- Applying z-score normalization to each slice, minimising instability in training values +- Decoding the masks from greyscale values to class-label identifiers (integers 0 - 3) + +Such processing was computed on-the-fly, as data was requested by the model. +This increases the training time, as data was processed repeatedly upon each epoch, in favour of optimising the memory demands. + +### Model Operation + +Training consisted of 26 epochs and batches of 4 images, with validation tests occurring after each epoch. +Model evaluation utilised a combination of Dice loss coefficient and cross entropy loss methods calculated at each epoch, however testing was governed by cross entropy loss. +Train the model using: +``` +python3 train.py +``` +Each epoch loss will be displayed, and the model state and training losses are saved to disk (model: `_tdmodel/model.dat`; losses: `_tdmodel/loss.dat`).\ +Snapshots of the predictions throughout training are also saved locally (`_tdmodel/disp/`) + +### Model Evaluation + +After training, evaluate the model using: +``` +python3 predict.py +``` +This will load the saved model and test it using the yet unseen test dataset. +Its final result is displayed in output and the visualisation of predictions in the first batch is saved to disk.\ +For example: + +![example evaluation batch visualisation](figs/ex-eval_preds.png) + +This module will also model the loss results from training and validation (`_tdmodel/plot/`).\ +For example: + +![example cross-entropy loss plot](figs/ex-celoss.png) + + +## Model Results + +The final result of this model achieved a loss score of **0.29355** and dice score of **0.59782**. + +> Training loss scores continued to decrease with each epoch, reaching a minimum of **0.20137**, while validation scores converged towards, approximately, **0.29**. + +> Similarly, training dice scores show decline across epochs, reaching **0.66733**, validation scores converged about **0.6** + +While it is likely possible to reach a convergence of training loss scores, given more epochs, this would only result in model overfitting, as the validation scores had already reached their convergence points.\ +The resulting output is: + +``` +Starting model training on cuda + [ Epoch 1 / 26 ] Train: L=0.98436 D=0.71914 Validate: L=0.86103 D=0.67615 + [ Epoch 2 / 26 ] Train: L=0.62303 D=0.62290 Validate: L=0.53432 D=0.58800 + [ Epoch 3 / 26 ] Train: L=0.45187 D=0.55552 Validate: L=0.40059 D=0.52977 + [ Epoch 4 / 26 ] Train: L=0.37325 D=0.51251 Validate: L=0.36206 D=0.51142 + [ Epoch 5 / 26 ] Train: L=0.33565 D=0.48701 Validate: L=0.35009 D=0.48460 + [ Epoch 6 / 26 ] Train: L=0.31543 D=0.47188 Validate: L=0.35551 D=0.48481 + [ Epoch 7 / 26 ] Train: L=0.29792 D=0.45573 Validate: L=0.31159 D=0.46544 + [ Epoch 8 / 26 ] Train: L=0.28726 D=0.44615 Validate: L=0.30409 D=0.45165 + [ Epoch 9 / 26 ] Train: L=0.27587 D=0.43496 Validate: L=0.29401 D=0.43792 + [ Epoch 10 / 26 ] Train: L=0.26944 D=0.42707 Validate: L=0.28949 D=0.43535 + [ Epoch 11 / 26 ] Train: L=0.26269 D=0.41961 Validate: L=0.28887 D=0.43687 + [ Epoch 12 / 26 ] Train: L=0.25914 D=0.41463 Validate: L=0.27945 D=0.42804 + [ Epoch 13 / 26 ] Train: L=0.25293 D=0.40722 Validate: L=0.29024 D=0.42730 + [ Epoch 14 / 26 ] Train: L=0.24776 D=0.40157 Validate: L=0.28522 D=0.43205 + [ Epoch 15 / 26 ] Train: L=0.24375 D=0.39455 Validate: L=0.28957 D=0.43101 + [ Epoch 16 / 26 ] Train: L=0.24061 D=0.38900 Validate: L=0.28890 D=0.42225 + [ Epoch 17 / 26 ] Train: L=0.23789 D=0.38711 Validate: L=0.30815 D=0.42445 + [ Epoch 18 / 26 ] Train: L=0.23599 D=0.38272 Validate: L=0.29255 D=0.41271 + [ Epoch 19 / 26 ] Train: L=0.22989 D=0.37462 Validate: L=0.28706 D=0.41631 + [ Epoch 20 / 26 ] Train: L=0.22484 D=0.36830 Validate: L=0.28170 D=0.40779 + [ Epoch 21 / 26 ] Train: L=0.21971 D=0.35926 Validate: L=0.28826 D=0.41555 + [ Epoch 22 / 26 ] Train: L=0.21698 D=0.35655 Validate: L=0.30154 D=0.41290 + [ Epoch 23 / 26 ] Train: L=0.21315 D=0.34834 Validate: L=0.29658 D=0.40792 + [ Epoch 24 / 26 ] Train: L=0.20954 D=0.34459 Validate: L=0.29707 D=0.40281 + [ Epoch 25 / 26 ] Train: L=0.20679 D=0.33885 Validate: L=0.30125 D=0.41382 + [ Epoch 26 / 26 ] Train: L=0.20137 D=0.33267 Validate: L=0.29598 D=0.40516 +Training complete! +``` +``` +Starting model evaluation + [ Eval ] Validate: L=0.29355 D=0.40218 +Evaluation complete! +``` + +Yielding the loss plots: + +![final model result dice loss scores](figs/fin-diceloss.png) + +![final model result CE loss scores](figs/fin-celoss.png) + + +## Dependencies + +- `python : 3.9.23` +- `pytorch : 2.5.1` +- `numpy : 2.0.1` +- `pillow : 11.3.0` +- `matplotlib : 3.9.2` \ No newline at end of file diff --git a/recognition/2d_oasis-46412384/dataset.py b/recognition/2d_oasis-46412384/dataset.py new file mode 100644 index 000000000..c138ad581 --- /dev/null +++ b/recognition/2d_oasis-46412384/dataset.py @@ -0,0 +1,112 @@ +import torch +import numpy as np +from torch.utils.data import DataLoader, Dataset +from os import listdir +from PIL import Image + + +class DataError(Exception): + """ + Raised when a problem regarding the dataset is encountered + """ + pass + + +class OasisDataset(Dataset): + """ + Dataset for OASIS greyscale images and segmentations + """ + + def __init__(self, data_dir, mask_dir, subset_size=-1): + # Find all data files + self.data_paths = [data_dir + "/" + file for file in listdir(data_dir)[:subset_size]] + self.mask_paths = [mask_dir + "/" + file for file in listdir(mask_dir)[:subset_size]] + + # Establish dataset size consistency between images, masks and subset size + if (subset_size < 0 or len(self.data_paths) < subset_size): + self.subset_size = len(self.data_paths) + else: + self.subset_size = subset_size + + if (len(self.mask_paths) != self.subset_size): + raise DataError("Non-matching subset sizes between data (length %d) and mask (length %d)." % (len(self.data_paths), len(self.mask_paths))) + + def __len__(self): + return self.subset_size + + def __getitem__(self, idx): + if (idx < 0 or idx >= self.subset_size): + raise IndexError("Index %d out of range for subset size %d" % (idx, self.subset_size)) + + # Read and load image data + img = Image.open(self.data_paths[idx]) + img_data = np.resize(np.array(img.getdata(), dtype=np.float32), img.size) + + # Apply z-score normalisation to image data + # This stabilises all images to a common scale by filtering out irrelevant details, as to maintain consistency in the model training + img_data = (img_data - img_data.mean()) / (img_data.std() + 1e-6) + # Shift array dimensions as to accomodate batch groups + img_data = torch.from_numpy(img_data).unsqueeze(0).float() + + # Load and decode mask + msk = Image.open(self.mask_paths[idx]) + msk_data = torch.from_numpy(OasisDataset.decode_mask(np.resize(np.array(msk.getdata(), dtype=np.int64), msk.size))) + + return img_data, msk_data + + + """ + Decode greyscale mask values from numpy array to segment classification integers + """ + def decode_mask(enc_msk): + # Collect a list of unique values in the mask + vals = set() + for val in enc_msk.flat: + vals.add(val) + + vals = np.sort(np.array(list(vals), dtype=np.uint8)) + + # Replace each value with a corresponding identifying index + dec_msk = np.zeros_like(enc_msk) + for idx, val in enumerate(vals): + dec_msk[enc_msk == val] = idx + + return dec_msk + + + """ + Encode segment classification integers from numpy array into greyscale values of equal distribution + """ + def encode_mask(dec_msk): + # Collect a list of unique class label indices + vals = set() + for val in dec_msk.flat: + vals.add(val) + + vals = np.sort(np.array(list(vals))) + + # Generate an equidistant greyscale value for each label + enc_msk = np.array(dec_msk) * 255 / (len(vals) - 1) + + return enc_msk + + +def get_oasis_dataloaders(data_dir, batch_size, subset_size=-1): + IMG_TRAIN_PATH = "keras_png_slices_train" + MSK_TRAIN_PATH = "keras_png_slices_seg_train" + IMG_VAL_PATH = "keras_png_slices_validate" + MSK_VAL_PATH = "keras_png_slices_seg_validate" + IMG_TEST_PATH = "keras_png_slices_test" + MSK_TEST_PATH = "keras_png_slices_seg_test" + + # Gather data files + ds_train = OasisDataset(data_dir + IMG_TRAIN_PATH, data_dir + MSK_TRAIN_PATH, subset_size=subset_size) + ds_validate = OasisDataset(data_dir + IMG_VAL_PATH, data_dir + MSK_VAL_PATH, subset_size=subset_size // 2 if subset_size else -1) + ds_test = OasisDataset(data_dir + IMG_TEST_PATH, data_dir + MSK_TEST_PATH, subset_size=subset_size // 2 if subset_size else -1) + + # Generate data loaders + dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True) + dl_validate = DataLoader(ds_validate, batch_size=batch_size, shuffle=False) + dl_test = DataLoader(ds_test, batch_size=batch_size, shuffle=False) + + return dl_train, dl_validate, dl_test \ No newline at end of file diff --git a/recognition/2d_oasis-46412384/figs/ex-batch_preds.png b/recognition/2d_oasis-46412384/figs/ex-batch_preds.png new file mode 100644 index 000000000..61cae7f4a Binary files /dev/null and b/recognition/2d_oasis-46412384/figs/ex-batch_preds.png differ diff --git a/recognition/2d_oasis-46412384/figs/ex-celoss.png b/recognition/2d_oasis-46412384/figs/ex-celoss.png new file mode 100644 index 000000000..f2e922f67 Binary files /dev/null and b/recognition/2d_oasis-46412384/figs/ex-celoss.png differ diff --git a/recognition/2d_oasis-46412384/figs/ex-eval_preds.png b/recognition/2d_oasis-46412384/figs/ex-eval_preds.png new file mode 100644 index 000000000..fccd297e7 Binary files /dev/null and b/recognition/2d_oasis-46412384/figs/ex-eval_preds.png differ diff --git a/recognition/2d_oasis-46412384/figs/fin-celoss.png b/recognition/2d_oasis-46412384/figs/fin-celoss.png new file mode 100644 index 000000000..abb2dbf3d Binary files /dev/null and b/recognition/2d_oasis-46412384/figs/fin-celoss.png differ diff --git a/recognition/2d_oasis-46412384/figs/fin-diceloss.png b/recognition/2d_oasis-46412384/figs/fin-diceloss.png new file mode 100644 index 000000000..e4415426a Binary files /dev/null and b/recognition/2d_oasis-46412384/figs/fin-diceloss.png differ diff --git a/recognition/2d_oasis-46412384/modules.py b/recognition/2d_oasis-46412384/modules.py new file mode 100644 index 000000000..7a0a5aa28 --- /dev/null +++ b/recognition/2d_oasis-46412384/modules.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class InConv(nn.Module): + def __init__(self, in_chls, out_chls): + super().__init__() + + # Use convolution with a 3x3 kernel, padding width of 1 and no dilation + # This suits the MRI data as its patterns are closely related to its near neighbours + # instead of regarding the entire image + + # Use the ReLU activation function as a suitable compromise on computational efficiency and + # model complexity while trying to replicate the sigmoid function + self.net = nn.Sequential( + nn.Conv2d(in_chls, out_chls, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_chls), + nn.ReLU(inplace=True), + nn.Conv2d(out_chls, out_chls, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_chls), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.net(x) + + +class DownLayer(nn.Module): + def __init__(self, in_chls, out_chls): + super().__init__() + # use 2x2 max pooling for data downscaling + self.pool = nn.MaxPool2d(kernel_size=2) + self.conv = InConv(in_chls, out_chls) + + def forward(self, x): + return self.conv(self.pool(x)) + + +class UpLayer(nn.Module): + def __init__(self, in_chls, out_chls): + super().__init__() + self.up = nn.ConvTranspose2d(in_chls, in_chls // 2, kernel_size=2, stride=2) + self.conv = InConv(in_chls, out_chls) + + def forward(self, x, skip): + x = self.up(x) + + dx = skip.size(3) - x.size(3) + dy = skip.size(2) - x.size(2) + + if (dx or dy): + x = F.pad(x, [dx//2, dx - dx//2, dy//2, dy-dy//2]) + x = torch.cat([skip, x], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_chls, out_chls): + super().__init__() + # Final 1x1 convolution to establish logits + self.conv = nn.Conv2d(in_chls, out_chls, kernel_size=1) + + def forward(self, x): + return self.conv(x) + + +class UNet(nn.Module): + + def __init__(self, in_chls=1, num_classes=4, base_chls=64): + super().__init__() + + c = [base_chls * 2**i for i in range(5)] + + # Network architecture + + self.input = InConv(in_chls, base_chls) + self.down1 = DownLayer(c[0], c[1]) + self.down2 = DownLayer(c[1], c[2]) + self.down3 = DownLayer(c[2], c[3]) + self.down4 = DownLayer(c[3], c[4]) + + self.up1 = UpLayer(c[4], c[3]) + self.up2 = UpLayer(c[3], c[2]) + self.up3 = UpLayer(c[2], c[1]) + self.up4 = UpLayer(c[1], c[0]) + self.output = OutConv(c[0], num_classes) + + def forward(self, x): + x1 = self.input(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + + # Upsampling with skips from downsampling train + return self.output( + self.up4( + self.up3( + self.up2( + self.up1(x5, x4), + x3), + x2), + x1) + ) + + +class MCDiceLoss(nn.Module): + def __init__(self, eps=1e-6): + super().__init__() + self.eps = eps + + def forward(self, logits, target): + # Convert raw output to probabilities using softmax function, ideal for predictions over several classes + probs = logits.softmax(dim=1) + # Batch size, num classes, height, width + B, C, H, W = probs.shape + + # One-hot encoding to finalise predictions, followed by dimension adjustments + target_oh = F.one_hot(target, num_classes=C).permute(0, 3, 1, 2).float() + + probs_flat = probs.reshape(B, C, -1) + target_flat = target_oh.reshape(B, C, -1) + + # Cross product yields the intersect + intersect = (probs_flat * target_flat).sum(dim=-1) + total = probs_flat.sum(dim=-1) + target_flat.sum(dim=-1) + + dice = (2 * intersect + self.eps) / (total + self.eps) + + return 1 - dice.mean() + diff --git a/recognition/2d_oasis-46412384/predict.py b/recognition/2d_oasis-46412384/predict.py new file mode 100644 index 000000000..b316cafd1 --- /dev/null +++ b/recognition/2d_oasis-46412384/predict.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from train import Config, evaluate +from dataset import get_oasis_dataloaders +from modules import UNet, MCDiceLoss +import matplotlib.pyplot as plt + + +def plot_loss(train_loss, validate_loss, save_path, title="Loss over Training Epochs", legend=["Training loss", "Validation loss"]): + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.grid(True) + + plt.title(title) + + plt.plot(train_loss, "r") + plt.plot(validate_loss, "b") + + plt.legend(legend) + + plt.savefig(save_path) + plt.close() + + +if (__name__ == "__main__"): + _, _, dl_test = get_oasis_dataloaders(Config.DATA_DIR, Config.BATCH_SIZE, Config.SUBSET_SIZE) + model = UNet(Config.IN_CHANNELS, Config.NUM_CLASSES).to(Config.DEVICE) + + # Read model data from save file + model.load_state_dict(torch.load(Config.MODEL_SAVE_FILE, map_location=Config.DEVICE)) + train_loss, train_dice, validate_loss, validate_dice = torch.load(Config.LOSS_SAVE_FILE) + + print("\nStarting model evaluation") + + crit = nn.CrossEntropyLoss() + dice_fn = MCDiceLoss() + + print("\t[ Eval ]", end="", flush=True) + + # """ + test_loss, test_dice = evaluate( + model=model, + dl=dl_test, + crit=crit, + dice_fn=dice_fn, + dev=Config.DEVICE, + display=True + ) + # """ + + plot_loss( + train_loss, + validate_loss, + save_path=Config.PLOT_SAVE_PATH + "celoss.png", + title="Cross-Entropy Loss over Training Epochs", + legend=["Training CE Loss", "Validation CE Loss"] + ) + + plot_loss( + train_dice, + validate_dice, + save_path=Config.PLOT_SAVE_PATH + "diceloss.png", + title="Dice Loss over Training Epochs", + legend=["Training Dice Loss", "Validation Dice Loss"] + ) + + print("\tValidate: L=%6.5f D=%6.5f" % (test_loss, test_dice)) + + print("Evaluation complete!") \ No newline at end of file diff --git a/recognition/2d_oasis-46412384/train.py b/recognition/2d_oasis-46412384/train.py new file mode 100644 index 000000000..c68829139 --- /dev/null +++ b/recognition/2d_oasis-46412384/train.py @@ -0,0 +1,214 @@ +import torch +import torch.nn as nn +from torch.optim import Adam +from dataset import get_oasis_dataloaders, OasisDataset +from modules import UNet, MCDiceLoss +import matplotlib.pyplot as plt + + +class Config: + # Storage and output paths + DATA_DIR = "data/" + MODEL_SAVE_FILE = "_tdmodel/model.dat" + LOSS_SAVE_FILE = "_tdmodel/loss.dat" + DISPLAY_SAVE_PATH = "_tdmodel/disp/" + PLOT_SAVE_PATH = "_tdmodel/plot/" + + DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + # Data subset size. Datasets will be trimmed to this size, if necessary. Set to -1 to not trim datasets + SUBSET_SIZE = -1 + + # Gradient for AdaM optimiser momentum. Adjusts the weight of the momentum-based component to AdaM + # The momentum technique regards an adaptive gradient descent algorithm, optimising the convergence rate + # while avoiding oscillations and divergence caused by excessive momentum + # A greater learning rate may result in a more agressive model progression, however reduces it's stability + LEARN_RATE = 1e-3 + + # Data parameters. These are unlikely to change for the OASIS dataset + IN_CHANNELS = 1 + NUM_CLASSES = 4 + + # Model parameters + BATCH_SIZE = 4 + NUM_EPOCHS = 26 + + # Output parameters + DISPLAY_EVERY = 2 + + # Internal variables + display_count = 0 + + +def train_epoch(model, dl, optim, crit, dice_fn, num_classes, dev="cpu", display=False): + """ + Train and test a model for a single run of the dataset + """ + + loss = 0 + dice = 0 + + model.train() + + for bat_idx, (imgs, msks) in enumerate(dl): + imgs = imgs.to(dev) + msks = msks.to(dev) + + # Reset the optimiser + optim.zero_grad() + + # Compute the model predictions + logits = model(imgs) + + if (display and bat_idx == 0): + display_batch(imgs, msks, logits) + + # Evaluate the prediction loss + bat_loss = crit(logits, msks) + bat_loss.backward() + + bat_dice = dice_fn(logits, msks).mean() + + # Advance the optimiser + optim.step() + + loss += bat_loss.item() + dice += bat_dice.item() + + return loss / len(dl), dice / len(dl) + + +def evaluate(model, dl, crit, dice_fn, dev="cpu", display=False): + """ + Test a model for a single run of the dataset + """ + + loss = 0 + dice = 0 + + model.eval() + + for bat_idx, (imgs, msks) in enumerate(dl): + imgs = imgs.to(dev) + msks = msks.to(dev) + + # Computer the model predictions + logits = model(imgs) + + if (display and bat_idx == 0): + display_batch(imgs, msks, logits, save_file="eval.png") + + # Evaluate the prediction loss + bat_loss = crit(logits, msks) + bat_dice = dice_fn(logits, msks).mean() + + loss += bat_loss.item() + dice += bat_dice.item() + + return loss / len(dl), dice / len(dl) + + +def display_batch(imgs, msks, logits, save_file=None): + + batch_size = len(imgs) + + fig, axes = plt.subplots(batch_size, 3) + axes[0][0].set_title("Input Image") + axes[0][1].set_title("Ground Truth") + axes[0][2].set_title("Prediction") + + imgs_data = imgs.cpu().numpy() + msks_data = msks.cpu().numpy() + logits_data = logits.argmax(dim=1).cpu().numpy() + + with torch.no_grad(): + for bat_idx in range(batch_size): + img_data = imgs_data[bat_idx, 0] + msk_data = msks_data[bat_idx] + logit_data = logits_data[bat_idx] + + img_data = (img_data - img_data.min()) / (img_data.max() - img_data.min() + 1e-8) + + msk_data = OasisDataset.encode_mask(msk_data) + logit_data = OasisDataset.encode_mask(logit_data) + + axes[bat_idx][0].imshow(img_data, cmap="gray") + axes[bat_idx][0].axis("off") + axes[bat_idx][1].imshow(msk_data, cmap="gray") + axes[bat_idx][1].axis("off") + axes[bat_idx][2].imshow(logit_data, cmap="gray") + axes[bat_idx][2].axis("off") + + plt.tight_layout() + + if (save_file == None): + save_file = "%d.png" % Config.display_count + + plt.savefig(Config.DISPLAY_SAVE_PATH + save_file) + plt.close() + + Config.display_count += 1 + + +def train(model, dl_train, dl_validate, epochs=20, display_every=10): + """ + Run full training loop for a model + """ + + model.to(Config.DEVICE) + print("\nStarting model training on %s" % Config.DEVICE) + + # Establish model training accessories + crit = nn.CrossEntropyLoss() + dice_fn = MCDiceLoss() + optim = Adam(model.parameters(), lr=Config.LEARN_RATE) + + train_losses = [] + train_dices = [] + + validate_losses = [] + validate_dices = [] + + # Run each training epoch + for ep_idx in range(1, epochs + 1): + print("\t[ Epoch %d / %d ]" % (ep_idx, epochs), end="", flush=True) + + train_loss, train_dice = train_epoch( + model=model, + dl=dl_train, + optim=optim, + crit=crit, + dice_fn=dice_fn, + num_classes=Config.NUM_CLASSES, + dev=Config.DEVICE, + display=not bool(ep_idx % display_every) + ) + + validate_loss, validate_dice = evaluate( + model=model, + dl=dl_validate, + crit=crit, + dice_fn=dice_fn, + dev=Config.DEVICE, + display=False + ) + + print("\tTrain: L=%6.5f D=%6.5f\tValidate: L=%6.5f D=%6.5f" % (train_loss, train_dice, validate_loss, validate_dice)) + + train_losses.append(train_loss) + train_dices.append(train_dice) + validate_losses.append(validate_loss) + validate_dices.append(validate_dice) + + print("Training complete!") + + return train_losses, train_dices, validate_losses, validate_dices + + +if (__name__ == "__main__"): + dl_train, dl_validate, _ = get_oasis_dataloaders(Config.DATA_DIR, Config.BATCH_SIZE, Config.SUBSET_SIZE) + model = UNet(Config.IN_CHANNELS, Config.NUM_CLASSES) + train_loss, train_dice, validate_loss, validate_dice = train(model, dl_train, dl_validate, epochs=Config.NUM_EPOCHS, display_every=Config.DISPLAY_EVERY) + + torch.save(model.state_dict(), Config.MODEL_SAVE_FILE) + torch.save([train_loss, train_dice, validate_loss, validate_dice], Config.LOSS_SAVE_FILE) \ No newline at end of file