Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions recognition/2D_OASIS_s43971125/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Table of Contents
- [Improved 2D UNet](#improved-2d-unet)
- [Problem](#problem)
- [Model](#model)
- [Loading Data](#loading-data)
- [Training](#training)
- [Loss](#loss)
- [Optimiser](#optimiser)
- [Testing](#testing)
- [Result](#result)
- [References](#references)
- [Dependencies](#dependencies)

# Improved 2D UNet
A UNet is a convolutional neural network that is designed for image segmentation. This means it
classifies every pixel of an image into one of several different categories. This is particularly useful
for medical images, where you might want to classify which parts of an image correspond to certain types of tissue.

## Problem
The problem being solved is the segmentation of 2D mri images of brains, into 4 different types of brain tissue.
The goal is for all 4 labels to have a minimum Dice similarity coefficient of 0.9 - so once the model is done training,
it could be used on any 2D mri image of a brain to accurately identify which parts of the image correspond to each type
of tissue. This could be useful for identifying potential abnormalities when scanning for illness.

## Model
[modules.py](modules.py)

An Improved 2D UNet is similar in basic structure to a UNet. It follows the same U-Shaped architecture, where an image
is gradually reduced in spatial size while the number of channels are increased (essentially
zooming in on the image, in a way) in the encoder step. Then the feature maps are upscaled back to the original size
of the image while maintaining the learned feature, until eventually you end up with a segmented image of the same size
as the original. UNet also utilise skip connections, where it can take the outputs from earlier levels of the encoder
and hand them directly to the decoder of the same level - this helps to reconstruct boundaries, because reducing spatial
dimensions leads to information loss.

![UNet Architecture](documentation/UNet.png)

An Improved 2D UNet has some enhancements however. Instead of the normal convolutional blocks, it use residual blocks
with pre-activation (activation before convolution). It has optional dropout layers to avoid overfitting (unused here),
instance normalisation instead of batch normalisation (better stability with small batch sizes), and deep supervision.
Deep supervision essentially combines the output from all the decoder layers to give the final output. This provides
better gradient flow (because it provides the gradient from all levels of the decoder, faster convergence (training
stabilises faster which speeds it up) and improved accuracy (because we are taking outputs from all levels of the encoder,
intermediate features are more prominent, which will lead to better final predictions

![Improved UNet Architecture](documentation/Improved_UNet.png)

# Loading Data
[dataset.py](dataset.py)
For this problem, loading the data is relatively simple. No transforms need to be applied, and the data is already
separated into images and labels, and training/testing/validation groups. All that needs to be done is to convert
each image and label to a pytorch tensor so that they can be processed by the Improved UNet.

# Training
Parameters used were as follows (these can be adjusted within train.py
Batch size: 4
Number of Epochs: 50
Learning Rate: 1e-4


## Loss
This model uses Dice similarity coefficient loss (DiceLoss). It computes the simarilty of our segmented images and
manually delineated images, based on how much they overlap. Specifically, we use Multiclass Dice loss (where we compute
the dice score for each class separately, then take the average).
DiceLoss function was generated by AI, then adjusted based on the information from
the lecture slides.

## Optimiser
This model uses the Adam (Adaptive Moment Estimation) optimiser, which is built-in to pytorch. It works by
adjusting the learning rate for each parameter during training, by keeping track of how gradients are changing
for each weight and adjusting how much each weight is updated. It's a very common choice for deep segmentation models
such as UNet.

# Testing
The model was tested by measuring the Dice Scores for each class. Each Dice score is calculated independently to
avoid imbalanced class errors. See results for an example of the input and output, the probability maps for
each class, and the dice scores per class.

# Result
Dice scores per class
Class 0: 0.9993
Class 1: 0.9585
Class 2: 0.9653
Class 3: 0.9783
![Dice Scores](documentation/dice_dcores.png)

Example input-output
![Visualisation](documentation/prediction_example.png)

Example per-class visualisation
![Per-class visualisation](documentation/prediction_per_class.png)

# References
Lecture Slides for model PNGs

# Dependencies
pytorch=2.6.0
torchvision=0.21.0
numpy=1.25.0
pillow-9.5.0
matplotlib=3.8.0
83 changes: 83 additions & 0 deletions recognition/2D_OASIS_s43971125/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Contains the data loader


import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import glob

'''
OASISDataset class.
Loads the 2D_OASIS_dataset, converts the images and labels to pytorch tensors,
Remaps the labels to one-hot encoding, so that they are compatible with the DiceLoss function
Optionally transforms the image if 'transform' is set, then returns the images and labels
'''
class OASISDataset(Dataset):
'''Current setup assumes that 2D_OASIS dataset is located in the specified root_dir
with the images and labels already split into train, test, and validation directory
Also assumes that all labels and images are .png files
/Oasis-
-keras_png_slices_train
-keras_png_slices_test
-keras_png_slices_val
-keras_png_slices_seg_train
-keras_png_slices_seg_test
-keras-png_slices_seg_val
If there is a different location, need to adjust root_dir, image_dir and label_dir accordingly
'''
def __init__(self, root_dir="/home/groups/comp3710/OASIS", split="train", categorical=False, transform=None):

image_dir = os.path.join(root_dir, f"keras_png_slices_{split}")
label_dir = os.path.join(root_dir, f"keras_png_slices_seg_{split}")

self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")))
self.label_paths = sorted(glob.glob(os.path.join(label_dir, "*.png")))

self.transform = transform

def __len__(self):
return len(self.image_paths)

def __getitem__(self, idx):
# load images and labels
image = Image.open(self.image_paths[idx]).convert("L")
label = Image.open(self.label_paths[idx]).convert("L")

# Convert to np arrays
image_np = np.array(image)
label_np = np.array(label, dtype=np.uint8)



# Convert to tensors
image = torch.tensor(image_np, dtype=torch.float32).unsqueeze(0) / 255.0
label = torch.tensor(label_np, dtype=torch.long)

#remap labels from [0, 85, 170, 255] to [0, 1, 2, 3], so that one-hot encoding of
#targets (in the diceloss function) works correctly. We do this by just dividing every label
#by 85, and rounding down.
label = torch.div(label, 85, rounding_mode='floor')

#optional transformation step - only used if transform is set in __init__
if self.transform:
image = self.transform(image)

return image, label

# Helper function that is called by train.py to get all 3 dataloaders
def get_dataloaders(root_dir="/home/groups/comp3710/OASIS", batch_size=4):
# define the 3 datasets
train_dataset = OASISDataset(root_dir=root_dir, split="train")
val_dataset = OASISDataset(root_dir=root_dir, split="validate")
test_dataset = OASISDataset(root_dir=root_dir, split="test")

#Construct the 3 dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

#Return the 3 dataloaders
return train_loader, val_loader, test_loader

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading