diff --git a/recognition/UNet_Prostate_47222610/README.md b/recognition/UNet_Prostate_47222610/README.md new file mode 100644 index 000000000..584e27174 --- /dev/null +++ b/recognition/UNet_Prostate_47222610/README.md @@ -0,0 +1,238 @@ +# 2D Prostate Segmentation using Improved UNet on HipMRI Dataset + +**Project**: #3 - HipMRI 2D Segmentation with Improved UNet + +This project implements an **Improved Unet** architecture for automated prostate segmentation from MRI images using the HipMRI Study dataset. The goal is to achieve a Dice similarity coefficient of ≥ 0.75 on the prostate label (Class 3) in the test set. + +## Problem Description + +Medical image segmentation is crucial for radiotherapy planning in prostate cancer. This project segments four anatomical regions from 2D magnetic resonance imaging (MRI) slices: +- Class 0: Background +- Class 1: Body +- Class 2: Bone +- Class 3: **Prostate (Primary Target)** + +The Improved UNet architecture enhances the original UNet through architectural improvements. + +## Model Architecture + +### Improved UNet vs Standard UNet +The improved UNet incorporates several improvements over the original UNet: + +**Key Improvements:** +1. **Deeper Network**: There are 5 levels of encoding/ decoding in Improved UNet, but only 4 in standard UNet. +2. **Residual Connections**: Skip connections using residual blocks for better gradient flow. +3. **Instance Normalization**: More stable than batch normalization for small batch sizes. +4. **Leaky ReLU**: Prevents the ReLU function from failing on negative slopes (alpha = 0.01). +5. **Deep Supervision**: Additional loss at intermediate decoder layers. +6. **Context Module**: Additional context aggregation at bottleneck. + +### Network Architecture: +``` +Input: (N, 1, 256, 128) - Grayscale MRI images + +[Encoder Path - Downsampling] + Level 0: ResidualDoubleConv: 1 -> 64 channels (256×128) + MaxPool2d(2×2) + Level 1: ResidualDoubleConv: 64 -> 128 channels (128×64) + MaxPool2d(2×2) + Level 2: ResidualDoubleConv: 128 -> 256 channels (64×32) + MaxPool2d(2×2) + Level 3: ResidualDoubleConv: 256 -> 512 channels (32×16) + MaxPool2d(2×2) + Level 4 (Bottleneck): ResidualDoubleConv: 512 -> 1024 channels (16×8) + +[Context Aggregation Module] + Parallel dilated convolutions with rates [1, 2, 4, 8] + Receptive fields: 3×3, 7×7, 15×15, 31×31 + Aggregated multi-scale features (1024 channels) + +[Decoder Path with Deep Supervision] + Level 3: TransposeConv + Skip + ResidualDoubleConv: 1024 -> 512 (32×16) + ├─ Auxiliary Output: DSV4 (512 -> 4 classes) + + Level 2: TransposeConv + Skip + ResidualDoubleConv: 512 -> 256 (64×32) + ├─ Auxiliary Output: DSV3 (256 -> 4 classes) + + Level 1: TransposeConv + Skip + ResidualDoubleConv: 256 -> 128 (128×64) + ├─ Auxiliary Output: DSV2 (128 -> 4 classes) + + Level 0: TransposeConv + Skip + ResidualDoubleConv: 128 -> 64 (256×128) + ├─ Auxiliary Output: DSV1 (64 -> 4 classes) + +[Output Layer] + 1×1 Convolution: 64 -> 4 channels + Output: (N, 4, 256, 128) - Class logits +``` + +### Key Architectural Components + +**1. Residual Blocks** +- Two 3x3 convolutions with skip connections +- Enables gradient flow in deep networks + +**2. Instance Normalization** +- Normalizes per sample +- More stable than Batch Normalization for medical imaging + +**3. Context Aggregation** +- Parallel dilated convolutions at bottleneck +- Captures features at multiple scales (3x3 to 31x31) + +**4. Deep Supervision** +- Auxiliary outputs at 5 decoder levels +- Loss weights: 1.0, 0.8, 0.6, 0.4, 0.2 + +## Dataset + +**Source**: HipMRI Study on Prostate Cancer + +**Format**: NIfTI (.nii.gz) + +**Data Splits**: +- Training: 11,460 slices +- Validation: 660 slices +- Testing: 540 slices + +**Preprocessing**: +1. Load NIFTI files with nibabel +2. Resize to 256x128 +3. Z-score normalization: '(img - mean) / std' +4. Clean invalid labels (≥4 -> class 0) +5. One-hot encode to 4 classes + +## Dependencies +```bash +torch>=2.0.0 +numpy>=1.24.0 +nibabel>=5.0.0 +matplotlib>=3.7.0 +opencv-python>=4.7.0 +tqdm>=4.65.0 +``` + +## Project Structure +``` +UNet_Prostate_47222610/ +├── README.md # This file +├── dataset.py # Data loading and preprocessing for MRI slices +├── modules.py # Improved UNet architecture +├── predict.py # Testing and visualization +├── train.py # Training with deep supervision +└── Result_Images/ # Visualization results + ├── training_curves.png + ├── prediction_batch_0.png + ├── prediction_batch_1.png + ├── prediction_batch_2.png + ├── prediction_batch_3.png + └── prediction_batch_4.png +``` + +## Usage + +### Training + +```bash +python train.py +``` + +Training parameters are hardcoded: 30 epochs, batch size 16, learning rate 1e-4. + +### Testing + +```bash +python predict.py +``` + +## Training Environment + +- **Platform**: Rangpur HPC (The University of Queensland) +- **GPU**: NVIDIA A100 +- **Training Time**: ~2 hours for 30 epochs + +## Reproducibility + +### Training Configuration + +- Architecture: Improved UNet (5-level encoder/decoder) +- Epochs: 30 +- Batch size: 16 +- Learning rate: 1e-4 (Adam optimizer) +- Weight decay: 1e-5 (L2 regularization) +- Loss function: CrossEntropyLoss + Deep Supervision +- Image size: 256×128 +- Number of classes: 4 + +### File Outputs Summary + +After training and evaluation: +``` +UNet_Prostate_47222610/ +├── improved_unet_best.pth # Best model +├── improved_unet_final.pth # Final model +├── improved_unet_epoch_*.pth # Checkpoints +├── logs/ +│ └── improved_unet_*.out # Training logs (text) +└── Result_Images/ + ├── training_curves.png # Loss/Dice plots + └── prediction_batch_*.png # Sample predictions +``` + +## Results + +### Test Set Performance +| Class | Region | Dice | +|-------|--------|------| +| 0 | Background | 0.9881 | +| 1 | Body | 0.9842 | +| 2 | Bone | 0.9271 | +| 3 | **Prostate (Target)** | **0.9552** | + +**Project Requirement**: Prostate Dice ≥ 0.75 +**Achievement**: **0.9552** (Exceeds requirement by 27.4%) +**Status**: PASSED + +### Visualizations + +![Training Curves](Result_Images/training_curves.png) + +*Figure 1: Training loss and prostate Dice coefficient over 30 epochs.* + +![Sample Predictions](Result_Images/prediction_batch_0.png) +![Sample Predictions](Result_Images/prediction_batch_1.png) +![Sample Predictions](Result_Images/prediction_batch_2.png) + +*Figure 2: Sample predictions on test set. Left: Input MRI, Center: Ground truth, Right: Model prediction.* + +## References +1. **Isensee, F., Kickingereder, P., Wick, W., Bendszus, M., & Maier-Hein, K. H. (2018)**. "Brain Tumor Segmentation and Radiomics Survival Prediction: Contribution to the BRATS 2017 Challenge." arXiv preprint arXiv:1802.10508. + +2. **Ronneberger, O., Fischer, P., & Brox, T. (2015)**. "U-Net: Convolutional Networks for Biomedical Image Segmentation." MICCAI 2015. + +3. **Yu, F., & Koltun, V. (2016)**. "Multi-Scale Context Aggregation by Dilated Convolutions." ICLR 2016. + +4. **COMP3710 Assignment Specification**. The University of Queensland, 2025. + +## Academic Integrity +- Code written independently following course materials and cited papers +- AI tools (ChatGPT) were used to assist in understanding and to provide reference material for writing docstrings + +## Author + +**Student Name**: Chia Jou Lu + +**Student ID**: 47222610 + +**Course**: COMP3710 Pattern Recognition + +**Institution**: The University of Queensland + +**Date**: November 2025 + + + + + + + + diff --git a/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_0.png b/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_0.png new file mode 100644 index 000000000..966ca7e42 Binary files /dev/null and b/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_0.png differ diff --git a/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_1.png b/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_1.png new file mode 100644 index 000000000..b7ba2a46c Binary files /dev/null and b/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_1.png differ diff --git a/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_2.png b/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_2.png new file mode 100644 index 000000000..0a75eb523 Binary files /dev/null and b/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_2.png differ diff --git a/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_3.png b/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_3.png new file mode 100644 index 000000000..9222d2563 Binary files /dev/null and b/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_3.png differ diff --git a/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_4.png b/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_4.png new file mode 100644 index 000000000..ba753700a Binary files /dev/null and b/recognition/UNet_Prostate_47222610/Result_Images/prediction_batch_4.png differ diff --git a/recognition/UNet_Prostate_47222610/Result_Images/training_curves.png b/recognition/UNet_Prostate_47222610/Result_Images/training_curves.png new file mode 100644 index 000000000..08b23cb85 Binary files /dev/null and b/recognition/UNet_Prostate_47222610/Result_Images/training_curves.png differ diff --git a/recognition/UNet_Prostate_47222610/dataset.py b/recognition/UNet_Prostate_47222610/dataset.py new file mode 100644 index 000000000..14c4cd803 --- /dev/null +++ b/recognition/UNet_Prostate_47222610/dataset.py @@ -0,0 +1,79 @@ +""" +Dataset Loader for HipMRI Prostate Segmentation + +This module provides data loading utilities for the HipMRI prostate MRI dataset +using ONLY the provided utility functions from the assignment Appendix B. + +Author: 47222610 +Date: October 2025 +Assignment: Pattern Recognition Project - 2D Prostate Segmentation + +References: + - Assignment Appendix B: Provided utility functions + - NIfTI file format: https://nifti.nimh.nih.gov/ + - Nibabel library: https://nipy.org/nibabel/ +""" +import numpy as np +import nibabel as nib +from tqdm import tqdm + +def to_channels(arr: np.ndarray, dtype=np.uint8) -> np.ndarray: + channels = np.unique(arr) + res = np.zeros(arr.shape + (len(channels),), dtype=dtype) + for c in channels: + c = int(c) + res[..., c:c+1][arr == c] = 1 + + return res + +# load medical image functions +def load_data_2D(imageNames, normImage=False, categorical=False, + dtype=np.float32, getAffines=False, early_stop=False): + """ + Load medical image data from names, cases list provided into a list for each. + + This function pre-allocates 4D arrays for conv2d to avoid excessive memory ↘ + untitled folder usage. + + normImage: bool (normalise the image 0.0 -1.0) + early_stop: Stop loading pre-maturely, leaves arrays mostly empty, for quick ↘ + loading and testing scripts. + """ + affines = [] + + # get fixed size + num = len(imageNames) + first_case = nib.load(imageNames[0]).get_fdata(caching='unchanged') + if len(first_case.shape) == 3: + first_case = first_case[:, :, 0] + if categorical: + first_case = to_channels(first_case, dtype=dtype) + rows, cols, channels = first_case.shape + images = np.zeros((num, rows, cols, channels), dtype=dtype) + else: + rows, cols = first_case.shape + images = np.zeros((num, rows, cols), dtype=dtype) + + for i, inName in enumerate(tqdm(imageNames, desc='Loading images')): + niftiImage = nib.load(inName) + inImage = niftiImage.get_fdata(caching='unchanged') + affine = niftiImage.affine + if len(inImage.shape) == 3: + inImage = inImage[:, :, 0] + inImage = inImage.astype(dtype) + if normImage: + inImage = (inImage - inImage.mean()) / (inImage.std() + 1e-8) + if categorical: + inImage = to_channels(inImage, dtype=dtype) + images[i, :, :, :] = inImage + else: + images[i, :, :] = inImage + + affines.append(affine) + if i > 20 and early_stop: + break + + if getAffines: + return images, affines + else: + return images diff --git a/recognition/UNet_Prostate_47222610/modules.py b/recognition/UNet_Prostate_47222610/modules.py new file mode 100644 index 000000000..2831ab0b3 --- /dev/null +++ b/recognition/UNet_Prostate_47222610/modules.py @@ -0,0 +1,252 @@ +""" +Improved UNet Model Implementation for 2D Prostate Segmentation + +This module implements the Improved UNet architecture based on Isensee et al. 2018. + +Key improvements over standard UNet: + - Instance Normalization instead of Batch Normalization + - Leaky ReLU instead of ReLU + - Residual connections in encoder/decoder blocks + - Context aggregation module with dilated convolutions + - Deep supervision at multiple scales +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ResidualDoubleConv(nn.Module): + """ + Improved Double Convolution Block with residual connections. + + Key improvements: + - Instance Normalization for better stability with small batches. + - Leaky ReLU to prevent dying neurons. + - Residual connection for better gradient flow. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + + """ + + def __init__(self, in_channels, out_channels): + super(ResidualDoubleConv, self).__init__() + + self.double_conv = nn.Sequential( + # First Convolution + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.InstanceNorm2d(out_channels), # Instance Norm instead of Batch Norm + nn.LeakyReLU(negative_slope=0.01, inplace=True), # Leaky ReLU + + # Second Convolution + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.InstanceNorm2d(out_channels), + nn.LeakyReLU(negative_slope=0.01, inplace=True) + ) + + # 1x1 conv for residual if channel dimensions change + self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) \ + if in_channels != out_channels else nn.Identity() + + def forward(self, feature_map): + """ + Forward pass through the double convolution block. + """ + residual = self.residual_conv(feature_map) + out = self.double_conv(feature_map) + return out + residual # Residual connection + + +class ContextModule(nn.Module): + """ + Context Aggregation Module using dilated convolutions. + + Args: + channels (int): Number of channels + """ + + def __init__(self, channels): + super(ContextModule, self).__init__() + + # Dilated convolutions with different rates (1, 2, 4, 8) + self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, dilation=1) + self.conv2 = nn.Conv2d(channels, channels, 3, padding=2, dilation=2) + self.conv4 = nn.Conv2d(channels, channels, 3, padding=4, dilation=4) + self.conv8 = nn.Conv2d(channels, channels, 3, padding=8, dilation=8) + + self.norm = nn.InstanceNorm2d(channels) + self.activation = nn.LeakyReLU(negative_slope=0.01, inplace=True) + + def forward(self, feature_map): + feature_map1 = self.activation(self.norm(self.conv1(feature_map))) + feature_map2 = self.activation(self.norm(self.conv2(feature_map))) + feature_map4 = self.activation(self.norm(self.conv4(feature_map))) + feature_map8 = self.activation(self.norm(self.conv8(feature_map))) + + return feature_map + feature_map1 + feature_map2 + feature_map4 + feature_map8 + + +class Downsampling(nn.Module): + """ + Downsampling block with residual connections. + """ + + def __init__(self, in_channels, out_channels): + super(Downsampling, self).__init__() + + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), # Downsample use 2x2 + ResidualDoubleConv(in_channels, out_channels) + ) + + def forward(self, feature_map): + """ + Forward pass through the downsampling block. + """ + return self.maxpool_conv(feature_map) + + +class Upsampling(nn.Module): + """ + Upsampling block with residual connections. + """ + + def __init__(self, in_channels, out_channels): + super(Upsampling, self).__init__() + + self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, + kernel_size=2, stride=2) + + # DoubleConv after concatenation + self.conv = ResidualDoubleConv(in_channels, out_channels) + + def forward(self, feature_map1, feature_map2): + """ + Forward pass through the upsampling block. + """ + feature_map1 = self.upsampling(feature_map1) + + # Handle potential size mismatch due to odd dimensions + # Calculate padding needed to match feature_map2's spatial dimensions + diffY = feature_map2.size()[2] - feature_map1.size()[2] # Height difference + diffX = feature_map2.size()[3] - feature_map1.size()[3] # Width difference + + # Pad feature_map1 if needed + feature_map1 = F.pad(feature_map1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + + # Concatenate + feature_map = torch.cat([feature_map2, feature_map1], dim=1) + + return self.conv(feature_map) + + +class OutConv(nn.Module): + """ + Output convolution layer for final segmentation mask. + """ + + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + + # 1×1 convolution to map to output classes + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, feature_map): + """ + Forward pass through the output convolution. + """ + return self.conv(feature_map) + + +class ImprovedUNet(nn.Module): + """ + Improved UNet architecture. + + Key improvements over standard UNet: + 1. Instance Normalization instead of Batch Normalization. + 2. Leaky ReLU activation. + 3. Residual connections in all conv blocks. + 4. Context aggregation module at bottleneck. + 5. Deep supervision at decoder levels. + + Args: + n_channels (int): Number of input channels + n_classes (int): Number of output classes + deep_supervision (bool): Whether to use deep supervision + """ + + def __init__(self, n_channels, n_classes, deep_supervision=True): + super(ImprovedUNet, self).__init__() + + self.n_channels = n_channels + self.n_classes = n_classes + self.deep_supervision = deep_supervision + + # Encoder + self.initConv = ResidualDoubleConv(n_channels, 64) # Initial convolution + self.down1 = Downsampling(64, 128) # 256×128 -> 128×64 + self.down2 = Downsampling(128, 256) # 128×64 -> 64×32 + self.down3 = Downsampling(256, 512) # 64×32 -> 32×16 + self.down4 = Downsampling(512, 1024) # 32×16 -> 16×8 + + # Context module at bottleneck + self.context = ContextModule(1024) + + # Decoder + self.up1 = Upsampling(1024, 512) # 16×8 -> 32×16 + self.up2 = Upsampling(512, 256) # 32×16 -> 64×32 + self.up3 = Upsampling(256, 128) # 64×32 -> 128×64 + self.up4 = Upsampling(128, 64) # 128×64 -> 256×128 + + # Output + self.outConv = OutConv(64, n_classes) # Final 1×1 conv to n_classes + + # Deep supervision outputs + if self.deep_supervision: + self.dsv4 = nn.Conv2d(512, n_classes, 1) + self.dsv3 = nn.Conv2d(256, n_classes, 1) + self.dsv2 = nn.Conv2d(128, n_classes, 1) + self.dsv1 = nn.Conv2d(64, n_classes, 1) + + def forward(self, feature_map): + """ + Forward pass through the ImprovedUNet. + """ + # Encoder + feature_map1 = self.initConv(feature_map) # 64 channels, same size + feature_map2 = self.down1(feature_map1) # 128 channels, 1/2 size + feature_map3 = self.down2(feature_map2) # 256 channels, 1/4 size + feature_map4 = self.down3(feature_map3) # 512 channels, 1/8 size + feature_map5 = self.down4(feature_map4) # 1024 channels, 1/16 size + + # Context aggregation + feature_map5 = self.context(feature_map5) + + # Decoder with skip connections + d4 = self.up1(feature_map5, feature_map4) # 512 channels, 1/8 size + d3 = self.up2(d4, feature_map3) # 256 channels, 1/4 size + d2 = self.up3(d3, feature_map2) # 128 channels, 1/2 size + d1 = self.up4(d2, feature_map1) # 64 channels, original size + + # Output + logits = self.outConv(d1) + + # Deep supervision outputs + if self.deep_supervision and self.training: + # Upsample intermediate outputs to match target size + dsv4 = F.interpolate(self.dsv4(d4), size=feature_map.shape[2:], + mode='bilinear', align_corners=False) + dsv3 = F.interpolate(self.dsv3(d3), size=feature_map.shape[2:], + mode='bilinear', align_corners=False) + dsv2 = F.interpolate(self.dsv2(d2), size=feature_map.shape[2:], + mode='bilinear', align_corners=False) + dsv1 = F.interpolate(self.dsv1(d1), size=feature_map.shape[2:], + mode='bilinear', align_corners=False) + + return logits, dsv1, dsv2, dsv3, dsv4 + else: + return logits + + diff --git a/recognition/UNet_Prostate_47222610/predict.py b/recognition/UNet_Prostate_47222610/predict.py new file mode 100644 index 000000000..0701bde07 --- /dev/null +++ b/recognition/UNet_Prostate_47222610/predict.py @@ -0,0 +1,336 @@ +""" +Prediction, testing, and visualization for UNet prostate segmentation. +""" +import os +import glob +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F +import nibabel as nib +import cv2 +from tqdm import tqdm + +from modules import ImprovedUNet + + +def load_data_with_resize(image_paths, target_size=(256, 128), normImage=True): + """ + Load the image and resize it to a uniform size. + + Args: + image_paths: Paths to the image files to be loaded. + target_size: Desired spatial dimensions for the output images. + normImage: If true, each image is normalized. Default is true. + + Returns: + A NumPy array of shape (N, H, W) containing the processed images. + """ + n = len(image_paths) + images = np.zeros((n, target_size[0], target_size[1]), dtype=np.float32) + + for i, path in enumerate(tqdm(image_paths, desc='Loading images')): + img = nib.load(path).get_fdata(caching='unchanged') + + if len(img.shape) == 3: + img = img[:, :, 0] + + # Resize if needed + if img.shape != target_size: + img = cv2.resize(img, (target_size[1], target_size[0]), + interpolation=cv2.INTER_LINEAR) + + # Normalize + if normImage: + img = (img - img.mean()) / (img.std() + 1e-8) + + images[i] = img + + return images + + +def load_labels_with_resize(seg_paths, target_size=(256, 128), n_classes=4): + """ + Load tags, resize, clean up categories, and perform one-hot encoding. + + Args: + seg_paths: Paths to the segmentation label files to be loaded. + target_size: Desired output dimensions (height, width) after resizing. + n_classes: Number of valid classes for one-hot encoding. + + Returns: + A NumPy array of shape (N, H, W, n_classes) containing the processed + one-hot encoded labels. + """ + n = len(seg_paths) + labels = np.zeros((n, target_size[0], target_size[1], n_classes), dtype=np.float32) + + for i, path in enumerate(tqdm(seg_paths, desc='Loading labels')): + label = nib.load(path).get_fdata(caching='unchanged') + + if len(label.shape) == 3: + label = label[:, :, 0] + + # Resize (Use INTER_NEAREST to keep the category unchanged) + if label.shape != target_size: + label = cv2.resize(label, (target_size[1], target_size[0]), + interpolation=cv2.INTER_NEAREST) + + # Clean up redundant categories (should address the issues with categories 4 and 5) + label[label >= n_classes] = 0 + + # One-hot encoding + for c in range(n_classes): + labels[i, :, :, c] = (label == c) + + return labels + + +def dice_coefficient_per_class(predicted, target, n_classes=4): + """ + Calculate Dice coefficient for each class separately. + + Args: + predicted: Model predictions of shape (N, C, H, W). + target: One-hot encoded ground truth labels of shape (N, C, H, W). + n_classes: Number of segmentation classes. + + Returns: + dict: A dictionary mapping each class name to its Dice coefficient. + """ + dice_scores = {} + + for c in range(n_classes): + # Extract specific class + predicted_c = predicted[:, c, :, :] # (N, H, W) + target_c = target[:, c, :, :] # (N, H, W) + + # Calculate intersection and union + intersection = (predicted_c * target_c).sum() + denominator = predicted_c.sum() + target_c.sum() + + # Handle empty cases + if denominator == 0: + dice_scores[f'class_{c}'] = 1.0 # Both empty = perfect match + else: + dice_scores[f'class_{c}'] = (2. * intersection / denominator).item() + + return dice_scores + + +def test(model, data_loader, device, save_path='results', visualize=False): + """ + Test model on dataset and generate visualizations. + + Args: + model: Trained UNet model. + data_loader: DataLoader for test data. + device: The computation device to run the test on (e.g., 'cuda' or 'cpu'). + save_path: Directory to save results. + visualize: Whether to generate visualization images. + + Returns: + dict: average Dice coefficient across all batches. + """ + model.eval() + + if not os.path.exists(save_path): + os.makedirs(save_path) + + running_dice = { + 'class_0': 0.0, # Background + 'class_1': 0.0, # Body + 'class_2': 0.0, # Bone + 'class_3': 0.0 # Prostate (MAIN TARGET) + } + batch_count = 0 + + with torch.no_grad(): + pbar = tqdm(data_loader, desc='Testing') + + for batch_idx, (images, labels) in enumerate(pbar): + images = images.to(device) + labels = labels.to(device) + + outputs = model(images) + probs = torch.softmax(outputs, dim=1) + + dice = dice_coefficient_per_class(probs, labels, n_classes=4) + + for key in running_dice.keys(): + running_dice[key] += dice[key] + + batch_count += 1 + + pbar.set_postfix({'Prostate_Dice': f'{dice["class_3"]:.4f}'}) + + # Visualize first few batches + if visualize and batch_idx < 5: + visualize_prediction(images, labels, probs, batch_idx, save_path) + + avg_dice = {key: value / batch_count for key, value in running_dice.items()} + + return avg_dice + + +def visualize_prediction(images, labels, predicts, batch_idx, save_path): + """ + Save a figure showing the input, ground truth, and prediction. + + Args: + images: Batch of input MRI images. Only the first image in the batch is visualized. + labels: One-hot encoded ground truth masks of shape (N, C, H, W). + predicts: Model output probabilities or logits. + batch_idx: Index of the current batch. + save_path: Directory path where the visualization image will be saved. + """ + # Get first image from batch + img = images[0, 0].cpu().numpy() + label = torch.argmax(labels[0], dim=0).cpu().numpy() + predict = torch.argmax(predicts[0], dim=0).cpu().numpy() + + # Create figure + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + + # Original image + axes[0].imshow(img, cmap='gray') + axes[0].set_title('MRI Image', fontsize=14) + axes[0].axis('off') + + # Ground truth + axes[1].imshow(label, cmap='tab10', vmin=0, vmax=3) + axes[1].set_title('Ground Truth', fontsize=14) + axes[1].axis('off') + + # Prediction + axes[2].imshow(predict, cmap='tab10', vmin=0, vmax=3) + axes[2].set_title('Prediction', fontsize=14) + axes[2].axis('off') + + # Add legend + from matplotlib.patches import Patch + legend_elements = [ + Patch(facecolor='tab:blue', label='Background'), + Patch(facecolor='tab:red', label='Body'), + Patch(facecolor='tab:pink', label='Bone'), + Patch(facecolor='tab:cyan', label='Prostate') + ] + fig.legend(handles=legend_elements, loc='lower center', ncol=4, + bbox_to_anchor=(0.5, -0.05)) + + plt.tight_layout() + + # Save figure + save_file = os.path.join(save_path, f'prediction_batch_{batch_idx}.png') + plt.savefig(save_file, bbox_inches='tight', dpi=150) + plt.close() + + print(f" Saved visualization: {save_file}") + + +def plot_curves(history, save_path='results'): + """ + Plot training curves. + + Args: + history: Dictionary containing 'train_loss', 'val_loss', 'train_dice', 'val_dice'. + save_path: Directory to save the plot. + """ + if not os.path.exists(save_path): + os.makedirs(save_path) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) + + # Loss plot + ax1.plot(history['train_loss'], label='Training') + if 'val_loss' in history: + ax1.plot(history['val_loss'], label='Validation') + ax1.set_xlabel('Epoch') + ax1.set_ylabel('Loss') + ax1.set_title('Loss over epochs') + ax1.legend() + ax1.grid(True) + + # Dice plot + ax2.plot(history['train_dice'], label='Training') + if 'val_dice' in history: + ax2.plot(history['val_dice'], label='Validation') + ax2.set_xlabel('Epoch') + ax2.set_ylabel('Dice Coefficient (Prostate)') + ax2.set_title('Dice over epochs') + ax2.legend() + ax2.grid(True) + + plt.tight_layout() + save_file = os.path.join(save_path, 'training_curves.png') + plt.savefig(save_file, dpi=150) + plt.close() + + print(f"Saved training curves to {save_file}") + + +def main(): + # Configuration + data_path = "/home/groups/comp3710/HipMRI_Study_open/keras_slices_data" + checkpoint_path = "unet_final.pth" + test_split = "test" + batch_size = 16 + visualize = True + + # Set device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}\n") + + # Load model + print("Loading model...") + model = ImprovedUNet(n_channels=1, n_classes=4, deep_supervision=False).to(device) + checkpoint = torch.load('improved_unet_final.pth', map_location=device) + model.load_state_dict(checkpoint['model_state_dict'], strict=False) + + # Load test data + print(f"Loading {test_split} dataset...") + from torch.utils.data import DataLoader, TensorDataset + + test_images = sorted(glob.glob(f'{data_path}/keras_slices_{test_split}/*.nii.gz')) + test_labels = sorted(glob.glob(f'{data_path}/keras_slices_seg_{test_split}/*.nii.gz')) + + X_test = load_data_with_resize(test_images) + y_test = load_labels_with_resize(test_labels) + + test_dataset = TensorDataset( + torch.from_numpy(X_test).unsqueeze(1).float(), + torch.from_numpy(y_test).permute(0, 3, 1, 2).float() + ) + test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + + # Evaluate on test set + print(f"\nEvaluating on {test_split} set...") + results = test(model, test_loader, device, save_path='results', visualize=visualize) + + # Print results + print("\n" + "="*60) + print("TEST SET RESULTS") + print("="*60) + class_names = ['Background', 'Body', 'Bone', 'Prostate'] + for i, name in enumerate(class_names): + print(f"{name:20s}: Dice = {results[f'class_{i}']:.4f}") + print("="*60) + print(f"\n{'Prostate (Target)':20s}: {results['class_3']:.4f} (Requirement: ≥ 0.75)") + + if results['class_3'] >= 0.75: + print("PASSED - Prostate segmentation meets requirement!") + else: + print("FAILED - Below requirement") + print("="*60) + + if visualize: + print("\nVisualizations saved to 'results/' directory") + + print("\nEvaluation complete!") + + +if __name__ == "__main__": + main() + diff --git a/recognition/UNet_Prostate_47222610/train.py b/recognition/UNet_Prostate_47222610/train.py new file mode 100644 index 000000000..006d3804e --- /dev/null +++ b/recognition/UNet_Prostate_47222610/train.py @@ -0,0 +1,374 @@ +""" +Train Improved Unet for prostate segmentation. + +This is the training script for the HipMRI dataset. +""" +import numpy as np +import nibabel as nib +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +import glob +from tqdm import tqdm +import cv2 + +from modules import ImprovedUNet + +def load_data_with_resize(image_paths, target_size=(256, 128), normImage=True): + """ + Load the image and resize it to a uniform size. + + Args: + image_paths: Paths to the image files to be loaded. + target_size: Desired spatial dimensions for the output images. + normImage: If true, each image is normalized. Default is true. + + Returns: + A NumPy array of shape (N, H, W) containing the processed images. + """ + n = len(image_paths) + images = np.zeros((n, target_size[0], target_size[1]), dtype=np.float32) + + for i, path in enumerate(tqdm(image_paths, desc='Loading images')): + img = nib.load(path).get_fdata(caching='unchanged') + + if len(img.shape) == 3: + img = img[:, :, 0] + + # Resize if needed + if img.shape != target_size: + img = cv2.resize(img, (target_size[1], target_size[0]), + interpolation=cv2.INTER_LINEAR) + + # Normalize + if normImage: + img = (img - img.mean()) / (img.std() + 1e-8) + + images[i] = img + + return images + + +def load_labels_with_resize(seg_paths, target_size=(256, 128), n_classes=4): + """ + Load tags, resize, clean up categories, and perform one-hot encoding. + + Args: + seg_paths: Paths to the segmentation label files to be loaded. + target_size: Desired output dimensions (height, width) after resizing. + n_classes: Number of valid classes for one-hot encoding. + + Returns: + A NumPy array of shape (N, H, W, n_classes) containing the processed + one-hot encoded labels. + """ + n = len(seg_paths) + labels = np.zeros((n, target_size[0], target_size[1], n_classes), dtype=np.float32) + + for i, path in enumerate(tqdm(seg_paths, desc='Loading labels')): + label = nib.load(path).get_fdata(caching='unchanged') + + if len(label.shape) == 3: + label = label[:, :, 0] + + # Resize (Use INTER_NEAREST to keep the category unchanged) + if label.shape != target_size: + label = cv2.resize(label, (target_size[1], target_size[0]), + interpolation=cv2.INTER_NEAREST) + + # Clean up redundant categories (should address the issues with categories 4 and 5) + label[label >= n_classes] = 0 + + # One-hot encoding + for c in range(n_classes): + labels[i, :, :, c] = (label == c) + + return labels + + +def dice_coefficient_per_class(predicted, target, n_classes=4): + """ + Calculate Dice coefficient for each class separately. + + Args: + predicted: Predicted probabilities (N, C, H, W) + target: One-hot encoded ground truth (N, C, H, W) + n_classes: Number of classes + + Returns: + dict: Dice scores for each class (class_0, class_1, class_2, class_3) + """ + dice_scores = {} + + for c in range(n_classes): + # Extract specific class + predicted_c = predicted[:, c, :, :] # (N, H, W) + target_c = target[:, c, :, :] # (N, H, W) + + # Calculate intersection and union + intersection = (predicted_c * target_c).sum() + denominator = predicted_c.sum() + target_c.sum() + + # Handle empty cases + if denominator == 0: + dice_scores[f'class_{c}'] = 1.0 # Both empty = perfect match + else: + dice_scores[f'class_{c}'] = (2. * intersection / denominator).item() + + return dice_scores + + +def train_one_epoch(model, data_loader, loss_fn, optimizer, device): + """ + This function trains the model for one epoch on the given data loader + with deep supervision. + + Deep supervision weights decrease for deeper layers: + - Main output: weight = 1.0 + - DSV1: weight = 0.8 + - DSV2: weight = 0.6 + - DSV3: weight = 0.4 + - DSV4: weight = 0.2 + + Args: + model: The neural network model to be trained. Its forward may return a tensor + or a dict of tensors for deep supervision. + data_loader: Iterable providing batches of (images, targets). Images are tensors + of shape (N, C_in, H, W). Targets are class indices (N, H, W) or + one-hot masks depending on loss_fn requirements. + loss_fn: The criterion used to measure prediction error (e.g., CrossEntropyLoss + or Dice-based losses). Must accept per-head logits and matched-size targets. + optimizer: Optimizer used to update model parameters based on computed gradients. + device: Computation device (e.g., 'cuda' or 'cpu'). Model, images, and targets + are moved to this device. + Returns: + A tuple (avg_loss, avg_dice), where avg_loss (float) is the average loss across + all batches, and avg_dice (float) is the average Dice coefficient across all batches, + used for segmentation performance monitoring. + """ + model.train() + + running_loss = 0.0 + running_dice = { + 'class_0': 0.0, # Background + 'class_1': 0.0, # Body + 'class_2': 0.0, # Bone + 'class_3': 0.0 # Prostate (MAIN TARGET) + } + batch_count = 0 + + # Deep supervision weights + ds_weights = [1.0, 0.8, 0.6, 0.4, 0.2] + + pbar = tqdm(data_loader, desc='Training') + + for images, labels in pbar: + # Move data to device (GPU/CPU) + images = images.to(device) + labels = labels.to(device) + + # Forward + outputs = model(images) + + # Loss Function + class_indices = torch.argmax(labels, dim=1).long() # (N, H, W) + + if isinstance(outputs, tuple): + # Deep supervision is active + main_output, dsv1, dsv2, dsv3, dsv4 = outputs + + loss = (ds_weights[0] * loss_fn(main_output, class_indices) + + ds_weights[1] * loss_fn(dsv1, class_indices) + + ds_weights[2] * loss_fn(dsv2, class_indices) + + ds_weights[3] * loss_fn(dsv3, class_indices) + + ds_weights[4] * loss_fn(dsv4, class_indices)) + else: + # Inference mode, no deep supervision + loss = loss_fn(outputs, class_indices) + main_output = outputs + + # Backpropagation + optimizer.zero_grad() + loss.backward() + + # Gradient clipping for stability + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + optimizer.step() + + # Use Dice to do monitor (no need gradient here) + with torch.no_grad(): + probs = torch.softmax(main_output, dim=1) + dice = dice_coefficient_per_class(probs, labels, n_classes=4) + + # Accumulate dice scores + for key in running_dice.keys(): + running_dice[key] += dice[key] + + running_loss += loss.item() + batch_count += 1 + + # Update progress bar with prostate (class_3) Dice + pbar.set_postfix({ + 'Loss': f'{loss.item():.4f}', + 'Prostate_Dice': f'{dice["class_3"]:.4f}' + }) + + avg_loss = running_loss / batch_count + avg_dice = {key: value / batch_count for key, value in running_dice.items()} + + return avg_loss, avg_dice + + +def validate(model, data_loader, loss_fn, device): + """ + Validate on validation set. + Similar to training but no backpropagation. + + Args: + model: The neural network model being evaluated. + data_loader: Iterable that provides batches of validation data. + loss_fn: The loss function used to compute prediction error during validation. + device: The computation device to run the validation on (e.g., 'cuda' or 'cpu'). + + Returns: + A tuple (avg_loss, avg_dice), where avg_loss (float) is the average loss across + all batches, and avg_dice (float) is the average Dice coefficient across all batches, + used for segmentation performance monitoring. + """ + model.eval() + + running_loss = 0.0 + running_dice = { + 'class_0': 0.0, # Background + 'class_1': 0.0, # Peripheral Zone + 'class_2': 0.0, # Transition Zone + 'class_3': 0.0 # Prostate (MAIN TARGET) + } + batch_count = 0 + + with torch.no_grad(): + pbar = tqdm(data_loader, desc='Validation') + + for images, labels in pbar: + images = images.to(device) + labels = labels.to(device) + + outputs = model(images) + + class_indices = torch.argmax(labels, dim=1) + loss = loss_fn(outputs, class_indices) + + probs = torch.softmax(outputs, dim=1) + dice = dice_coefficient_per_class(probs, labels, n_classes=4) + + # Accumulate dice scores + for key in running_dice.keys(): + running_dice[key] += dice[key] + + running_loss += loss.item() + batch_count += 1 + + # Update progress bar with prostate (class_3) Dice + pbar.set_postfix({ + 'Loss': f'{loss.item():.4f}', + 'Prostate_Dice': f'{dice["class_3"]:.4f}' + }) + + avg_loss = running_loss / batch_count + avg_dice = {key: value / batch_count for key, value in running_dice.items()} + + return avg_loss, avg_dice + + +if __name__ == "__main__": + # Configuration + data_path = "/home/groups/comp3710/HipMRI_Study_open/keras_slices_data" + num_epochs = 30 + batch_size = 16 + learning_rate = 1e-4 + weight_decay = 1e-5 + target_size = (256, 128) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}\n") + + # Load data + print("Loading training data...") + train_images = sorted(glob.glob(f'{data_path}/keras_slices_train/*.nii.gz')) + train_labels = sorted(glob.glob(f'{data_path}/keras_slices_seg_train/*.nii.gz')) + + X_train = load_data_with_resize(train_images, target_size=target_size) + y_train = load_labels_with_resize(train_labels, target_size=target_size) + + train_dataset = TensorDataset( + torch.from_numpy(X_train).unsqueeze(1).float(), + torch.from_numpy(y_train).permute(0, 3, 1, 2).float() + ) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + # Initialize model + print("Initializing model...") + model = ImprovedUNet(n_channels=1, n_classes=4, deep_supervision=True).to(device) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + + # Learning rate scheduler + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode='min', factor=0.5, patience=5 + ) + + # Training loop + print(f"\nTraining for {num_epochs} epochs...") + best_dice = 0.0 + + for epoch in range(num_epochs): + train_loss, train_dice = train_one_epoch(model, train_loader, criterion, optimizer, device) + + scheduler.step(train_loss) + + print(f"Epoch [{epoch+1}/{num_epochs}]") + print(f" Loss: {train_loss:.4f}") + print(f" Prostate Dice: {train_dice['class_3']:.4f}") + print(f" Learning Rate: {optimizer.param_groups[0]['lr']:.6f}") + + # Save best model + if train_dice['class_3'] > best_dice: + best_dice = train_dice['class_3'] + torch.save({ + 'epoch': epoch + 1, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'train_loss': train_loss, + 'train_dice': train_dice, + }, 'improved_unet_best.pth') + print(f"Saved best model (Dice: {best_dice:.4f})") + + # Save checkpoints + if (epoch + 1) % 10 == 0: + torch.save({ + 'epoch': epoch + 1, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'train_loss': train_loss, + 'train_dice': train_dice, + }, f'improved_unet_epoch_{epoch+1}.pth') + print(f"Saved checkpoint: improved_unet_epoch_{epoch+1}.pth") + + print() + + # Save final model + torch.save({ + 'epoch': num_epochs, + 'model_state_dict': model.state_dict(), + 'train_dice': train_dice, + }, 'improved_unet_final.pth') + + print(f"\nTraining complete!") + print(f"Best Prostate Dice: {best_dice:.4f}") + print(f"Final Prostate Dice: {train_dice['class_3']:.4f}") + print(f"Model saved to: improved_unet_final.pth") + +