Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c934485
Initialised project files
s4684 Oct 16, 2025
9f578ea
Adjusted project directory init structure
s4684 Oct 17, 2025
1613f75
Adjusted project directory init structure (part 2)
s4684 Oct 17, 2025
4caf70c
Implemented dataset loaders for the OASIS dataset
s4684 Nov 3, 2025
95b5e29
Implemented UNet modules
s4684 Nov 3, 2025
e0e3854
Fixed typo error in dataset loader module
s4684 Nov 3, 2025
c9fe8a3
Fixed faulty loading in dataset loader module
s4684 Nov 3, 2025
1367fba
Implement model training basics
s4684 Nov 3, 2025
f91a814
Implemented simple status updates and model saving for training module
s4684 Nov 3, 2025
6e67ad4
Adjusted minor formatting problems
s4684 Nov 3, 2025
67a3674
Established simple model prediction validation
s4684 Nov 3, 2025
a3a7465
Implemented prediction visualisation in the training and evaluation m…
s4684 Nov 4, 2025
b40a754
Swapped the roles of test and validation datasets
s4684 Nov 4, 2025
d635550
Implemented CE loss function alongside multi-class dice for training …
s4684 Nov 4, 2025
604abd1
Removed testing config parameters from train module
s4684 Nov 4, 2025
be15ac0
Establish project README file
s4684 Nov 4, 2025
a3f1981
Implemented loss modelling figures in evaluation module
s4684 Nov 5, 2025
d32408f
Removed axes from batch visualisation plots
s4684 Nov 5, 2025
938620b
Removed testing configuration from training module
s4684 Nov 5, 2025
4123d2a
Added figures in README
s4684 Nov 5, 2025
4a10be0
Updated the comments for training module
s4684 Nov 5, 2025
5c96216
Updated the comments for dataset loading module
s4684 Nov 5, 2025
0f4de89
Updated the comments for model modules
s4684 Nov 5, 2025
e643170
Updated README to include final model results. Changed README figure …
s4684 Nov 5, 2025
4a70957
Fixed typo in README
s4684 Nov 5, 2025
5e06ab8
Changed final dice losses to dice coefficients in README
s4684 Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions recognition/2d_oasis-46412384/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data
__pycache__
_tdmodel*
133 changes: 133 additions & 0 deletions recognition/2d_oasis-46412384/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# 2D OASIS brain data segmentation with Improved UNet

## Project Synopsis

This project segments raw 2D brain MRI data slices into regions of distinct classification.
Using an improved UNet model, success criteria consisted of achieving a dice similarity coefficient of 0.9 minimum across the final predictions.\
Each 2D data slice is segmented into 4 classes, distinguished by equally distributed greyscale shades in the segmentation mask:

1. Background (black)
2. Cerebro-spinal fluid (dark grey)
3. Grey matter (light grey)
4. White matter (white)

For example:

![example batch prediction visualisation](figs/ex-batch_preds.png)


## Model Analysis

### UNet Model

The Improved UNet algorithm proves to be a reliable model architecture for 2D image segmentation tasks, particularly those of medical application.
In summary, it's design revolves around applying numerous fine filters over the image to discern patterns across the image.
These filters are applied repeatedly over several layers, between which the data is compressed, yielding a smaller result each time.
This act of compression - called pooling - allows for the algorithm to capture elements of both fine detail and abstract patterns across the data.
Finally, the data is upscaled back through each layer, where each layer of prediction is compiled presenting a final mask that matches the resolution of the input image.
This method of compaction to a final bottleneck and contrasting reconstruction presents the U-shape architecture that entitles the UNet.

### Dataset

This model is trained on the pre-split OASIS dataset. Training comprised 9664 slices, validation over 1120 validation slices, and final evaluation on 544 testing slices.
This achieves a distribution approximating 85 / 10 / 5 respectively, aligning with the train-heavier side of typical split ratios.

The dataset consistutes 256x256 greyscale (single channel) images. Loading of each dataset consisted of:

- Sequentially pairing each slice with their respective mask
- Applying z-score normalization to each slice, minimising instability in training values
- Decoding the masks from greyscale values to class-label identifiers (integers 0 - 3)

Such processing was computed on-the-fly, as data was requested by the model.
This increases the training time, as data was processed repeatedly upon each epoch, in favour of optimising the memory demands.

### Model Operation

Training consisted of 26 epochs and batches of 4 images, with validation tests occurring after each epoch.
Model evaluation utilised a combination of Dice loss coefficient and cross entropy loss methods calculated at each epoch, however testing was governed by cross entropy loss.
Train the model using:
```
python3 train.py
```
Each epoch loss will be displayed, and the model state and training losses are saved to disk (model: `_tdmodel/model.dat`; losses: `_tdmodel/loss.dat`).\
Snapshots of the predictions throughout training are also saved locally (`_tdmodel/disp/`)

### Model Evaluation

After training, evaluate the model using:
```
python3 predict.py
```
This will load the saved model and test it using the yet unseen test dataset.
Its final result is displayed in output and the visualisation of predictions in the first batch is saved to disk.\
For example:

![example evaluation batch visualisation](figs/ex-eval_preds.png)

This module will also model the loss results from training and validation (`_tdmodel/plot/`).\
For example:

![example cross-entropy loss plot](figs/ex-celoss.png)


## Model Results

The final result of this model achieved a loss score of **0.29355** and dice score of **0.59782**.

> Training loss scores continued to decrease with each epoch, reaching a minimum of **0.20137**, while validation scores converged towards, approximately, **0.29**.

> Similarly, training dice scores show decline across epochs, reaching **0.66733**, validation scores converged about **0.6**

While it is likely possible to reach a convergence of training loss scores, given more epochs, this would only result in model overfitting, as the validation scores had already reached their convergence points.\
The resulting output is:

```
Starting model training on cuda
[ Epoch 1 / 26 ] Train: L=0.98436 D=0.71914 Validate: L=0.86103 D=0.67615
[ Epoch 2 / 26 ] Train: L=0.62303 D=0.62290 Validate: L=0.53432 D=0.58800
[ Epoch 3 / 26 ] Train: L=0.45187 D=0.55552 Validate: L=0.40059 D=0.52977
[ Epoch 4 / 26 ] Train: L=0.37325 D=0.51251 Validate: L=0.36206 D=0.51142
[ Epoch 5 / 26 ] Train: L=0.33565 D=0.48701 Validate: L=0.35009 D=0.48460
[ Epoch 6 / 26 ] Train: L=0.31543 D=0.47188 Validate: L=0.35551 D=0.48481
[ Epoch 7 / 26 ] Train: L=0.29792 D=0.45573 Validate: L=0.31159 D=0.46544
[ Epoch 8 / 26 ] Train: L=0.28726 D=0.44615 Validate: L=0.30409 D=0.45165
[ Epoch 9 / 26 ] Train: L=0.27587 D=0.43496 Validate: L=0.29401 D=0.43792
[ Epoch 10 / 26 ] Train: L=0.26944 D=0.42707 Validate: L=0.28949 D=0.43535
[ Epoch 11 / 26 ] Train: L=0.26269 D=0.41961 Validate: L=0.28887 D=0.43687
[ Epoch 12 / 26 ] Train: L=0.25914 D=0.41463 Validate: L=0.27945 D=0.42804
[ Epoch 13 / 26 ] Train: L=0.25293 D=0.40722 Validate: L=0.29024 D=0.42730
[ Epoch 14 / 26 ] Train: L=0.24776 D=0.40157 Validate: L=0.28522 D=0.43205
[ Epoch 15 / 26 ] Train: L=0.24375 D=0.39455 Validate: L=0.28957 D=0.43101
[ Epoch 16 / 26 ] Train: L=0.24061 D=0.38900 Validate: L=0.28890 D=0.42225
[ Epoch 17 / 26 ] Train: L=0.23789 D=0.38711 Validate: L=0.30815 D=0.42445
[ Epoch 18 / 26 ] Train: L=0.23599 D=0.38272 Validate: L=0.29255 D=0.41271
[ Epoch 19 / 26 ] Train: L=0.22989 D=0.37462 Validate: L=0.28706 D=0.41631
[ Epoch 20 / 26 ] Train: L=0.22484 D=0.36830 Validate: L=0.28170 D=0.40779
[ Epoch 21 / 26 ] Train: L=0.21971 D=0.35926 Validate: L=0.28826 D=0.41555
[ Epoch 22 / 26 ] Train: L=0.21698 D=0.35655 Validate: L=0.30154 D=0.41290
[ Epoch 23 / 26 ] Train: L=0.21315 D=0.34834 Validate: L=0.29658 D=0.40792
[ Epoch 24 / 26 ] Train: L=0.20954 D=0.34459 Validate: L=0.29707 D=0.40281
[ Epoch 25 / 26 ] Train: L=0.20679 D=0.33885 Validate: L=0.30125 D=0.41382
[ Epoch 26 / 26 ] Train: L=0.20137 D=0.33267 Validate: L=0.29598 D=0.40516
Training complete!
```
```
Starting model evaluation
[ Eval ] Validate: L=0.29355 D=0.40218
Evaluation complete!
```

