Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e256674
First commit with readme and required python files
baulch-m Oct 26, 2025
4662353
added a working data loader for the OASIS dataset
baulch-m Oct 26, 2025
c4d7339
added a rough convolution component for unet model
baulch-m Oct 26, 2025
517f486
Updated readme and added gitignore
baulch-m Oct 27, 2025
4164be2
re-formatted modules.py to include encoding and decoding and added a …
baulch-m Oct 27, 2025
328ff5b
added a trainer for the unet model
baulch-m Oct 27, 2025
be0b61b
added hipmri study to gitignore
baulch-m Oct 30, 2025
8bfbe7b
changing task to using 2D UNet for HipMRI Prostate Cancer study - made
baulch-m Oct 30, 2025
fc65488
added segmentation masks to the loader
baulch-m Oct 30, 2025
a9edd13
converted hipmri data into pytorch dataloaders for compatibility
baulch-m Oct 30, 2025
7849a5d
cached loaded datasets and changed train.py to work ith hipmroo datas…
baulch-m Oct 30, 2025
693782d
deleted diceloss class, added residual blocks in module.py
baulch-m Nov 2, 2025
99257e2
reconfigured dataset.py to account for HIPMRI folder, deleted old OAS…
baulch-m Nov 2, 2025
2c63713
added plotting to train.py and reconfigured for dice loss function
baulch-m Nov 2, 2025
06070ee
removed dice loss calculations and obsolete DoubleConv class
baulch-m Nov 2, 2025
afc2007
added prediction functionality (model evaluation)
baulch-m Nov 2, 2025
3f60a55
removed DoubleConv class for real this time
baulch-m Nov 2, 2025
72cdd1f
updated readme and output images (training and dice loss and best model)
baulch-m Nov 2, 2025
79c9e1f
added some output predictions to /outputs for visualisation examples
baulch-m Nov 3, 2025
24ab356
Added functionality for prediction visualisation
baulch-m Nov 3, 2025
b9b7262
Refactored some class names in dataset.py
baulch-m Nov 3, 2025
352390f
updated README
baulch-m Nov 3, 2025
71dacab
finalised module and training architectures
baulch-m Nov 3, 2025
b9bc413
updated README
baulch-m Nov 3, 2025
0172757
updated gitignore
baulch-m Nov 3, 2025
b866cae
test commit to get pr to work
baulch-m Nov 3, 2025
fa9f31e
dummy commit
baulch-m Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
keras_png_slices_data/
HipMRI_Study_open/
__pycache__/
308 changes: 293 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,298 @@
# Pattern Analysis
Pattern Analysis of various datasets by COMP3710 students in 2025 at the University of Queensland.
# HipMRI_Study Segmentation with Improved U-Net (Task 3)

We create pattern recognition and image processing library for Tensorflow (TF), PyTorch or JAX.
## Author

