diff --git a/recognition/2D_OASIS_s43971125/README.md b/recognition/2D_OASIS_s43971125/README.md new file mode 100644 index 000000000..66c1311e2 --- /dev/null +++ b/recognition/2D_OASIS_s43971125/README.md @@ -0,0 +1,101 @@ +# Table of Contents +- [Improved 2D UNet](#improved-2d-unet) + - [Problem](#problem) + - [Model](#model) +- [Loading Data](#loading-data) +- [Training](#training) + - [Loss](#loss) + - [Optimiser](#optimiser) +- [Testing](#testing) +- [Result](#result) +- [References](#references) +- [Dependencies](#dependencies) + +# Improved 2D UNet +A UNet is a convolutional neural network that is designed for image segmentation. This means it +classifies every pixel of an image into one of several different categories. This is particularly useful +for medical images, where you might want to classify which parts of an image correspond to certain types of tissue. + +## Problem +The problem being solved is the segmentation of 2D mri images of brains, into 4 different types of brain tissue. +The goal is for all 4 labels to have a minimum Dice similarity coefficient of 0.9 - so once the model is done training, +it could be used on any 2D mri image of a brain to accurately identify which parts of the image correspond to each type +of tissue. This could be useful for identifying potential abnormalities when scanning for illness. + +## Model +[modules.py](modules.py) + +An Improved 2D UNet is similar in basic structure to a UNet. It follows the same U-Shaped architecture, where an image +is gradually reduced in spatial size while the number of channels are increased (essentially +zooming in on the image, in a way) in the encoder step. Then the feature maps are upscaled back to the original size +of the image while maintaining the learned feature, until eventually you end up with a segmented image of the same size +as the original. UNet also utilise skip connections, where it can take the outputs from earlier levels of the encoder +and hand them directly to the decoder of the same level - this helps to reconstruct boundaries, because reducing spatial +dimensions leads to information loss. + +![UNet Architecture](documentation/UNet.png) + +An Improved 2D UNet has some enhancements however. Instead of the normal convolutional blocks, it use residual blocks +with pre-activation (activation before convolution). It has optional dropout layers to avoid overfitting (unused here), +instance normalisation instead of batch normalisation (better stability with small batch sizes), and deep supervision. +Deep supervision essentially combines the output from all the decoder layers to give the final output. This provides +better gradient flow (because it provides the gradient from all levels of the decoder, faster convergence (training +stabilises faster which speeds it up) and improved accuracy (because we are taking outputs from all levels of the encoder, +intermediate features are more prominent, which will lead to better final predictions + +![Improved UNet Architecture](documentation/Improved_UNet.png) + +# Loading Data +[dataset.py](dataset.py) +For this problem, loading the data is relatively simple. No transforms need to be applied, and the data is already +separated into images and labels, and training/testing/validation groups. All that needs to be done is to convert +each image and label to a pytorch tensor so that they can be processed by the Improved UNet. + +# Training +Parameters used were as follows (these can be adjusted within train.py +Batch size: 4 +Number of Epochs: 50 +Learning Rate: 1e-4 + + +## Loss +This model uses Dice similarity coefficient loss (DiceLoss). It computes the simarilty of our segmented images and +manually delineated images, based on how much they overlap. Specifically, we use Multiclass Dice loss (where we compute +the dice score for each class separately, then take the average). +DiceLoss function was generated by AI, then adjusted based on the information from +the lecture slides. + +## Optimiser +This model uses the Adam (Adaptive Moment Estimation) optimiser, which is built-in to pytorch. It works by +adjusting the learning rate for each parameter during training, by keeping track of how gradients are changing +for each weight and adjusting how much each weight is updated. It's a very common choice for deep segmentation models +such as UNet. + +# Testing +The model was tested by measuring the Dice Scores for each class. Each Dice score is calculated independently to +avoid imbalanced class errors. See results for an example of the input and output, the probability maps for +each class, and the dice scores per class. + +# Result +Dice scores per class +Class 0: 0.9993 +Class 1: 0.9585 +Class 2: 0.9653 +Class 3: 0.9783 +![Dice Scores](documentation/dice_dcores.png) + +Example input-output +![Visualisation](documentation/prediction_example.png) + +Example per-class visualisation +![Per-class visualisation](documentation/prediction_per_class.png) + +# References +Lecture Slides for model PNGs + +# Dependencies +pytorch=2.6.0 +torchvision=0.21.0 +numpy=1.25.0 +pillow-9.5.0 +matplotlib=3.8.0 diff --git a/recognition/2D_OASIS_s43971125/dataset.py b/recognition/2D_OASIS_s43971125/dataset.py new file mode 100644 index 000000000..f2f3fd277 --- /dev/null +++ b/recognition/2D_OASIS_s43971125/dataset.py @@ -0,0 +1,83 @@ +# Contains the data loader + + +import os +import torch +from torch.utils.data import Dataset, DataLoader +import numpy as np +from PIL import Image +import glob + +''' +OASISDataset class. +Loads the 2D_OASIS_dataset, converts the images and labels to pytorch tensors, +Remaps the labels to one-hot encoding, so that they are compatible with the DiceLoss function +Optionally transforms the image if 'transform' is set, then returns the images and labels +''' +class OASISDataset(Dataset): + '''Current setup assumes that 2D_OASIS dataset is located in the specified root_dir + with the images and labels already split into train, test, and validation directory + Also assumes that all labels and images are .png files + /Oasis- + -keras_png_slices_train + -keras_png_slices_test + -keras_png_slices_val + -keras_png_slices_seg_train + -keras_png_slices_seg_test + -keras-png_slices_seg_val + If there is a different location, need to adjust root_dir, image_dir and label_dir accordingly + ''' + def __init__(self, root_dir="/home/groups/comp3710/OASIS", split="train", categorical=False, transform=None): + + image_dir = os.path.join(root_dir, f"keras_png_slices_{split}") + label_dir = os.path.join(root_dir, f"keras_png_slices_seg_{split}") + + self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png"))) + self.label_paths = sorted(glob.glob(os.path.join(label_dir, "*.png"))) + + self.transform = transform + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + # load images and labels + image = Image.open(self.image_paths[idx]).convert("L") + label = Image.open(self.label_paths[idx]).convert("L") + + # Convert to np arrays + image_np = np.array(image) + label_np = np.array(label, dtype=np.uint8) + + + + # Convert to tensors + image = torch.tensor(image_np, dtype=torch.float32).unsqueeze(0) / 255.0 + label = torch.tensor(label_np, dtype=torch.long) + + #remap labels from [0, 85, 170, 255] to [0, 1, 2, 3], so that one-hot encoding of + #targets (in the diceloss function) works correctly. We do this by just dividing every label + #by 85, and rounding down. + label = torch.div(label, 85, rounding_mode='floor') + + #optional transformation step - only used if transform is set in __init__ + if self.transform: + image = self.transform(image) + + return image, label + +# Helper function that is called by train.py to get all 3 dataloaders +def get_dataloaders(root_dir="/home/groups/comp3710/OASIS", batch_size=4): + # define the 3 datasets + train_dataset = OASISDataset(root_dir=root_dir, split="train") + val_dataset = OASISDataset(root_dir=root_dir, split="validate") + test_dataset = OASISDataset(root_dir=root_dir, split="test") + + #Construct the 3 dataloaders + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + + #Return the 3 dataloaders + return train_loader, val_loader, test_loader + diff --git a/recognition/2D_OASIS_s43971125/documentation/Improved_UNet.png b/recognition/2D_OASIS_s43971125/documentation/Improved_UNet.png new file mode 100644 index 000000000..b5ee13fbc Binary files /dev/null and b/recognition/2D_OASIS_s43971125/documentation/Improved_UNet.png differ diff --git a/recognition/2D_OASIS_s43971125/documentation/UNet.png b/recognition/2D_OASIS_s43971125/documentation/UNet.png new file mode 100644 index 000000000..ffbc1bfa9 Binary files /dev/null and b/recognition/2D_OASIS_s43971125/documentation/UNet.png differ diff --git a/recognition/2D_OASIS_s43971125/documentation/dice_scores.png b/recognition/2D_OASIS_s43971125/documentation/dice_scores.png new file mode 100644 index 000000000..bace36e08 Binary files /dev/null and b/recognition/2D_OASIS_s43971125/documentation/dice_scores.png differ diff --git a/recognition/2D_OASIS_s43971125/documentation/prediction_example.png b/recognition/2D_OASIS_s43971125/documentation/prediction_example.png new file mode 100644 index 000000000..6731a86d7 Binary files /dev/null and b/recognition/2D_OASIS_s43971125/documentation/prediction_example.png differ diff --git a/recognition/2D_OASIS_s43971125/documentation/prediction_per_class.png b/recognition/2D_OASIS_s43971125/documentation/prediction_per_class.png new file mode 100644 index 000000000..cab5fca41 Binary files /dev/null and b/recognition/2D_OASIS_s43971125/documentation/prediction_per_class.png differ diff --git a/recognition/2D_OASIS_s43971125/modules.py b/recognition/2D_OASIS_s43971125/modules.py new file mode 100644 index 000000000..bd3f4fdfd --- /dev/null +++ b/recognition/2D_OASIS_s43971125/modules.py @@ -0,0 +1,272 @@ +# Contains the source code of the model. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + #Very basic ConvBlock, based on lecture example + super(ConvBlock, self).__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + return x + +class DownBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(DownBlock, self).__init__() + self.conv = ConvBlock(in_channels, out_channels) + self.pool = nn.MaxPool2d(2) + + def forward(self, x): + x = self.conv(x) + p = self.pool(x) + return x, p # return features before pooling for skip connection + +class UpBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(UpBlock, self).__init__() + self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) + self.conv = ConvBlock(in_channels, out_channels) # in_channels includes skip connection + + def forward(self, x, skip): + x = self.up(x) + x = torch.cat([x, skip], dim=1) # concatenate along channel dimension + x = self.conv(x) + return x + +''' +Basic UNet - this code was adapted from the Lectures +''' + +class UNet(nn.Module): + def __init__(self, in_channels=1, out_channels=4, base_filters=64): + super(UNet, self).__init__() + # Encoder + self.down1 = DownBlock(in_channels, base_filters) + self.down2 = DownBlock(base_filters, base_filters*2) + self.down3 = DownBlock(base_filters*2, base_filters*4) + self.down4 = DownBlock(base_filters*4, base_filters*8) + + # Bottleneck + self.bottleneck = ConvBlock(base_filters*8, base_filters*16) + + # Decoder + self.up4 = UpBlock(base_filters*16, base_filters*8) + self.up3 = UpBlock(base_filters*8, base_filters*4) + self.up2 = UpBlock(base_filters*4, base_filters*2) + self.up1 = UpBlock(base_filters*2, base_filters) + + # Final conv + self.final_conv = nn.Conv2d(base_filters, out_channels, kernel_size=1) + pass + + def forward(self, x): + # Encoder + s1, p1 = self.down1(x) + s2, p2 = self.down2(p1) + s3, p3 = self.down3(p2) + s4, p4 = self.down4(p3) + + # Bottleneck + b = self.bottleneck(p4) + + # Decoder + d4 = self.up4(b, s4) + d3 = self.up3(d4, s3) + d2 = self.up2(d3, s2) + d1 = self.up1(d2, s1) + + out = self.final_conv(d1) + return out + + +#Below are all the requirements for the improved UNet + + +class PreActResBlock(nn.Module): + #dropout_prob is optional tuning metric, used to regularise the block (prevents overfitting) + def __init__(self, in_ch, out_ch, dropout_prob=0.0): + super().__init__() + self.in_ch = in_ch + self.out_ch = out_ch + self.dropout_prob = dropout_prob + + #First Normalisation + convolution + self.norm1 = nn.InstanceNorm2d(in_ch, affine=False) + self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False) + + #Second normalisation + convolution + self.norm2 = nn.InstanceNorm2d(out_ch, affine=False) + self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False) + + self.dropout = nn.Dropout2d(p=dropout_prob) if dropout_prob > 0 else nn.Identity() + + #check that input can be added to output + self.skip = nn.Conv2d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else nn.Identity() + + def forward(self, x): + out = self.conv1(F.relu(self.norm1(x))) + out = self.dropout(out) + out = self.conv2(F.relu(self.norm2(out))) + # residual add + return out + self.skip(x) + +class DownResBlock(nn.Module): + def __init__(self, in_ch, out_ch, dropout_prob=0.0): + super().__init__() + self.res = PreActResBlock(in_ch, out_ch, dropout_prob=dropout_prob) + self.pool = nn.MaxPool2d(kernel_size=2) + + def forward(self, x): + features = self.res(x) + pooled = self.pool(features) + return features, pooled + +class UpResBlock(nn.Module): + def __init__(self, in_ch, skip_ch, out_ch, dropout_prob=0.0): + """ + in_ch = channels of decoder input (from previous layer) + skip_ch = channels from encoder skip connection + out_ch = desired output channels after block + First upsample (in_ch -> out_ch), then concatenate with skip (skip_ch), + then use a PreActResBlock with in_ch = out_ch + skip_ch, out_ch = out_ch + """ + super().__init__() + self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2) + self.conv = PreActResBlock(out_ch + skip_ch, out_ch, dropout_prob=dropout_prob) + + def forward(self, x, skip): + x = self.up(x) + # If shapes differ by one pixel due to odd sizes, center-crop or pad (common safe approach) + if x.shape[-2:] != skip.shape[-2:]: + # simple center crop/pad to match skip spatial dims + target_h, target_w = skip.shape[-2], skip.shape[-1] + x = F.interpolate(x, size=(target_h, target_w), mode='bilinear', align_corners=False) + x = torch.cat([x, skip], dim=1) + x = self.conv(x) + return x + +""" +Improved UNet: +- PreAct residual blocks +- InstanceNorm +- Dropout2d +- Deep supervision +Parameters: +in_channels: input channels (e.g. 1) +out_channels: number of segmentation classes (4 for 2D_OASIS) +base_filters: number of filters at first level (commonly 32 or 64) +dropout_prob: probability for spatial dropout inside blocks +deep_supervision: whether to include deep supervision +""" +class ImprovedUNet(nn.Module): + def __init__(self, in_channels=1, out_channels=4, base_filters=64, + dropout_prob=0.1, deep_supervision=True): + super().__init__() + f = base_filters + self.deep_supervision = deep_supervision + + # Encoder + self.enc1 = DownResBlock(in_channels, f, dropout_prob=dropout_prob) + self.enc2 = DownResBlock(f, f*2, dropout_prob=dropout_prob) + self.enc3 = DownResBlock(f*2, f*4, dropout_prob=dropout_prob) + self.enc4 = DownResBlock(f*4, f*8, dropout_prob=dropout_prob) + + # Bottleneck + self.bottleneck = PreActResBlock(f*8, f*16, dropout_prob=dropout_prob) + + # Decoder (note channel bookkeeping) + self.up4 = UpResBlock(f*16, skip_ch=f*8, out_ch=f*8, dropout_prob=dropout_prob) + self.up3 = UpResBlock(f*8, skip_ch=f*4, out_ch=f*4, dropout_prob=dropout_prob) + self.up2 = UpResBlock(f*4, skip_ch=f*2, out_ch=f*2, dropout_prob=dropout_prob) + self.up1 = UpResBlock(f*2, skip_ch=f, out_ch=f, dropout_prob=dropout_prob) + + # Final 1x1 conv to logits + self.final_conv = nn.Conv2d(f, out_channels, kernel_size=1) + + # Auxiliary heads for deep supervision - map intermediate decoder features to logits + if self.deep_supervision: + self.aux4 = nn.Conv2d(f*8, out_channels, kernel_size=1) # from d4 + self.aux3 = nn.Conv2d(f*4, out_channels, kernel_size=1) # from d3 + self.aux2 = nn.Conv2d(f*2, out_channels, kernel_size=1) # from d2 + # We do not need aux for d1 since final_conv handles it + + def forward(self, x): + # Encoder Step + # each call of self.enc(1-4) is an instance of DownResBlock + #s1-4 are saved so they can be used for skip connections + #p1-4 are the progressively downsampled features of the input + s1, p1 = self.enc1(x) # s1: f + s2, p2 = self.enc2(p1) # s2: f*2 + s3, p3 = self.enc3(p2) # s3: f*4 + s4, p4 = self.enc4(p3) # s4: f*8 + + # Bottleneck + # This is the deepest layer of the network (i.e we encode down to here, do feature extraction + # at the lowest level (most abstracted features), then decode back up to higher level. + b = self.bottleneck(p4) # f*16 + + # Decoder + #Each call of up(1-4) is an instance of upResBlock + #it takes a feature map and a skip connection, upsamples the feature map, concatenates + #it with the skip connection, then merges the features. + # each call utilises the previous feature map that has been upsampled + d4 = self.up4(b, s4) # f*8 + d3 = self.up3(d4, s3) # f*4 + d2 = self.up2(d3, s2) # f*2 + d1 = self.up1(d2, s1) # f + + # Final convolution + # We take our final feature map from the decode, and map it to the number of output classes + # we can then feed this result into DiceLoss + out_final = self.final_conv(d1) # [B, out_channels, H, W] + + if not self.deep_supervision: + return out_final + + # Deep supervision + # create extra predictions at lower decoder levels (i.e not d1) + # These are essentially "early" predictions, at lower resolutions + aux4 = self.aux4(d4) + aux3 = self.aux3(d3) + aux2 = self.aux2(d2) + + # Upsample Aux + # because each decoder layer has a different spatial size, have to upsample all of them so they match + # out_final's size + target_size = out_final.shape[-2:] + aux4_up = F.interpolate(aux4, size=target_size, mode='bilinear', align_corners=False) + aux3_up = F.interpolate(aux3, size=target_size, mode='bilinear', align_corners=False) + aux2_up = F.interpolate(aux2, size=target_size, mode='bilinear', align_corners=False) + + # Sum elementwise to combine them. + # By doing this, gives stronger supervision to early layers + # also means we can swap out standard UNet and Improved UNet in train.py with no changes + # since they both return a single output this way + combined = out_final + aux2_up + aux3_up + aux4_up + return combined + + + +def build_unet(in_channels=1, out_channels=4): + return UNet(in_channels, out_channels) + +def build_improved_unet(in_channels=1, out_channels=4, base_filters=32, dropout_prob=0.1, deep_supervision=True): + return ImprovedUNet(in_channels=in_channels, out_channels=out_channels, + base_filters=base_filters, dropout_prob=dropout_prob, + deep_supervision=deep_supervision) + +#Forward pass test +if __name__ == "__main__": + x = torch.randn(1, 1, 256, 256) # batch_size=1, channels=1, H=W=256 + model = build_unet() + y = model(x) + print("Input shape:", x.shape) + print("Output shape:", y.shape) # should be [1, out_channels, 256, 256] diff --git a/recognition/2D_OASIS_s43971125/predict.py b/recognition/2D_OASIS_s43971125/predict.py new file mode 100644 index 000000000..b2bedbe92 --- /dev/null +++ b/recognition/2D_OASIS_s43971125/predict.py @@ -0,0 +1,133 @@ +# Contains example usage of the model + +import torch +from modules import UNet +from modules import ImprovedUNet +from dataset import OASISDataset +import matplotlib.pyplot as plt +from train import dice_coefficient + +def load_model(model_path="best_improved.pth", in_channels=1, out_channels=4, device=None): + """ + Loads a trained UNet model from the specified checkpoint path. + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model = ImprovedUNet(in_channels=in_channels, out_channels=out_channels).to(device) + model.load_state_dict(torch.load(model_path, map_location=device)) + model.eval() + return model, device + + +def predict_and_visualise(model, dataset_split="test", sample_index=0, save_path="prediction_example.png", num_classes=4, device=None): + """ + Runs inference on one sample from the test dataset and visualises input, ground truth, and output (prediction). Saves the + generated image to save_path. + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load one sample from dataset + dataset = OASISDataset(split=dataset_split) + image, label = dataset[sample_index] + + # Add batch dimension and move to device + image = image.unsqueeze(0).to(device) + label = label.unsqueeze(0).to(device) + + # Forward pass + with torch.no_grad(): + output = model(image) + probs = torch.softmax(output, dim=1) + predicted = torch.argmax(output, dim=1) + + # Compute Dice per class + dice_scores = dice_coefficient(output, label, num_classes=num_classes) + print(f"Dice per Class (sample {sample_index}):", dice_scores) + + # Visualise results + plt.figure(figsize=(12, 4)) + + plt.subplot(1, 3, 1) + plt.title("Input") + plt.imshow(image[0, 0].cpu(), cmap='gray') + + plt.subplot(1, 3, 2) + plt.title("Ground Truth") + plt.imshow(label.squeeze().cpu(), cmap='jet') + + plt.subplot(1, 3, 3) + plt.title("Prediction") + plt.imshow(predicted[0].cpu(), cmap='jet') + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + #Note: return dice scores so that we can plot them using a different function + return dice_scores + +def predict_and_visualise_per_class(model, dataset_split="test", sample_index=0, save_path="prediction_per_class.png", num_classes=4, device=None): + """ + Runs inference on one sample from the dataset and visualises per-class probability maps. + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load one sample from dataset + dataset = OASISDataset(split=dataset_split) + image, label = dataset[sample_index] + + # Add batch dimension and move to device + image = image.unsqueeze(0).to(device) + label = label.unsqueeze(0).to(device) + + # Forward pass + with torch.no_grad(): + output = model(image) + probs = torch.softmax(output, dim=1) # [B, C, H, W] + predicted = torch.argmax(output, dim=1) + + # Visualise results + n_cols = 3 + num_classes # input + GT + overall prediction + per-class + plt.figure(figsize=(4 * n_cols, 4)) + + # Each Class Probability Map + for c in range(num_classes): + plt.subplot(1, n_cols, 4 + c) + plt.title(f"Class {c} Prob") + plt.imshow(probs[0, c].cpu(), cmap='inferno') + plt.axis('off') + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + + +def plot_dice_scores(dice_scores, save_path="dice_scores.png"): + """ + Plots a bar chart of Dice scores for each class. + """ + dice_scores_np = dice_scores.cpu().numpy() + classes = [f"Class {i}" for i in range(len(dice_scores_np))] + plt.figure(figsize=(6, 4)) + plt.bar(classes, dice_scores_np, color="skyblue", edgecolor="black") + plt.ylim(0, 1.05) + plt.ylabel("Dice Score") + plt.title("Dice Score per Class") + plt.grid(axis='y', linestyle='--', alpha=0.6) + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + print(f"Saved Dice score plot to {save_path}") + +if __name__ == "__main__": + model, device = load_model() + #Note: plot_dice_scores requires dice_scores as input + #for simplicity,we just return them when we are plotting the example input-output + dice_scores = predict_and_visualise(model, sample_index=5, save_path="prediction_example.png", device=device) + predict_and_visualise_per_class(model, sample_index=5, save_path="prediction_per_class.png", device=device) + plot_dice_scores(dice_scores, save_path="dice_scores.png") + + + diff --git a/recognition/2D_OASIS_s43971125/train.py b/recognition/2D_OASIS_s43971125/train.py new file mode 100644 index 000000000..4a0d19f16 --- /dev/null +++ b/recognition/2D_OASIS_s43971125/train.py @@ -0,0 +1,162 @@ +# Contains the source code for training and validation + +import torch +import torch.nn as nn +import torch.optim as optim +import os +import sys + +from modules import UNet +from modules import ImprovedUNet +from dataset import get_dataloaders + +#Create the DiceLoss functionality, as explained in the lectures +class DiceLoss(nn.Module): + #Subclass behaves like other Pytorch Loss functions + #smooth is a small constant - used in the calculation step to avoid potential division by zero + def __init__(self, smooth=1e-6): + super(DiceLoss, self).__init__() + self.smooth = smooth + + def forward(self, preds, targets): + #preds = predictions, in form [BatchSize, Classes, Height, Width] = [B, C, H, W] + #targets are in form [B, H, W] + + #torch.softmax turns the UNet output scores into probabilities, which sum to 1 + preds = torch.softmax(preds, dim=1) + + #Convert targets to one-hot encoding. This ensures that when we flatten it into + # a 1D tensor, it's the same size as the predictions tensor, so we can calculate + #the intersection. + targets_onehot = torch.nn.functional.one_hot(targets, num_classes=preds.shape[1]) + # This line swaps the order around, so it matches preds order (which is the just + # the raw output of UNet, which gives [BatchSize, NumClasses, ImageHeight, ImageWidth] + targets_onehot = targets_onehot.permute(0, 3, 1, 2).float() + + #Calculate the Dice score for each class + intersection = (preds * targets_onehot).sum(dim=(2, 3)) + union = preds.sum(dim=(2, 3)) + targets_onehot.sum(dim=(2,3)) + dice_score = (2.0 * intersection + self.smooth) / (union + self.smooth) + + #Subtract the mean dice score across all classes from 1 to find the average DiceLoss + loss = 1 - dice_score.mean() + return loss + +#Same as above, but this is used during the evaluation step, while the DiceLoss is used for training +#Calculates dice_coefficient per class +def dice_coefficient(preds, targets, num_classes=4, smooth=1e-6): + preds = torch.softmax(preds, dim=1) + preds = torch.argmax(preds, dim=1) + + dice_scores = [] + # for each class, compute the dice score, then add them to a tensor. + for i in range(num_classes): + pred_i = (preds == i).float() + target_i = (targets == i).float() + + intersection = (pred_i * target_i).sum() + union = pred_i.sum() + target_i.sum() + + dice = (2.0 * intersection + smooth) / (union + smooth) + dice_scores.append(dice.item()) + #return the tensor that contains the dice score for each class + return torch.tensor(dice_scores, device=preds.device) + +#Note: wrap the training loop so it's only run when train.py is called, not when it's imported +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + #device = torch.device("cpu") + + # Set the training paremeters + num_epochs = 50 + batch_size = 4 + learning_rate = 1e-4 + save_path = "best_improved.pth" + + # Load in the datasets + train_load, val_load, test_load = get_dataloaders(batch_size=batch_size) + + # Define the model, loss (DiceLoss) and optimiser (Adam) + + model = ImprovedUNet(in_channels=1, out_channels=4).to(device) + criterion = DiceLoss() + optimiser = optim.Adam(model.parameters(), lr=learning_rate) + + + # set the initial best dice score to 0. Gets updated each time the model + # finds a better score + best_val_dice = 0.0 + + # Training loop + # Note all print statements have no impact on training, are optional to track progress + for epoch in range(num_epochs): + model.train() + running_loss = 0.0 + + for images, labels in train_load: + images = images.to(device) + labels = labels.to(device) + + optimiser.zero_grad() + outputs = model(images) + loss = criterion(outputs, labels) + loss.backward() + optimiser.step() + + running_loss += loss.item() + + avg_loss = running_loss / len(train_load) + #print(f"Epoch {epoch+1} finished, Avg Loss: {avg_loss:.4f}") + + #Validation Step + model.eval() + dice_scores = [] + with torch.no_grad(): + for images, labels in val_load: + images = images.to(device) + labels = labels.to(device) + outputs = model(images) + dice_batch = dice_coefficient(outputs, labels, num_classes=4) + dice_scores.append(dice_batch) + + dice_scores = torch.stack(dice_scores) + mean_dice = dice_scores.mean(dim=0) + + # Compute overall mean Dice score + mean_val_dice = mean_dice.mean().item() + + # Print per-class results and average results + #print(f"Validation Dice per class: {mean_dice.cpu().numpy()}") + #print(f"Mean Validation Dice: {mean_val_dice:.4f}") + + # Mean probability for each class + probs = torch.softmax(outputs, dim=1) + mean_probs = probs.mean(dim=(0, 2, 3)) + #print(f"Mean predicted probabilities per class: {mean_probs.cpu().numpy()}") + + + # Save the best model so far, based on the average Dice score + if mean_val_dice > best_val_dice: + torch.save(model.state_dict(), save_path) + best_val_dice = mean_val_dice + best_per_class = mean_dice.cpu().numpy() + #print(f"✅ New best model saved (Avg Dice: {mean_val_dice:.4f}, per class: {best_per_class})") + + #blank line to separate each epoch printout in the log + #print() + + # Final test evaluation to find Dice Coefficient for each class + model.load_state_dict(torch.load(save_path)) + model.eval() + dice_totals = torch.zeros(4, device=device) + with torch.no_grad(): + for images, labels in test_load: + images = images.to(device) + labels = labels.to(device) + outputs = model(images) + dice_scores = dice_coefficient(outputs, labels, num_classes=4) + dice_totals += dice_scores.to(device) + dice_averages = dice_totals / len(test_load) + + #for i, dice in enumerate(dice_averages): + # print(f"Class {i} Dice: {dice:.4f}")