Yielding the loss plots:

![final model result dice loss scores](figs/fin-diceloss.png)

![final model result CE loss scores](figs/fin-celoss.png)


## Dependencies

- `python : 3.9.23`
- `pytorch : 2.5.1`
- `numpy : 2.0.1`
- `pillow : 11.3.0`
- `matplotlib : 3.9.2`
112 changes: 112 additions & 0 deletions recognition/2d_oasis-46412384/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from os import listdir
from PIL import Image


class DataError(Exception):
"""
Raised when a problem regarding the dataset is encountered
"""
pass


class OasisDataset(Dataset):
"""
Dataset for OASIS greyscale images and segmentations
"""

def __init__(self, data_dir, mask_dir, subset_size=-1):
# Find all data files
self.data_paths = [data_dir + "/" + file for file in listdir(data_dir)[:subset_size]]
self.mask_paths = [mask_dir + "/" + file for file in listdir(mask_dir)[:subset_size]]

# Establish dataset size consistency between images, masks and subset size
if (subset_size < 0 or len(self.data_paths) < subset_size):
self.subset_size = len(self.data_paths)
else:
self.subset_size = subset_size

if (len(self.mask_paths) != self.subset_size):
raise DataError("Non-matching subset sizes between data (length %d) and mask (length %d)." % (len(self.data_paths), len(self.mask_paths)))

def __len__(self):
return self.subset_size

def __getitem__(self, idx):
if (idx < 0 or idx >= self.subset_size):
raise IndexError("Index %d out of range for subset size %d" % (idx, self.subset_size))

# Read and load image data
img = Image.open(self.data_paths[idx])
img_data = np.resize(np.array(img.getdata(), dtype=np.float32), img.size)

# Apply z-score normalisation to image data
# This stabilises all images to a common scale by filtering out irrelevant details, as to maintain consistency in the model training
img_data = (img_data - img_data.mean()) / (img_data.std() + 1e-6)
# Shift array dimensions as to accomodate batch groups
img_data = torch.from_numpy(img_data).unsqueeze(0).float()

# Load and decode mask
msk = Image.open(self.mask_paths[idx])
msk_data = torch.from_numpy(OasisDataset.decode_mask(np.resize(np.array(msk.getdata(), dtype=np.int64), msk.size)))

return img_data, msk_data


"""
Decode greyscale mask values from numpy array to segment classification integers
"""
def decode_mask(enc_msk):
# Collect a list of unique values in the mask
vals = set()
for val in enc_msk.flat:
vals.add(val)

vals = np.sort(np.array(list(vals), dtype=np.uint8))

# Replace each value with a corresponding identifying index
dec_msk = np.zeros_like(enc_msk)
for idx, val in enumerate(vals):
dec_msk[enc_msk == val] = idx

return dec_msk


"""
Encode segment classification integers from numpy array into greyscale values of equal distribution
"""
def encode_mask(dec_msk):
# Collect a list of unique class label indices
vals = set()
for val in dec_msk.flat:
vals.add(val)

vals = np.sort(np.array(list(vals)))

# Generate an equidistant greyscale value for each label
enc_msk = np.array(dec_msk) * 255 / (len(vals) - 1)

return enc_msk


def get_oasis_dataloaders(data_dir, batch_size, subset_size=-1):
IMG_TRAIN_PATH = "keras_png_slices_train"
MSK_TRAIN_PATH = "keras_png_slices_seg_train"
IMG_VAL_PATH = "keras_png_slices_validate"
MSK_VAL_PATH = "keras_png_slices_seg_validate"
IMG_TEST_PATH = "keras_png_slices_test"
MSK_TEST_PATH = "keras_png_slices_seg_test"

