Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
71a143b
Add MRI dataset loader with DataLoader, transforms and visualisation.
olivermccarthy-uq Sep 19, 2025
8bdde45
Add VAE model skeleton with encoder, decoder, and reparameterization
olivermccarthy-uq Sep 19, 2025
9108dbf
Add VAE training script with loss function and model saving
olivermccarthy-uq Sep 19, 2025
d08113d
Add VAE visualization script to generate and display images from late…
olivermccarthy-uq Sep 19, 2025
efec82e
Enhance VAE training script with validation, final reconstruction dis…
olivermccarthy-uq Sep 19, 2025
988341a
Add epoch vs loss graph and final reconstruction visualization to VAE…
olivermccarthy-uq Sep 19, 2025
4b134b8
Add OASIS Improved UNet initial files
Oct 29, 2025
020106a
Implemented basic OASIS dataset sorting and loading.
Oct 29, 2025
9bbc00d
Updated modules.py to feature a basic improved UNet in PyTorch.
Oct 29, 2025
cafbe1a
Set up train.py script for training with UNet model, data loader, los…
Oct 29, 2025
abd8f7f
Implemented predict.py to run inference with the trained model and vi…
Oct 29, 2025
e53bb3a
Add results image and also implemented a DataLoader for batching as t…
Oct 31, 2025
6de479f
Updated README.md to included required sections for the report.
Oct 31, 2025
f54fb23
Fixed directory structure formatting in the README.md
Oct 31, 2025
8318b37
Merge pull request #1 from olivermccarthy-uq/topic-recognition
olivermccarthy-uq Oct 31, 2025
c15a55e
Updated README after epoch increase to 30. Train.py updated for 30 ep…
Oct 31, 2025
39a5a7e
Merge pull request #2 from olivermccarthy-uq/topic-recognition
olivermccarthy-uq Oct 31, 2025
9e60583
Merge branch 'shakes76:main' into assignment-from-3710
olivermccarthy-uq Nov 29, 2025
5f855cb
Delete scripts/__pycache__ directory
olivermccarthy-uq Dec 2, 2025
3109ae9
Delete recognition/OASIS-ImprovedUNet/__pycache__ directory
olivermccarthy-uq Dec 2, 2025
a430009
Move OASIS-Improved-UNet into new directory
Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
import glob
import imageio

def load_oasis_data(image_dir, label_dir):
# Sort image and label files to ensure corresponding order
image_files = sorted(glob.glob(f"{image_dir}/*.nii.png"))
label_files = sorted(glob.glob(f"{label_dir}/*.nii.png"))

images = []
labels = []
for img_file, lbl_file in zip(image_files, label_files):
image = imageio.imread(img_file)
label = imageio.imread(lbl_file)
images.append(image)
labels.append(label)

images = np.stack(images, axis=0)
labels = np.stack(labels, axis=0)
return images, labels
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
95 changes: 95 additions & 0 deletions OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import torch.nn as nn

class DoubleConv(nn.Module):
"""(Conv => BN => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None, dropout=0.2):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(dropout)
)

def forward(self, x):
return self.double_conv(x)

class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels, dropout=0.2):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels, dropout=dropout)
)

def forward(self, x):
return self.maxpool_conv(x)

class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True, dropout=0.2):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2, dropout=dropout)
else:
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels, dropout=dropout)

def forward(self, x1, x2):
x1 = self.up(x1)
# Pad x1 if necessary to match x2's size
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)

class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

def forward(self, x):
return self.conv(x)

class ImprovedUNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True, dropout=0.2):
super().__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear

self.inc = DoubleConv(n_channels, 64, dropout=dropout)
self.down1 = Down(64, 128, dropout=dropout)
self.down2 = Down(128, 256, dropout=dropout)
self.down3 = Down(256, 512, dropout=dropout)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor, dropout=dropout)
self.up1 = Up(1024, 512 // factor, bilinear, dropout=dropout)
self.up2 = Up(512, 256 // factor, bilinear, dropout=dropout)
self.up3 = Up(256, 128 // factor, bilinear, dropout=dropout)
self.up4 = Up(128, 64, bilinear, dropout=dropout)
self.outc = OutConv(64, n_classes)

def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
54 changes: 54 additions & 0 deletions OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import matplotlib.pyplot as plt
from modules import ImprovedUNet
from dataset import load_oasis_data
from torch.utils.data import DataLoader, TensorDataset

def dice_coefficient(pred, target, epsilon=1e-6):
pred = pred.int()
target = target.int()
intersection = (pred & target).sum(dim=(1,2,3))
union = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3))
dice = (2 * intersection + epsilon) / (union + epsilon)
return dice.mean().item()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ImprovedUNet(n_channels=1, n_classes=1).to(device)
model.load_state_dict(torch.load('unet_epoch15.pth', map_location=device))
model.eval()

test_images, test_labels = load_oasis_data(
'/home/groups/comp3710/OASIS/keras_png_slices_test',
'/home/groups/comp3710/OASIS/keras_png_slices_seg_test'
)

images_tensor = torch.tensor(test_images, dtype=torch.float32).unsqueeze(1)/255.0
labels_tensor = torch.tensor(test_labels, dtype=torch.float32).unsqueeze(1)/255.0

# Use DataLoader for batching
test_ds = TensorDataset(images_tensor, labels_tensor)
test_dl = DataLoader(test_ds, batch_size=8, shuffle=False)
preds = []

with torch.no_grad():
for xb, _ in test_dl:
out = model(xb.to(device))
out = torch.sigmoid(out) > 0.5
preds.append(out.cpu())
preds = torch.cat(preds, dim=0)

# Compute Dice coefficient
dice = dice_coefficient(preds, labels_tensor)
print(f"Dice coefficient on test set: {dice:.4f}")

# Visualize 3 sample predictions
for i in range(3):
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(test_images[i], cmap='gray')
axs[0].set_title('Input')
axs[1].imshow(test_labels[i], cmap='gray')
axs[1].set_title('Ground Truth')
axs[2].imshow(preds[i][0], cmap='gray')
axs[2].set_title('Prediction')
plt.savefig(f'prediction_{i}.png')
plt.close(fig)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
MRI Dataset Loading and Preprocessing Module

PURPOSE: This file handles the loading and preprocessing of MRI brain scan images for training a VAE.
It creates a custom PyTorch Dataset class that can load PNG images from directories, apply
transforms (resizing, normalization), and create DataLoaders for efficient batch processing.

WHY IT'S NEEDED:
- Raw MRI images need to be preprocessed (resized to consistent dimensions, normalized)
- PyTorch requires a custom Dataset class to work with our file structure
- DataLoaders enable efficient batch processing during training
- Visualisation helps verify the data is loaded correctly
"""

import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt

# -----------------------------
# Custom Dataset Class for MRI Images
# -----------------------------
class MRIDataset(Dataset):
"""
Custom PyTorch Dataset class for loading MRI brain scan images.

This class inherits from PyTorch's Dataset and implements the required methods
to load PNG images from a directory, apply transforms, and return them as tensors.
"""

def __init__(self, folder, transform=None):
"""
Initialize the dataset with a folder containing PNG images.

Args:
folder (str): Path to directory containing PNG images
transform (callable, optional): Transform to apply to each image
"""
self.folder = folder
self.transform = transform
# Get all PNG files from the folder and sort them for consistent ordering
self.images = sorted([os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".png")])

# Ensure we have images to work with
if len(self.images) == 0:
raise ValueError(f"No PNG images found in folder: {folder}")

def __len__(self):
"""Return the number of images in the dataset."""
return len(self.images)

def __getitem__(self, idx):
"""
Load and return a single image at the given index.

Args:
idx (int): Index of the image to load

Returns:
torch.Tensor: Preprocessed image tensor
"""
# Load image and convert to grayscale (single channel)
img = Image.open(self.images[idx]).convert('L') # 'L' mode = grayscale

# Apply transforms if provided (resize, normalize, etc.)
if self.transform:
img = self.transform(img)
return img

# -----------------------------
# Data Path Configuration
# -----------------------------
# Get the directory where this script is located
script_dir = os.path.dirname(os.path.abspath(__file__)) # folder of this script

# Navigate to the data directory (two levels up from scripts folder)
data_base = os.path.abspath(os.path.join(script_dir, "../../keras_png_slices_data"))

# Define paths to train, validation, and test data folders
train_folder = os.path.join(data_base, "keras_png_slices_train")
validate_folder = os.path.join(data_base, "keras_png_slices_validate")
test_folder = os.path.join(data_base, "keras_png_slices_test")

# Verify the data paths exist
print("Train folder:", train_folder)
print("Exists?", os.path.exists(train_folder))

# -----------------------------
# Image Preprocessing Pipeline
# -----------------------------
transform = transforms.Compose([
transforms.Resize((64,64)), # Resize all images to 64x64 pixels for consistency
transforms.ToTensor(), # Convert PIL image to PyTorch tensor (0-1 range)
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1] range (better for VAE training)
])

# -----------------------------
# Dataset Creation
# -----------------------------
# Create dataset objects for train, validation, and test sets
train_dataset = MRIDataset(train_folder, transform=transform)
val_dataset = MRIDataset(validate_folder, transform=transform)
test_dataset = MRIDataset(test_folder, transform=transform)

# Display dataset sizes
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# -----------------------------
# Data Visualization
# -----------------------------
# Display a sample of training images to verify data loading
plt.figure(figsize=(8,8))
for i in range(4):
plt.subplot(2,2,i+1)
# .squeeze() removes the channel dimension for display (1,64,64) -> (64,64)
plt.imshow(train_dataset[i].squeeze(), cmap='gray')
plt.axis('off')
plt.show()

# -----------------------------
# DataLoader Creation
# -----------------------------
# Create DataLoaders for efficient batch processing during training
# DataLoaders handle batching, shuffling, and parallel data loading
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) # Shuffle for training
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False) # No shuffle for validation
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
Loading