This library is created and maintained by The University of Queensland [COMP3710](https://my.uq.edu.au/programs-courses/course.html?course_code=comp3710) students.
Marcus Baulch (47445464)
COMP3710 - Pattern Recognition and Analysis
The University of Queensland

The library includes the following implemented in Tensorflow:
* fractals
* recognition problems
## Overview

In the recognition folder, you will find many recognition problems solved including:
* segmentation
* classification
* graph neural networks
* StyleGAN
* Stable diffusion
* transformers
etc.
This project implements an Improved U-Net for multi-class semantic segmentation of Prostate MRI images. The model segments anatomical structures in 2D MRI slices into 6 distinct classes, achieving performance through residual connections, batch normalisation, and a combined loss function.

### Key Features

- **Residual U-Net Architecture**: Enhanced U-Net with ResNet-style skip connections within encoder/decoder blocks
- **Multi-Class Segmentation**: 6-class semantic segmentation
- **Combined Loss Function**: 60% Dice Loss + 40% Cross-Entropy for balanced optimisation
- **Data Augmentation**: Random flips and rotations during training
- **Comprehensive Evaluation**: Dice coefficient metrics with visualisation capabilities

---

## Dataset and Preprocessing

### Hip MRI Study Dataset

The project uses the HipMRI Study Open Dataset, which contains MRI scans with semantic labels for male pelvises. The data was retrieved from https://data.csiro.au/collection/csiro:51392v2?redirected=true (see reference at the end).

**Dataset Structure:**
```
HipMRI_Study_open/
├── keras_slices_data/
│ ├── keras_slices_train/ # Training images
│ ├── keras_slices_seg_train/ # Training masks
│ ├── keras_slices_validate/ # Validation images
│ ├── keras_slices_seg_validate/ # Validation masks
│ ├── keras_slices_test/ # Test images
│ └── keras_slices_seg_test/ # Test masks
└── semantic_labels_only/ # Original 3D NIfTI files
```

### Preprocessing

1. **Image Loading**: NIfTI (.nii.gz) files loaded with `nibabel`
2. **Normalisation**: Images standardised using z-score normalisation: `(x - mean) / std` [1]
3. **One-Hot Encoding**: Masks converted to 6-channel one-hot format `[B, 6, H, W]`
4. **Data Augmentation** (training only) [2]:
- Random horizontal/vertical flips (50% probability each)
- Random rotation of ±15 degrees
- Geometric transforms applied consistently to image-mask pairs

### Train/Validation/Test Split

The dataset uses a predefined split provided by the HipMRI Study Open Dataset:
- **Training set**: 11,460 images
- **Validation set**: 660 images
- **Test set**: 540 images

**Justification**:
- The dataset came pre-split, so no manual splitting was required
- 90/5/5 split is standard for medical imaging datasets
- Training set is large enough to learn robust features
- Validation set (660 samples) is sufficient
- Test set (540 samples) provides statistically meaningful evaluation


## Model Architecture

### Residual U-Net

The model improves upon standard U-Net with residual blocks and batch normalisation:

```
Input (1 channel, grayscale MRI)

Encoder Path (with residual blocks):
ResBlock(1→64) → MaxPool
ResBlock(64→128) → MaxPool
ResBlock(128→256) → MaxPool
ResBlock(256→512) → MaxPool
ResBlock(512→1024) [Bottleneck]

Decoder Path (with skip connections):
UpConv + Concat → ResBlock(1024→512)
UpConv + Concat → ResBlock(512→256)
UpConv + Concat → ResBlock(256→128)
UpConv + Concat → ResBlock(128→64)

Final Conv(64→6)

Output (6 channels, class logits)
```

### Residual Block Details

Each ResidualBlock consists of:
```
Input
├─ Conv3x3 → BatchNorm → ReLU → Conv3x3 → BatchNorm → (+)
└─ [1x1 Conv if channels mismatch] ────────────────────→ ReLU → Output
```


## Training

### Configuration

| Parameter | Value |
|-----------|-------|
| **Batch Size** | 16 |
| **Epochs** | 20 |
| **Learning Rate** | 1e-4 |
| **Optimiser** | Adam |
| **Loss Function** | 60% Dice + 40% CrossEntropy |
| **Device** | CUDA (if available) / CPU |

### Loss Function

The combined loss leverages strengths of both components:

- **Dice Loss**: Directly optimises the evaluation metric (Dice coefficient)
- **Cross-Entropy**: Provides stable pixel-wise classification gradients


### Training Script

```bash
python train.py
```

**Outputs:**
- `outputs/best_model.pth` - Best model checkpoint (highest validation Dice)
- `outputs/prediction_XXX.png/` - Predicted visualisations (saved PNGs)

### Output Visualisations

![Prediction - slice 01](outputs/prediction_000.png)
*Input | Ground truth | Model prediction*

![Prediction - slice 02](outputs/prediction_001.png)


![Prediction - slice 03](outputs/prediction_002.png)

- `outputs/training_curves.png` - Loss and Dice score plots
![Dice Loss Curves](outputs/training_curves.png)

---

## Evaluation

### Metrics

This model was trained on only 5 epochs, as it reaches the minimum dice coefficient of 0.75 very quickly.

**Dice Coefficient** (primary metric):
```
Dice = (2 × |Prediction ∩ Ground Truth|) / (|Prediction| + |Ground Truth|)
```

Calculated per-class and averaged across all 6 classes for final score.

### Running Evaluation

```bash
python predict.py
```

**Features:**
- Loads best model from `outputs/best_model.pth`
- Evaluates on test set
- Reports mean, std, min, max Dice scores
- Saves prediction visualisations


## Results

### Performance Metrics

The following is an output from predict.py:
```
======================================================================
TRAINING COMPLETED
======================================================================
Best Validation Dice: 0.8654
======================================================================

======================================================================
TEST SET EVALUATION
======================================================================

Test Loss: 0.2150
Test Dice: 0.8777
======================================================================

Training curves saved to: outputs/training_curves.png
FINAL SUMMARY
======================================================================
Best Validation Dice: 0.8654
Test Set Dice: 0.8777
Model saved to: ./outputs/best_model.pth
Plots saved to: ./outputs/training_curves.png
======================================================================
```
The model provided an average Dice coefficient of 0.877 per label (averaged over 6 classes), which exceeds the 0.75 dice coefficient requirement for this task.

### Training Curves

Training and validation loss/Dice curves are automatically saved to `outputs/training_curves.png` after training completes.


## Project Structure

```
COMP3710-Report/
├── train.py # Training script
├── predict.py # Evaluation script
├── modules.py # Residual U-Net architecture
├── dataset.py # Dataset loader with augmentations
├── utils_visualize.py # Visualisation utilities
├── check_predictions.py # Quick prediction checker
├── README.md # This file
├── LICENSE # Project license
└── outputs/ # Training outputs
├── best_model.pth
├── training_curves.png
└── prediction_XXX.png #variable amount of prediction visualisations

```

---

## Requirements

### Python Dependencies

```
torch>=1.9.0
torchvision>=0.10.0
numpy>=1.19.0
nibabel>=3.2.0
matplotlib>=3.3.0
tqdm>=4.60.0
scipy>=1.5.0
```

### Installation

```bash
pip install torch torchvision numpy nibabel matplotlib tqdm scipy
```

---

## Usage

## Hardware + Runtime

This project made use of UQ's Rangpur cluster, namely an a100 GPU. The following bash script was used to run it:
```
#!/bin/bash
#SBATCH --partition=a100
#SBATCH --gres=gpu:1
#SBATCH --job-name=hipmri
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8

#SBATCH --output=task3.out
#SBATCH --error=task3.err

conda activate torch
python train.py
```

### Runtime Estimates

| Task | GPU Time | CPU Time |
|------|----------|----------|
| **Training (20 epochs)** | ~10-15 min | ~1-2 hours |
| **Evaluation (test set)** | ~10-30 sec | ~1-2 min |
| **Single prediction** | <1 sec | ~1 sec |


### Device Selection

The code automatically detects and uses CUDA if available:

```python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
```

---

### References
COMP3710 Teaching Team, 2025. Retrieved from https://colab.research.google.com/drive/1VOsZSyRhyuHLmgoqGriQk01ub4bKNmZ1?usp=sharing

Dowling, J. & Greer, P. (2014). Labelled weekly MR images of the male pelvis. Retrieved from https://data.csiro.au/collection/csiro:51392v2?redirected=true
Loading