# Gather data files
ds_train = OasisDataset(data_dir + IMG_TRAIN_PATH, data_dir + MSK_TRAIN_PATH, subset_size=subset_size)
ds_validate = OasisDataset(data_dir + IMG_VAL_PATH, data_dir + MSK_VAL_PATH, subset_size=subset_size // 2 if subset_size else -1)
ds_test = OasisDataset(data_dir + IMG_TEST_PATH, data_dir + MSK_TEST_PATH, subset_size=subset_size // 2 if subset_size else -1)

# Generate data loaders
dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
dl_validate = DataLoader(ds_validate, batch_size=batch_size, shuffle=False)
dl_test = DataLoader(ds_test, batch_size=batch_size, shuffle=False)

return dl_train, dl_validate, dl_test
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/2d_oasis-46412384/figs/ex-celoss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/2d_oasis-46412384/figs/fin-celoss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
131 changes: 131 additions & 0 deletions recognition/2d_oasis-46412384/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class InConv(nn.Module):
def __init__(self, in_chls, out_chls):
super().__init__()

# Use convolution with a 3x3 kernel, padding width of 1 and no dilation
# This suits the MRI data as its patterns are closely related to its near neighbours
# instead of regarding the entire image

# Use the ReLU activation function as a suitable compromise on computational efficiency and
# model complexity while trying to replicate the sigmoid function
self.net = nn.Sequential(
nn.Conv2d(in_chls, out_chls, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_chls),
nn.ReLU(inplace=True),
nn.Conv2d(out_chls, out_chls, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_chls),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.net(x)


class DownLayer(nn.Module):
def __init__(self, in_chls, out_chls):
super().__init__()
# use 2x2 max pooling for data downscaling
self.pool = nn.MaxPool2d(kernel_size=2)
self.conv = InConv(in_chls, out_chls)

def forward(self, x):
return self.conv(self.pool(x))


class UpLayer(nn.Module):
def __init__(self, in_chls, out_chls):
super().__init__()
self.up = nn.ConvTranspose2d(in_chls, in_chls // 2, kernel_size=2, stride=2)
self.conv = InConv(in_chls, out_chls)

def forward(self, x, skip):
x = self.up(x)

dx = skip.size(3) - x.size(3)
dy = skip.size(2) - x.size(2)

if (dx or dy):
x = F.pad(x, [dx//2, dx - dx//2, dy//2, dy-dy//2])
x = torch.cat([skip, x], dim=1)
return self.conv(x)


class OutConv(nn.Module):
def __init__(self, in_chls, out_chls):
super().__init__()
# Final 1x1 convolution to establish logits
self.conv = nn.Conv2d(in_chls, out_chls, kernel_size=1)

def forward(self, x):
return self.conv(x)


class UNet(nn.Module):

def __init__(self, in_chls=1, num_classes=4, base_chls=64):
super().__init__()

c = [base_chls * 2**i for i in range(5)]

# Network architecture

self.input = InConv(in_chls, base_chls)
self.down1 = DownLayer(c[0], c[1])
self.down2 = DownLayer(c[1], c[2])
self.down3 = DownLayer(c[2], c[3])
self.down4 = DownLayer(c[3], c[4])

self.up1 = UpLayer(c[4], c[3])
self.up2 = UpLayer(c[3], c[2])
self.up3 = UpLayer(c[2], c[1])
self.up4 = UpLayer(c[1], c[0])
self.output = OutConv(c[0], num_classes)

def forward(self, x):
x1 = self.input(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)

# Upsampling with skips from downsampling train
return self.output(
self.up4(
self.up3(
self.up2(
self.up1(x5, x4),
x3),
x2),
x1)
)


class MCDiceLoss(nn.Module):
def __init__(self, eps=1e-6):
super().__init__()
self.eps = eps

def forward(self, logits, target):
# Convert raw output to probabilities using softmax function, ideal for predictions over several classes
probs = logits.softmax(dim=1)
# Batch size, num classes, height, width
B, C, H, W = probs.shape

# One-hot encoding to finalise predictions, followed by dimension adjustments
target_oh = F.one_hot(target, num_classes=C).permute(0, 3, 1, 2).float()

probs_flat = probs.reshape(B, C, -1)
target_flat = target_oh.reshape(B, C, -1)

# Cross product yields the intersect
intersect = (probs_flat * target_flat).sum(dim=-1)
total = probs_flat.sum(dim=-1) + target_flat.sum(dim=-1)

dice = (2 * intersect + self.eps) / (total + self.eps)

return 1 - dice.mean()

Loading