Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
f7853da
Initial commit
shakes76 Sep 21, 2025
ba3550f
Updated text in README
shakes76 Sep 21, 2025
6d8b074
Added recognition branch for tasks
shakes76 Sep 21, 2025
4b689d2
init folders
CleoBriez Oct 15, 2025
0d79af7
init folders
CleoBriez Oct 15, 2025
b8f3812
added header blocks to each file and initialised the README
CleoBriez Nov 2, 2025
c4ee232
added header blocks to each file and initialised the README
CleoBriez Nov 2, 2025
b65f142
added sample load 2D Data code
CleoBriez Nov 2, 2025
13266c1
added sample load 2D Data code
CleoBriez Nov 2, 2025
4943208
updated sample code formatting
CleoBriez Nov 3, 2025
e5c345f
updated sample code formatting
CleoBriez Nov 3, 2025
59c6b32
added data locally for testing
CleoBriez Nov 3, 2025
23d6bef
added data locally for testing
CleoBriez Nov 3, 2025
bf707ca
realised adding the data locally was dumb, removed now
CleoBriez Nov 3, 2025
83fa8cb
realised adding the data locally was dumb, removed now
CleoBriez Nov 3, 2025
049c5fd
loaded dataset into model
CleoBriez Nov 3, 2025
ae9e975
loaded dataset into model
CleoBriez Nov 3, 2025
7b87e56
ghost commit to remove changes
CleoBriez Nov 3, 2025
e458267
ghost commit to remove changes
CleoBriez Nov 3, 2025
88fe28b
Added dataset class to dataset.py
CleoBriez Nov 3, 2025
786e96d
Added dataset class to dataset.py
CleoBriez Nov 3, 2025
4ae95fa
Refined dataset class to return slices properly
CleoBriez Nov 3, 2025
462f63d
Refined dataset class to return slices properly
CleoBriez Nov 3, 2025
15944bb
Dataset draft complete, starting on modules
CleoBriez Nov 3, 2025
95d5b1a
Dataset draft complete, starting on modules
CleoBriez Nov 3, 2025
031838c
Cleaning up emoty files
CleoBriez Nov 3, 2025
a7d9fa5
Cleaning up emoty files
CleoBriez Nov 3, 2025
5805f9c
Save to restart for cuda install
CleoBriez Nov 4, 2025
7ea1d24
Save to restart for cuda install
CleoBriez Nov 4, 2025
20b7c36
Initial train and predict and module update
CleoBriez Nov 4, 2025
e5f0d67
Initial train and predict and module update
CleoBriez Nov 4, 2025
ddaa356
Training and prediction *refactor*
CleoBriez Nov 4, 2025
8417fa2
Training and prediction *refactor*
CleoBriez Nov 4, 2025
3df6c75
Added dependencies
CleoBriez Nov 4, 2025
7ad9af3
Added dependencies
CleoBriez Nov 4, 2025
375d937
Fixed indentation issue
CleoBriez Nov 4, 2025
f1f91a0
Fixed indentation issue
CleoBriez Nov 4, 2025
ef466dd
Added util files for visualisations
CleoBriez Nov 5, 2025
2c3a8ee
Added util files for visualisations
CleoBriez Nov 5, 2025
bd2b070
Visualisation debugging
CleoBriez Nov 5, 2025
42a84c1
Visualisation debugging
CleoBriez Nov 5, 2025
d5f6130
Path update
CleoBriez Nov 5, 2025
dda5f78
Path update
CleoBriez Nov 5, 2025
f937421
Updated testing
CleoBriez Nov 5, 2025
1d9ed17
Updated testing
CleoBriez Nov 5, 2025
e4f80e4
Dataset updated to adapt to more classes as needed
CleoBriez Nov 6, 2025
da23250
Dataset updated to adapt to more classes as needed
CleoBriez Nov 6, 2025
f7ff621
Dataset accounts for mismatches nifti files
CleoBriez Nov 6, 2025
501a33f
Dataset accounts for mismatches nifti files
CleoBriez Nov 6, 2025
da4e186
Fixed first_n issue not setting num properly
CleoBriez Nov 6, 2025
4050dbb
Fixed first_n issue not setting num properly
CleoBriez Nov 6, 2025
4ac676d
Added Attention Games to module
CleoBriez Nov 6, 2025
32f81bf
Added Attention Games to module
CleoBriez Nov 6, 2025
aba195b
Developed training script WiP
CleoBriez Nov 6, 2025
92e69ce
Developed training script WiP
CleoBriez Nov 6, 2025
3969f21
Training script WiP
CleoBriez Nov 6, 2025
4c51d8c
Training script WiP
CleoBriez Nov 6, 2025
1217a54
Fixed errors in the train script
CleoBriez Nov 6, 2025
6563a57
Fixed errors in the train script
CleoBriez Nov 6, 2025
8033925
Fixed structure across docs and added comments for dataset.py
CleoBriez Nov 6, 2025
d6a1a5a
Fixed structure across docs and added comments for dataset.py
CleoBriez Nov 6, 2025
69aa0b4
Added comments to train.py
CleoBriez Nov 6, 2025
29ad993
Added comments to train.py
CleoBriez Nov 6, 2025
2a4a31f
Updated utils.pu docstrings (I remember what they're called now)
CleoBriez Nov 6, 2025
c1cf65c
Updated utils.pu docstrings (I remember what they're called now)
CleoBriez Nov 6, 2025
0cb9b81
Cleaned up docsctrongs and imports
CleoBriez Nov 6, 2025
8a04942
Cleaned up docsctrongs and imports
CleoBriez Nov 6, 2025
ad67aad
Consistency pass
CleoBriez Nov 6, 2025
3a77454
Consistency pass
CleoBriez Nov 6, 2025
0867f13
Comments updated and predict WiP
CleoBriez Nov 6, 2025
35fc6dd
Comments updated and predict WiP
CleoBriez Nov 6, 2025
fd46ccc
Added gitignore for pycache files
CleoBriez Nov 6, 2025
68f9ee7
Added gitignore for pycache files
CleoBriez Nov 6, 2025
fbefbc8
Added mean and std dev for the dataset for the predict
CleoBriez Nov 6, 2025
b3ac6b7
Added mean and std dev for the dataset for the predict
CleoBriez Nov 6, 2025
9e09689
I think I die here
CleoBriez Nov 6, 2025
02cb89c
I think I die here
CleoBriez Nov 6, 2025
1223141
Unified testing across the project
CleoBriez Nov 7, 2025
7dbf81d
Unified testing across the project
CleoBriez Nov 7, 2025
ef630df
Updated dataset DataLoader and enabled saving of sufficient models
CleoBriez Nov 7, 2025
ad9609d
Updated dataset DataLoader and enabled saving of sufficient models
CleoBriez Nov 7, 2025
2b672f5
Added space in training
CleoBriez Nov 7, 2025
d9cd130
Added space in training
CleoBriez Nov 7, 2025
236cd93
Developed predict and finalised training
CleoBriez Nov 7, 2025
16b54e7
Developed predict and finalised training
CleoBriez Nov 7, 2025
c1488c5
Saved model and pred visualisations
CleoBriez Nov 7, 2025
f02de06
Updated pred visualisation
CleoBriez Nov 7, 2025
22af77b
Updated versions and readme WiP
CleoBriez Nov 7, 2025
5af130f
Deleted pycache and vscode files
CleoBriez Nov 7, 2025
888f835
Updated readme for submission
CleoBriez Nov 7, 2025
fbc5636
removed model for upload
CleoBriez Nov 22, 2025
f797587
Merge branch 'topic-recognition' of https://github.com/CleoBriez/3710…
CleoBriez Nov 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/
92 changes: 92 additions & 0 deletions recognition/43711451_HipMRI2D_AttentionUNET/README.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Title
2D HipMRI Dataset Segmented using Attention U-Net

# Problem Description
Segment the HipMRI Study on Prostate Cancer (see Appendix for link) using the processed 2D slices (2D images) with the 2D CAN with all labels having a minimum Dice similarity coefficient of 0.75 on the test set on the prostate label. You will need to load Nifti file format and sample code provided.

# Algorithm Description
## Data Preparation
Data was pre-sorted into training, validation and testing datasets as well as segment masks for each.
The data was wrapped as if it were 3D in order to standardise the inputs as they were multiclass rather than binary.
After it was wrapped, testing data was shuffled, normalised and put into a dataset and dataloader for training and validation.

## Model Architecture
### Filters
The encoder steps were filtered from [64, 128, 256, 512, 1024], the inverse of [1024, 512, 256, 128, 64] was there for the decoder
### Downsampling (Encoder)
The encoder carries the input through the double convultions and skip connections through each filter.

### Upsampling (Decoder)
The decoder acts inversely though to the encoder and receives the input though the skip connections and the bottleneck.

### Bottleneck
The bottleneck applies the double convolutions using the 3x3 kernel.
The output is then upsampled through the decoder.

### Attention Gates
As I've tried to make this an Attention U-Net by using Attention gates to add additional context to the U-Net structure by adding a layer of mapping semantic context from the decoder and spatial detail from the encoder alongside the regular connections.

## Model Training
Model was trained using all training samples and validated using the validation samples.
A random seed was also applied to try and generalise the training and prevent issues with just learning the training data.

## Model Performance
The model was run and saved using batches of size 16 and initially 25 epochs using the full testing and validation sample.
I actually got 25 epochs out of it but messed up my save location so it didn't save. I really would've liked to refine that :c.
I was running on an Nvidia RTX 3060Ti with 8GB of VRAM.
After 20~ Epochs the was overfitting and this could have been optimised better by experimenting with smaller subsets with increased epochs.

## Model Testing
After training, the model was tested using the predict function to then be put ont

# How it Works
## File Structure
.
├── recognition
│ ├── mask_output
│ │ ├── pred_mask_case_040_week_0_slice_0.nii.gz
│ │ └── pred_mask_case_040_week_0_slice_0.nii.png
│ │ ├── pred_mask_case_040_week_1_slice_0.nii.gz
│ │ └── pred_mask_case_040_week_1_slice_0.nii.png
│ │ ├── pred_mask_case_040_week_2_slice_0.nii.gz
│ │ └── pred_mask_case_040_week_2_slice_0.nii.png
│ ├── saved_models
│ │ └── full_set_5_epochs.pth
│ ├──dataset.py
│ ├──modules.py
│ ├──predict.py
│ ├──README.MD
│ ├──train.py
│ └── utils.py
├── LICENSE
└── README.md

## Requirements & Dependencies
Torch (for Keras layers, models, and tensoring)
NumPy (for numerical operations)
Matplotlib (for plotting)
NiBabel (for neuroimaging data handling)
NiLearn (For image resampling)
Tqdm (for progress bars)
Pathlib (for filesystem path manipulations)
Random (for random seeding)

## Future Improvements
I honestly would really like to adapt this towards 3D Datasets and visualising it in nicer/funner ways.
I also would like to just clean up the hard-coded parameters and make it nicer to just run in a more central way.
I would've liked to add functionality to the predict section to compare the created masks to the provided one in a more clear way like analysing the area that was accurate.
Honestly a lot of the predict section was really exciting but I didn't really expand on it as much as I would've liked to.
Tbh this doc too ;~;

## Usage & Reproduction Steps
For Training:
python3 train.py
For predictions and Visualisations
python3 predict.py

Hyperparameters are present at the top of both to allow for customisation to how it's run.
Set the dataset locations in the dataset.py path variable too.
Subsets are able to be configured using the SUBSET param to make it smaller. (0 = all the data)

# Visualisations
Output prediction visualisations are present in the 'recognition\43711451_HipMRI2D_AttentionUNET\mask_output' folder
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
274 changes: 274 additions & 0 deletions recognition/43711451_HipMRI2D_AttentionUNET/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
# recognition\43711451_HipMRI2D_AttentionUNET\dataset.py
"""
Contains the data loader and preprocessing for the HipMRI 2D Slice Dataset to be used by the model
"""

import numpy as np
import nibabel as nib
from nibabel import Nifti1Image
from nilearn.image import resample_to_img
from tqdm import tqdm
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader

__author__ = "Cleodora Kizmann"
__copyright__ = "Copyright 2025, Cleodora Kizmann"
__credits__ = ["Cleodora Kizmann"]
__license__ = "Apache License 2.0"
__version__ = "1.0.1"
__maintainer__ = "Cleodora Kizmann"
__email__ = "[email protected]"
__status__ = "Prototype"

# Dataset path
path = "D:/keras_slices_data/keras_slices_" # Adjust this path as needed

# Hyperparameters
BATCH_SIZE = 16 # I got 8GB VRAM on my GPU so I might be pushing this a little
SUBSET = 25

def to_channels(arr: np.ndarray, num_classes: int, dtype = np.uint8)-> np.ndarray:
"""
Converts an integer label array into a one-hot encoded array.

Args:
arr: The input 2D mask array (H, W).
num_classes: The total number of classes.
dtype: The data type of the output array.

Returns:
A one-hot encoded array of shape (H, W, num_classes).
"""
res = np.zeros(arr.shape + (num_classes,), dtype = dtype)

for c in range(num_classes):
# Set the channel 'c' to 1 where the input array has label 'c'
res[..., c] = (arr == c)
return res

def standardise(img_path):
"""
Helper for if the file is (H, W, 1), it rebuilds it as (H, W, 1).

Args:
img_path: Path to the NIfTI image file.
Returns:
A Nifti1Image object with standardized dimensions.
"""
nii = nib.load(img_path)

if len(nii.shape) == 2:
# Data is 2D (H, W). We need to make it 3D (H, W, 1).
data_2d = nii.get_fdata(caching = "unchanged") # Shape (H, W)
data_3d = np.expand_dims(data_2d, axis = -1) # Shape (H, W, 1)

# Re-create the NIfTI object with the new 3D data
new_nii = Nifti1Image(data_3d, nii.affine, nii.header)

# Manually update the header to reflect the 3D shape
new_nii.header.set_data_shape(data_3d.shape)
return new_nii
elif len(nii.shape) == 3:
# It's already 3D, just return it.
return nii
return nii

# load medical image functions
def load_data_2D(imageNames, normalise = False, categorical = False, num_classes = None, dtype = np.float32, getAffines = False, first_n = 0):
"""
Load medical image data from names, cases list provided into a list for each.
Altered to account for slices being different sizes, by resampling to a template image.

Args:
imageNames: list of paths to NIfTI image files
normalise: bool (normalise the image 0.0-1.0)
categorical: bool (If True, 'num_classes' must also be provided)
num_classes: int (The total number of classes for one-hot encoding, e.g., 6)
getAffines: bool (Return the affine matrices along with the images)
first_n: int (Stop loading after n images for quick loading and testing scripts)

Returns:
images: np.ndarray of shape (N, H, W) or (N, H, W, C) depending on 'categorical'
affines: list of affine matrices (if getAffines is True)
"""
# Validate mask and classes inputs
if categorical and num_classes is None:
raise ValueError("You should specify the number of classes when loading categorical mask data.")

affines = [] # Spatial coordinates list

# Load a template image to get dimensions
try:
template_nifti = standardise(imageNames[0])
except Exception as e:
print(f"Error loading template image: {imageNames[0]}. {e}")
return

num = len(imageNames) if first_n == 0 else first_n

first_case = template_nifti.get_fdata(caching="unchanged")

if len(first_case.shape) == 3:
first_case = first_case [:,:,0] # sometimes extra dims, remove to keep 2D slice

if categorical:
# first_case = to_channels(first_case, dtype = dtype)
rows, cols = first_case.shape
channels = num_classes
images = np.zeros((num, rows, cols, channels), dtype = dtype)
else:
rows, cols = first_case.shape
images = np.zeros((num, rows, cols), dtype = dtype)

if categorical:
interpolation = "nearest" # Preserve integer labels
else:
interpolation = "linear" # Average pixels for smooth image

for i, inName in enumerate(tqdm(imageNames[:num])):
niftiImage = standardise(inName) # Loads the image
# resampled nifti to match template
resampled_nifti = resample_to_img(
niftiImage,
template_nifti,
interpolation = interpolation,
# Suppressing annoying warnings
force_resample=True,
copy_header=True
)
# Get data from the *resampled* image
inImage = resampled_nifti.get_fdata(caching = "unchanged") # read disk only
affine = resampled_nifti.affine
if len(inImage.shape) == 3:
inImage = inImage [:,:,0] # sometimes extra dims in HipMRI_study data
inImage = inImage.astype(dtype)

if normalise and not categorical:
# ~ inImage = inImage / np.linalg.norm(inImage)
# # ~ inImage = 255. * inImage / inImage.max ()
inImage = (inImage - inImage.mean()) / inImage.std()
elif(normalise and categorical):
raise ValueError("You probably didn't mean to normalise categorical mask data.")

if categorical:
inImage = to_channels(inImage, num_classes = num_classes, dtype = dtype)
images[i, :, :, :] = inImage
else:
images [i,:,:] = inImage
affines.append(affine)

if first_n != 0 and i == first_n:
break

if getAffines:
return images, affines
else:
return images

class HipMRI2D(Dataset):
"""
Dataset class for segmentation for HipMRI 2D dataset.

Args:
dataset: str, one of "train", "validate", "test" to specify which dataset to load.
first_n: int, number of samples to load for quick testing (default: 0, load all).

Returns:
A PyTorch Dataset object that can be used with DataLoader for training/validation/testing.
"""
def __init__(self, dataset, first_n = 0):
"""
Initialize the HipMRI2D dataset.

Args:
dataset: str, one of "train", "validate", "test" to specify which dataset to load.
first_n: int, number of samples to load for quick testing (default: 0, load all).
"""
self.dataset = load_data_2D(sorted(Path(path + dataset).glob("*.gz")), normalise = True, categorical = False, first_n = first_n) # Shape (N, H, W)
self.mask_data = load_data_2D(sorted(Path(path + "seg_" + dataset).glob("*.gz")), normalise = False, categorical = True, num_classes = 6, first_n = first_n) # Shape (N, H, W, C)
self.num_classes = self.mask_data.shape[-1]

print(f"Image array shape: {self.dataset.shape}") # e.g., (100, 256, 128)
print(f"Mask array shape: {self.mask_data.shape}") # e.g., (100, 256, 128, 6)

def __len__(self):
"""
Returns the total number of samples in the dataset.

Returns:
int: Number of samples in the dataset.
"""
return len(self.dataset)

def __getitem__(self, index):
"""
Retrieve the image and corresponding mask at the specified index.

Args:
index: Index of the sample to retrieve.

Returns:
A tuple (image, mask) where:
- image is the preprocessed image tensor.
- mask is the binary mask tensor for the hip region.
"""
# Get filename
image_np = self.dataset[index] # Shape (H, W)
mask_np = self.mask_data[index] # Shape (H, W, C)

image_tensor = torch.from_numpy(image_np).float()
mask_tensor = torch.from_numpy(mask_np).float()

image_tensor = image_tensor.unsqueeze(0) # (H, W) -> (1, H, W)

# Permute the mask from "channels-last" to "channels-first" becuase PyTorch expects (C, H, W)
mask_tensor = mask_tensor.permute(2, 0, 1) # (H, W, C) -> (C, H, W)

return image_tensor, mask_tensor

def get_mean(self):
return np.mean(self.dataset)

def get_std(self):
return np.std(self.dataset)

class LoadData(DataLoader):
"""
Custom DataLoader for the HipMRI2D dataset.
Inherits from torch.utils.data.DataLoader.
"""
def __init__(self, dataset, first_n = SUBSET, batch_size=BATCH_SIZE, shuffle = True):
"""
Initialize the DataLoader.

Args:
dataset: An instance of the HipMRI2D dataset.
batch_size: Number of samples per batch to load (default: 16).
shuffle: Whether to shuffle the data at every epoch (default: True).
"""
super().__init__(dataset=HipMRI2D(dataset, first_n = first_n), batch_size=batch_size, shuffle=shuffle)

def __getitem__(self, dataset = "train"):
"""
Load the dataset and return a DataLoader.

Args:
dataset: str, one of "train", "validate", "test" to specify which dataset to load.

Returns:
A DataLoader object for the specified dataset.
"""
retrieved_dataset = HipMRI2D(dataset, first_n = SUBSET)
loaded_data = DataLoader(retrieved_dataset, batch_size = BATCH_SIZE, shuffle = True)

return loaded_data

if __name__ == "__main__":
print("💛 Loading training data 💛")
LoadData(dataset = "train", first_n = SUBSET, batch_size = BATCH_SIZE, shuffle = True)
print("💚 Training data loading complete 💚")

print("💛 Loading validation data 💛")
LoadData(dataset = "validate", first_n = SUBSET, batch_size = BATCH_SIZE, shuffle = False)
print("💚 Validation data loading complete 💚")
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading