diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..4367d2f6a --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +keras_png_slices_data/ +HipMRI_Study_open/ +__pycache__/ \ No newline at end of file diff --git a/README.md b/README.md index 272bba4fa..698cc288c 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,298 @@ -# Pattern Analysis -Pattern Analysis of various datasets by COMP3710 students in 2025 at the University of Queensland. +# HipMRI_Study Segmentation with Improved U-Net (Task 3) -We create pattern recognition and image processing library for Tensorflow (TF), PyTorch or JAX. +## Author -This library is created and maintained by The University of Queensland [COMP3710](https://my.uq.edu.au/programs-courses/course.html?course_code=comp3710) students. +Marcus Baulch (47445464) +COMP3710 - Pattern Recognition and Analysis +The University of Queensland -The library includes the following implemented in Tensorflow: -* fractals -* recognition problems +## Overview -In the recognition folder, you will find many recognition problems solved including: -* segmentation -* classification -* graph neural networks -* StyleGAN -* Stable diffusion -* transformers -etc. +This project implements an Improved U-Net for multi-class semantic segmentation of Prostate MRI images. The model segments anatomical structures in 2D MRI slices into 6 distinct classes, achieving performance through residual connections, batch normalisation, and a combined loss function. +### Key Features + +- **Residual U-Net Architecture**: Enhanced U-Net with ResNet-style skip connections within encoder/decoder blocks +- **Multi-Class Segmentation**: 6-class semantic segmentation +- **Combined Loss Function**: 60% Dice Loss + 40% Cross-Entropy for balanced optimisation +- **Data Augmentation**: Random flips and rotations during training +- **Comprehensive Evaluation**: Dice coefficient metrics with visualisation capabilities + +--- + +## Dataset and Preprocessing + +### Hip MRI Study Dataset + +The project uses the HipMRI Study Open Dataset, which contains MRI scans with semantic labels for male pelvises. The data was retrieved from https://data.csiro.au/collection/csiro:51392v2?redirected=true (see reference at the end). + +**Dataset Structure:** +``` +HipMRI_Study_open/ +├── keras_slices_data/ +│ ├── keras_slices_train/ # Training images +│ ├── keras_slices_seg_train/ # Training masks +│ ├── keras_slices_validate/ # Validation images +│ ├── keras_slices_seg_validate/ # Validation masks +│ ├── keras_slices_test/ # Test images +│ └── keras_slices_seg_test/ # Test masks +└── semantic_labels_only/ # Original 3D NIfTI files +``` + +### Preprocessing + +1. **Image Loading**: NIfTI (.nii.gz) files loaded with `nibabel` +2. **Normalisation**: Images standardised using z-score normalisation: `(x - mean) / std` [1] +3. **One-Hot Encoding**: Masks converted to 6-channel one-hot format `[B, 6, H, W]` +4. **Data Augmentation** (training only) [2]: + - Random horizontal/vertical flips (50% probability each) + - Random rotation of ±15 degrees + - Geometric transforms applied consistently to image-mask pairs + +### Train/Validation/Test Split + +The dataset uses a predefined split provided by the HipMRI Study Open Dataset: +- **Training set**: 11,460 images +- **Validation set**: 660 images +- **Test set**: 540 images + +**Justification**: +- The dataset came pre-split, so no manual splitting was required +- 90/5/5 split is standard for medical imaging datasets +- Training set is large enough to learn robust features +- Validation set (660 samples) is sufficient +- Test set (540 samples) provides statistically meaningful evaluation + + +## Model Architecture + +### Residual U-Net + +The model improves upon standard U-Net with residual blocks and batch normalisation: + +``` +Input (1 channel, grayscale MRI) + +Encoder Path (with residual blocks): + ResBlock(1→64) → MaxPool + ResBlock(64→128) → MaxPool + ResBlock(128→256) → MaxPool + ResBlock(256→512) → MaxPool + ResBlock(512→1024) [Bottleneck] + +Decoder Path (with skip connections): + UpConv + Concat → ResBlock(1024→512) + UpConv + Concat → ResBlock(512→256) + UpConv + Concat → ResBlock(256→128) + UpConv + Concat → ResBlock(128→64) + +Final Conv(64→6) + +Output (6 channels, class logits) +``` + +### Residual Block Details + +Each ResidualBlock consists of: +``` +Input + ├─ Conv3x3 → BatchNorm → ReLU → Conv3x3 → BatchNorm → (+) + └─ [1x1 Conv if channels mismatch] ────────────────────→ ReLU → Output +``` + + +## Training + +### Configuration + +| Parameter | Value | +|-----------|-------| +| **Batch Size** | 16 | +| **Epochs** | 20 | +| **Learning Rate** | 1e-4 | +| **Optimiser** | Adam | +| **Loss Function** | 60% Dice + 40% CrossEntropy | +| **Device** | CUDA (if available) / CPU | + +### Loss Function + +The combined loss leverages strengths of both components: + +- **Dice Loss**: Directly optimises the evaluation metric (Dice coefficient) +- **Cross-Entropy**: Provides stable pixel-wise classification gradients + + +### Training Script + +```bash +python train.py +``` + +**Outputs:** +- `outputs/best_model.pth` - Best model checkpoint (highest validation Dice) +- `outputs/prediction_XXX.png/` - Predicted visualisations (saved PNGs) + +### Output Visualisations + +![Prediction - slice 01](outputs/prediction_000.png) +*Input | Ground truth | Model prediction* + +![Prediction - slice 02](outputs/prediction_001.png) + + +![Prediction - slice 03](outputs/prediction_002.png) + +- `outputs/training_curves.png` - Loss and Dice score plots +![Dice Loss Curves](outputs/training_curves.png) + +--- + +## Evaluation + +### Metrics + +This model was trained on only 5 epochs, as it reaches the minimum dice coefficient of 0.75 very quickly. + +**Dice Coefficient** (primary metric): +``` +Dice = (2 × |Prediction ∩ Ground Truth|) / (|Prediction| + |Ground Truth|) +``` + +Calculated per-class and averaged across all 6 classes for final score. + +### Running Evaluation + +```bash +python predict.py +``` + +**Features:** +- Loads best model from `outputs/best_model.pth` +- Evaluates on test set +- Reports mean, std, min, max Dice scores +- Saves prediction visualisations + + +## Results + +### Performance Metrics + +The following is an output from predict.py: +``` +====================================================================== +TRAINING COMPLETED +====================================================================== +Best Validation Dice: 0.8654 +====================================================================== + +====================================================================== +TEST SET EVALUATION +====================================================================== + +Test Loss: 0.2150 +Test Dice: 0.8777 +====================================================================== + +Training curves saved to: outputs/training_curves.png +FINAL SUMMARY +====================================================================== +Best Validation Dice: 0.8654 +Test Set Dice: 0.8777 +Model saved to: ./outputs/best_model.pth +Plots saved to: ./outputs/training_curves.png +====================================================================== +``` +The model provided an average Dice coefficient of 0.877 per label (averaged over 6 classes), which exceeds the 0.75 dice coefficient requirement for this task. + +### Training Curves + +Training and validation loss/Dice curves are automatically saved to `outputs/training_curves.png` after training completes. + + +## Project Structure + +``` +COMP3710-Report/ +├── train.py # Training script +├── predict.py # Evaluation script +├── modules.py # Residual U-Net architecture +├── dataset.py # Dataset loader with augmentations +├── utils_visualize.py # Visualisation utilities +├── check_predictions.py # Quick prediction checker +├── README.md # This file +├── LICENSE # Project license +└── outputs/ # Training outputs + ├── best_model.pth + ├── training_curves.png + └── prediction_XXX.png #variable amount of prediction visualisations + +``` + +--- + +## Requirements + +### Python Dependencies + +``` +torch>=1.9.0 +torchvision>=0.10.0 +numpy>=1.19.0 +nibabel>=3.2.0 +matplotlib>=3.3.0 +tqdm>=4.60.0 +scipy>=1.5.0 +``` + +### Installation + +```bash +pip install torch torchvision numpy nibabel matplotlib tqdm scipy +``` + +--- + +## Usage + +## Hardware + Runtime + +This project made use of UQ's Rangpur cluster, namely an a100 GPU. The following bash script was used to run it: +``` +#!/bin/bash +#SBATCH --partition=a100 +#SBATCH --gres=gpu:1 +#SBATCH --job-name=hipmri +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=8 + +#SBATCH --output=task3.out +#SBATCH --error=task3.err + +conda activate torch +python train.py +``` + +### Runtime Estimates + +| Task | GPU Time | CPU Time | +|------|----------|----------| +| **Training (20 epochs)** | ~10-15 min | ~1-2 hours | +| **Evaluation (test set)** | ~10-30 sec | ~1-2 min | +| **Single prediction** | <1 sec | ~1 sec | + + +### Device Selection + +The code automatically detects and uses CUDA if available: + +```python +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +``` + +--- + +### References +COMP3710 Teaching Team, 2025. Retrieved from https://colab.research.google.com/drive/1VOsZSyRhyuHLmgoqGriQk01ub4bKNmZ1?usp=sharing + +Dowling, J. & Greer, P. (2014). Labelled weekly MR images of the male pelvis. Retrieved from https://data.csiro.au/collection/csiro:51392v2?redirected=true \ No newline at end of file diff --git a/dataset.py b/dataset.py new file mode 100644 index 000000000..2ac4ec9c0 --- /dev/null +++ b/dataset.py @@ -0,0 +1,280 @@ +import numpy as np +import nibabel as nib +from tqdm import tqdm +from glob import glob +import torch +from torch.utils.data import Dataset +import torchvision.transforms.functional as F +from torchvision import transforms +import random +import argparse +from torch.utils.data import DataLoader + +def to_channels(arr: np.ndarray, dtype=np.uint8, num_classes=None) -> np.ndarray: + if num_classes is None: + channels = np.unique(arr) + num_classes = len(channels) + else: + channels = range(num_classes) + + res = np.zeros(arr.shape + (num_classes,), dtype=dtype) + for c in np.unique(arr): + c_as_int = int(c) + res[..., c_as_int] = (arr == c_as_int).astype(dtype) + return res + +def load_data_2D(imageNames, normImage=False, categorical=False, dtype=np.float32, + getAffines=False, early_stop=False, num_classes = None): + ''' + Load medical image data from names, cases list provided into a list for each. + This function pre-allocates 4D arrays for conv2d to avoid excessive memory usage. + normImage: bool (normalise the image 0.0-1.0) + early_stop: Stop loading prematurely, leaves arrays mostly empty for quick loading and testing scripts. + ''' + from scipy.ndimage import zoom + + affines = [] + + # get fixed size + num = len(imageNames) + img = nib.load(imageNames[0]) + first_case = img.get_fdata(caching='unchanged') + if len(first_case.shape) == 3: + first_case = first_case[:, :, 0] # sometimes extra dims, remove + if categorical: + first_case = to_channels(first_case, dtype=dtype, num_classes=num_classes) + rows, cols, channels = first_case.shape + images = np.zeros((num, rows, cols, channels), dtype=dtype ) + else: + rows, cols = first_case.shape + images = np.zeros((num, rows, cols), dtype=dtype) + + for i, inName in enumerate(tqdm(imageNames)): + niftiImage = nib.load(inName) + inImage = niftiImage.get_fdata(caching='unchanged') # read disk only + affine = niftiImage.affine + if len(inImage.shape) == 3: + inImage = inImage[:, :, 0] # sometimes extra dims in HipMRI_study data + + # Resize image if dimensions don't match + if inImage.shape != (rows, cols): + zoom_factors = (rows / inImage.shape[0], cols / inImage.shape[1]) + inImage = zoom(inImage, zoom_factors, order=1) # bilinear interpolation + + inImage = inImage.astype(dtype) + if normImage: + #~ inImage = inImage / np.linalg.norm ( inImage ) + #~ inImage = 255. * inImage / inImage . max () + inImage = (inImage - inImage.mean()) / inImage.std() + if categorical: + inImage = to_channels(inImage, dtype=dtype, num_classes=num_classes) # There are 6 classes in the dataset + images[i, :, :, :] = inImage + else : + images[i, :, :] = inImage + + affines.append(affine) + if i > 20 and early_stop: + break + + if getAffines: + return images, affines + else: + return images + +class DataSegmenter(Dataset): + """Custom PyTorch Dataset for 2D medical image segmentation with NIfTI files.""" + + def __init__(self, image_path, mask_path, subset_size=None, start_index=0, augment=False): + """ + Create a dataset from glob patterns for images and masks. + + Parameters: + image_path: Glob pattern for image files (e.g., 'Data/.../train/*.nii.gz') + mask_path: Glob pattern for mask files (e.g., 'Data/.../seg_train/*.nii.gz') + subset_size: Optional limit on number of samples to load + start_index: Skip samples before this index in sorted file list + augment: Whether to apply data augmentation during __getitem__ + """ + # Collect and sort file paths + self.image_paths = sorted(glob(image_path)) + self.mask_paths = sorted(glob(mask_path)) + self.augment = augment + + # Validate matching counts + num_images, num_masks = len(self.image_paths), len(self.mask_paths) + if num_images != num_masks: + print(f"Warning. The number of images ({num_images}) does not equal the number of masks ({num_masks})") + + # Calculate initial dataset size + self.dataset_size = min(num_images, num_masks) + + # Apply start index filtering + self._apply_start_index_filter(start_index) + + # Apply subset size limiting + self._apply_subset_size_limit(subset_size) + + # Preload data into memory as tensors + self._preload_data_to_memory() + + def _apply_start_index_filter(self, start_index): + """Remove samples before the specified start index.""" + if self.dataset_size <= start_index: + print(f"Warning: start index ({start_index}) >= the size of the data ({self.dataset_size}). No data will be stored.") + + self.image_paths = self.image_paths[start_index:] + self.mask_paths = self.mask_paths[start_index:] + self.dataset_size = min(len(self.image_paths), len(self.mask_paths)) + print(f"There are {self.dataset_size} samples beyond the start index") + + def _apply_subset_size_limit(self, subset_size): + """Limit dataset to a subset if specified.""" + if subset_size is not None and subset_size < self.dataset_size: + original_size = self.dataset_size + self.image_paths = self.image_paths[:subset_size] + self.mask_paths = self.mask_paths[:subset_size] + self.dataset_size = min(len(self.image_paths), len(self.mask_paths)) + print(f"Using subset of {subset_size} samples (out of {original_size} total)") + else: + print(f"Using all {self.dataset_size} samples") + + def _preload_data_to_memory(self): + """Load all images and masks from disk into memory as PyTorch tensors.""" + # Load normalized images + image_array = load_data_2D(self.image_paths, normImage=True) + self.images = torch.from_numpy(image_array) + + # Load categorical masks (one-hot encoded with 6 classes) + mask_array = load_data_2D(self.mask_paths, categorical=True, num_classes=6) + self.masks = torch.from_numpy(mask_array) + + # Transpose masks from [B, H, W, C] to PyTorch format [B, C, H, W] + self.masks = torch.permute(self.masks, (0, 3, 1, 2)) + + print("Images shape: ", self.images.shape) + print("Masks shape: ", self.masks.shape) + + def __len__(self): + return self.dataset_size + + def __getitem__(self, idx): + """Retrieve a single image-mask pair, optionally with augmentation.""" + image, mask = self.images[idx], self.masks[idx] + + # Ensure image has channel dimension + if image.ndim == 2: + image = image.unsqueeze(0) + + # Apply augmentations if enabled + if self.augment: + image, mask = self._apply_augmentations(image, mask) + + return image, mask + + def _apply_augmentations(self, image, mask): + """Apply random geometric augmentations to both image and mask.""" + # Horizontal flip with 50% probability + if random.random() > 0.5: + image = F.hflip(image) + mask = F.hflip(mask) + + # Vertical flip with 50% probability + if random.random() > 0.5: + image = F.vflip(image) + mask = F.vflip(mask) + + # Random rotation between -15 and +15 degrees + rotation_angle = random.uniform(-15, 15) + image = F.rotate(image, rotation_angle, interpolation=transforms.InterpolationMode.BILINEAR) + mask = F.rotate(mask, rotation_angle, interpolation=transforms.InterpolationMode.NEAREST) + + return image, mask + + +def get_image_path(dataset, type): + """ + Construct glob pattern for HipMRI Study dataset files. + + Parameters: + dataset: Which split to use ('train', 'validate', or 'test') + type: Either 'image' or 'mask' for respective file types + + Returns: + String glob pattern suitable for DataSegmenter initialization + + """ + # Define base directory structure + base_path = "home/groups/comp3710/HipMRI_Study_open/keras_slices_data" + + # Build path based on dataset split and type + split_mapping = { + 'train': 'train', + 'validate': 'validate', + 'test': 'test' + } + + # Default to train if invalid + split = split_mapping.get(dataset, 'train') + + # Construct the appropriate subdirectory + if type == 'mask': + subdir = f"keras_slices_seg_{split}" + else: + subdir = f"keras_slices_{split}" + + # Return full glob pattern + return f"{base_path}/{subdir}/*.nii.gz" + + +if __name__ == "__main__": + # Command-line interface for testing DataSegmenter + parser = argparse.ArgumentParser(description="Quick test / demo for DataSegmenter") + parser.add_argument("--dataset", choices=("train", "validate", "test"), default="train", + help="Which HipMRI split to use") + parser.add_argument("--subset", type=int, default=4, help="How many samples to load (for quick tests)") + parser.add_argument("--start", type=int, default=0, help="Start index into the sorted file list") + parser.add_argument("--augment", action="store_true", help="Enable simple augmentations") + parser.add_argument("--batch-size", type=int, default=2, help="Batch size for the DataLoader") + args = parser.parse_args() + + # Get file patterns for selected dataset split + img_pattern = get_image_path(args.dataset, "image") + mask_pattern = get_image_path(args.dataset, "mask") + + print("Image pattern:", img_pattern) + print("Mask pattern:", mask_pattern) + + # Instantiate dataset + ds = DataSegmenter( + img_pattern, + mask_pattern, + subset_size=args.subset, + start_index=args.start, + augment=args.augment + ) + + # Create DataLoader + dl = DataLoader(ds, batch_size=args.batch_size, shuffle=False) + + # Test loading a batch + try: + batch = next(iter(dl)) + images, masks = batch + print("Loaded batch:") + print(" images shape:", images.shape, "dtype:", images.dtype) + print(" masks shape:", masks.shape, "dtype:", masks.dtype) + + # Display basic statistics + try: + img_min, img_max, img_mean = float(images.min()), float(images.max()), float(images.mean()) + print(f" image min/max/mean: {img_min} {img_max} {img_mean}") + + # Calculate class pixel counts across batch and spatial dimensions + class_counts = masks.sum(dim=(0, 2, 3)).tolist() + print(" mask class counts (per channel):", class_counts) + except Exception: + pass + except StopIteration: + print("No data returned by DataLoader (empty dataset).") + except Exception as e: + print("Error while loading a batch:", repr(e)) \ No newline at end of file diff --git a/modules.py b/modules.py new file mode 100644 index 000000000..4c7464e8e --- /dev/null +++ b/modules.py @@ -0,0 +1,115 @@ +import torch + +import torch.nn as nn +import torch.nn.functional as F + +# The following code is inspired by Unet Segmentation code demo on Blackboard +# https://colab.research.google.com/drive/1VOsZSyRhyuHLmgoqGriQk01ub4bKNmZ1?usp=sharing + +class ResidualBlock(nn.Module): + """Residual block with skip connection for U-Net.""" + def __init__(self, in_channels, out_channels): + super(ResidualBlock, self).__init__() + + # Main path: Conv -> BN -> ReLU -> Conv -> BN + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(out_channels) + + # Skip connection: 1x1 conv if channel dimensions don't match + self.skip_connection = nn.Sequential() + if in_channels != out_channels: + self.skip_connection = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(out_channels) + ) + + def forward(self, x): + identity = self.skip_connection(x) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + # Add skip connection + out += identity + out = self.relu(out) + + return out + +class UNet(nn.Module): + def __init__(self, in_channels, out_channels): + super(UNet, self).__init__() + # Encoder layers with residual blocks + self.encoder1 = ResidualBlock(in_channels, 64) + self.encoder2 = ResidualBlock(64, 128) + self.encoder3 = ResidualBlock(128, 256) + self.encoder4 = ResidualBlock(256, 512) + self.bottleneck = ResidualBlock(512, 1024) + + # Decoder layers with residual blocks + self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) + self.decoder4 = ResidualBlock(1024, 512) + self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) + self.decoder3 = ResidualBlock(512, 256) + self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) + self.decoder2 = ResidualBlock(256, 128) + self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) + self.decoder1 = ResidualBlock(128, 64) + + self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1) + + def encode(self, x): + """ + Encoder path: progressively downsample and increase feature channels. + Returns the bottleneck features and skip connections. + """ + # Encoder with skip connections + enc1 = self.encoder1(x) + enc2 = self.encoder2(F.max_pool2d(enc1, kernel_size=2)) + enc3 = self.encoder3(F.max_pool2d(enc2, kernel_size=2)) + enc4 = self.encoder4(F.max_pool2d(enc3, kernel_size=2)) + bottleneck = self.bottleneck(F.max_pool2d(enc4, kernel_size=2)) + + # Return bottleneck and skip connections + return bottleneck, (enc1, enc2, enc3, enc4) + + def decode(self, bottleneck, skip_connections): + """ + Decoder path: progressively upsample and combine with skip connections. + Takes bottleneck features and skip connections from encoder. + """ + enc1, enc2, enc3, enc4 = skip_connections + + # Decoder with skip connections + dec4 = self.upconv4(bottleneck) + dec4 = torch.cat((enc4, dec4), dim=1) + dec4 = self.decoder4(dec4) + + dec3 = self.upconv3(dec4) + dec3 = torch.cat((enc3, dec3), dim=1) + dec3 = self.decoder3(dec3) + + dec2 = self.upconv2(dec3) + dec2 = torch.cat((enc2, dec2), dim=1) + dec2 = self.decoder2(dec2) + + dec1 = self.upconv1(dec2) + dec1 = torch.cat((enc1, dec1), dim=1) + dec1 = self.decoder1(dec1) + + return dec1 + + def forward(self, x): + """ + Forward pass through the complete U-Net architecture. + """ + bottleneck, skip_connections = self.encode(x) + decoded = self.decode(bottleneck, skip_connections) + return self.final_conv(decoded) + + \ No newline at end of file diff --git a/outputs/prediction_000.png b/outputs/prediction_000.png new file mode 100644 index 000000000..31fc1570c Binary files /dev/null and b/outputs/prediction_000.png differ diff --git a/outputs/prediction_001.png b/outputs/prediction_001.png new file mode 100644 index 000000000..f2871daa0 Binary files /dev/null and b/outputs/prediction_001.png differ diff --git a/outputs/prediction_002.png b/outputs/prediction_002.png new file mode 100644 index 000000000..2bca9d269 Binary files /dev/null and b/outputs/prediction_002.png differ diff --git a/outputs/prediction_003.png b/outputs/prediction_003.png new file mode 100644 index 000000000..57685b247 Binary files /dev/null and b/outputs/prediction_003.png differ diff --git a/outputs/prediction_004.png b/outputs/prediction_004.png new file mode 100644 index 000000000..a5b2e6ac1 Binary files /dev/null and b/outputs/prediction_004.png differ diff --git a/outputs/training_curves.png b/outputs/training_curves.png new file mode 100644 index 000000000..7392639a6 Binary files /dev/null and b/outputs/training_curves.png differ diff --git a/predict.py b/predict.py new file mode 100644 index 000000000..8b8a3e459 --- /dev/null +++ b/predict.py @@ -0,0 +1,269 @@ +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from tqdm import tqdm +import numpy as np +import os +import matplotlib.pyplot as plt + +from modules import UNet +from dataset import DataSegmenter, get_image_path + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f'Predict Using device: {device}') + + +def compute_dice_score(logits: torch.Tensor, targets: torch.Tensor) -> float: + """ + Calculate Dice score for multi-class segmentation. + + Args: + logits: Raw model outputs [B, C, H, W] (before softmax) + targets: One-hot encoded ground truth [B, C, H, W] + + Returns: + Mean Dice score across all classes (float) + """ + with torch.no_grad(): + # Get predicted classes + preds = torch.argmax(logits, dim=1) # [B, H, W] + targets_class = torch.argmax(targets, dim=1) # [B, H, W] + + # Calculate Dice per class and average + num_classes = logits.shape[1] + dice_scores = [] + + for c in range(num_classes): + pred_c = (preds == c).float() + target_c = (targets_class == c).float() + + intersection = (pred_c * target_c).sum() + union = pred_c.sum() + target_c.sum() + + if union == 0: + dice_scores.append(1.0 if intersection == 0 else 0.0) + else: + dice = (2.0 * intersection + 1e-6) / (union + 1e-6) + dice_scores.append(dice.item()) + + # Return mean Dice across all classes + return sum(dice_scores) / len(dice_scores) + + +def evaluate_model(model_path, test_loader): + """ + Loads a saved model and evaluates Dice score on the test dataset. + + Args: + model_path (str): Path to the saved model .pth file + test_loader (DataLoader): Test data loader + + Returns: + float: Mean Dice score across test set + """ + print("="*70) + print("MODEL EVALUATION") + print("="*70) + print(f"Loading model from: {model_path}") + + # Load model with 6 output channels for multi-class segmentation + model = UNet(in_channels=1, out_channels=6) + checkpoint = torch.load(model_path, map_location=device) + + # Handle different checkpoint formats + if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: + model.load_state_dict(checkpoint['model_state_dict']) + print(f"Model loaded from checkpoint (epoch {checkpoint.get('history', {}).get('epochs_completed', 'unknown')})") + if 'best_dice' in checkpoint: + print(f"Best validation Dice from training: {checkpoint['best_dice']:.4f}") + else: + model.load_state_dict(checkpoint) + print("Model loaded successfully") + + model.to(device) + model.eval() + + # Evaluate + dice_scores = [] + + print("\nEvaluating on test set...") + with torch.no_grad(): + for images, masks in tqdm(test_loader, desc="Evaluating"): + images, masks = images.to(device), masks.to(device) + + # Forward pass (logits for multi-class) + logits = model(images) + + # Calculate Dice for this batch + dice = compute_dice_score(logits, masks) + dice_scores.append(dice) + + # Calculate statistics + mean_dice = np.mean(dice_scores) + std_dice = np.std(dice_scores) + min_dice = np.min(dice_scores) + max_dice = np.max(dice_scores) + + # Report results + print("\n" + "="*70) + print("EVALUATION RESULTS") + print("="*70) + print(f"Number of batches: {len(dice_scores)}") + print(f"Mean Dice Score: {mean_dice:.4f} ± {std_dice:.4f}") + print(f"Min Dice Score: {min_dice:.4f}") + print(f"Max Dice Score: {max_dice:.4f}") + print("="*70 + "\n") + + return mean_dice + + +def save_prediction_visualisations(model, test_loader, output_dir='./outputs/predictions', max_samples=10): + """ + Save prediction visualizations comparing input, ground truth, and predictions. + + Args: + model: Trained model + test_loader: DataLoader for test set + output_dir: Directory to save visualization images + max_samples: Maximum number of samples to visualize + """ + os.makedirs(output_dir, exist_ok=True) + + # Define colors for each class + class_colors = [ + (0.0, 0.0, 0.0), # class 0 - black (background) + (1.0, 0.0, 0.0), # class 1 - red + (0.0, 1.0, 0.0), # class 2 - green + (0.0, 0.0, 1.0), # class 3 - blue + (1.0, 1.0, 0.0), # class 4 - yellow + (1.0, 0.0, 1.0), # class 5 - magenta + ] + + def colorize_mask(mask_idx): + """Convert class indices to RGB image.""" + h, w = mask_idx.shape + rgb = np.zeros((h, w, 3), dtype=np.float32) + for c, color in enumerate(class_colors): + rgb[mask_idx == c] = color + return rgb + + model.eval() + sample_count = 0 + + print(f"\nSaving visualizations to {output_dir}...") + + with torch.no_grad(): + for batch_idx, (images, masks) in enumerate(test_loader): + if sample_count >= max_samples: + break + + images, masks = images.to(device), masks.to(device) + logits = model(images) + preds = torch.argmax(logits, dim=1) # [B, H, W] + + # Process each image in batch + for i in range(images.shape[0]): + if sample_count >= max_samples: + break + + # Get single sample + img = images[i, 0].cpu().numpy() # [H, W] + gt_mask = torch.argmax(masks[i], dim=0).cpu().numpy() # [H, W] + pred_mask = preds[i].cpu().numpy() # [H, W] + + # Normalize image for display + if img.max() - img.min() > 1e-6: + img_display = (img - img.min()) / (img.max() - img.min()) + else: + img_display = img + + # Colorize masks + gt_colored = colorize_mask(gt_mask) + pred_colored = colorize_mask(pred_mask) + + # Create figure + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + + axes[0].imshow(img_display, cmap='gray') + axes[0].set_title('Input Image', fontsize=12, fontweight='bold') + axes[0].axis('off') + + axes[1].imshow(gt_colored) + axes[1].set_title('Ground Truth', fontsize=12, fontweight='bold') + axes[1].axis('off') + + axes[2].imshow(pred_colored) + axes[2].set_title('Prediction', fontsize=12, fontweight='bold') + axes[2].axis('off') + + plt.tight_layout() + + # Save figure + output_path = os.path.join(output_dir, f'prediction_{sample_count:03d}.png') + plt.savefig(output_path, dpi=150, bbox_inches='tight') + plt.close(fig) + + sample_count += 1 + + print(f"✓ Saved {sample_count} visualization images to {output_dir}") + + +def main(): + """Main execution function.""" + print("="*70) + print("LOADING TEST DATASET") + print("="*70) + + # Get file paths for test set + test_img_path = get_image_path('test', 'image') + test_mask_path = get_image_path('test', 'mask') + + print(f'Test images: {test_img_path}') + print(f'Test masks: {test_mask_path}') + + # Create test dataset + test_dataset = DataSegmenter(test_img_path, test_mask_path, augment=False) + + # Create test dataloader + test_loader = DataLoader( + test_dataset, + batch_size=16, + shuffle=False, + num_workers=0, + pin_memory=False + ) + + print(f'\n✓ Test data loaded successfully') + print(f' Test batches: {len(test_loader)}') + print() + + # Evaluate the model + model_path = './outputs/best_model.pth' + + # Load model for visualidation + print("="*70) + print("LOADING MODEL FOR VISUALIZATION") + print("="*70) + model = UNet(in_channels=1, out_channels=6) + checkpoint = torch.load(model_path, map_location=device) + + if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: + model.load_state_dict(checkpoint['model_state_dict']) + else: + model.load_state_dict(checkpoint) + + model.to(device) + model.eval() + print("Model loaded successfully\n") + + save_prediction_visualisations(model, test_loader, output_dir='./outputs/predictions', max_samples=5) + + # Evaluate Dice scores + mean_dice = evaluate_model(model_path, test_loader) + + print(f"\nFinal Test Dice Score: {mean_dice:.4f}") + print(f"Prediction images saved to: ./outputs/predictions/") + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 000000000..a25adf343 --- /dev/null +++ b/train.py @@ -0,0 +1,380 @@ +import os +import torch +import torch.nn as nn +import torch.optim as optim +from pathlib import Path +import matplotlib.pyplot as plt +from modules import UNet +from dataset import DataSegmenter, get_image_path +from torch.utils.data import DataLoader + + + +# Configuration +CONFIG = { + 'batch_size': 16, + 'epochs': 20, + 'learning_rate': 1e-4, + 'image_size': (128, 128), + 'checkpoint_dir': './outputs', + 'device': 'cuda' if torch.cuda.is_available() else 'cpu' +} + + +class DiceLoss(nn.Module): + """Dice Loss for multi-class segmentation. + + Computes the Dice loss across all classes and averages them. + """ + def __init__(self, smooth=1e-6): + super(DiceLoss, self).__init__() + self.smooth = smooth + + def forward(self, logits, targets): + """ + Args: + logits: Raw model outputs [B, C, H, W] (before softmax) + targets: One-hot encoded ground truth [B, C, H, W] + + Returns: + torch.Tensor: Dice loss value (1 - Dice Coefficient) + """ + # Apply softmax to get probabilities + probs = torch.softmax(logits, dim=1) + + # Flatten spatial dimensions + probs = probs.view(probs.size(0), probs.size(1), -1) # [B, C, H*W] + targets = targets.view(targets.size(0), targets.size(1), -1) # [B, C, H*W] + + # Calculate Dice per class + intersection = (probs * targets).sum(dim=2) # [B, C] + union = probs.sum(dim=2) + targets.sum(dim=2) # [B, C] + + dice_per_class = (2.0 * intersection + self.smooth) / (union + self.smooth) # [B, C] + + # Average across classes and batch + dice_coeff = dice_per_class.mean() + + # Return Dice Loss (1 - Dice Coefficient) + return 1.0 - dice_coeff + + +class ModelTrainer: + """Training manager for U-Net segmentation.""" + + def __init__(self, config: dict): + self.cfg = config + self.device = config['device'] + + # Initialize model with 6 output channels for multi-class segmentation + self.model = UNet(in_channels=1, out_channels=6).to(self.device) + + # Loss and optimizer - using combined Dice + CrossEntropy for better performance + self.dice_loss = DiceLoss() + self.ce_loss = nn.CrossEntropyLoss() + self.optimizer = optim.Adam(self.model.parameters(), lr=config['learning_rate']) + + # History tracking + self.history = { + 'train_loss': [], + 'val_loss': [], + 'train_dice': [], + 'val_dice': [], + 'epochs_completed': 0 + } + + self.best_dice = 0.0 + + # Setup save directory + Path(config['checkpoint_dir']).mkdir(parents=True, exist_ok=True) + + def combined_loss(self, logits, targets): + """Combine Dice Loss and CrossEntropy Loss for robust training.""" + dice = self.dice_loss(logits, targets) + ce = self.ce_loss(logits, targets) + # Weight the losses: 60% Dice, 40% CrossEntropy + return 0.6 * dice + 0.4 * ce + + def compute_dice_metric(self, logits: torch.Tensor, targets: torch.Tensor) -> float: + """Calculate Dice score for multi-class segmentation.""" + with torch.no_grad(): + # Get predicted classes + preds = torch.argmax(logits, dim=1) # [B, H, W] + targets_class = torch.argmax(targets, dim=1) # [B, H, W] + + # Calculate Dice per class and average + num_classes = logits.shape[1] + dice_scores = [] + + for c in range(num_classes): + pred_c = (preds == c).float() + target_c = (targets_class == c).float() + + intersection = (pred_c * target_c).sum() + union = pred_c.sum() + target_c.sum() + + if union == 0: + dice_scores.append(1.0 if intersection == 0 else 0.0) + else: + dice = (2.0 * intersection + 1e-6) / (union + 1e-6) + dice_scores.append(dice.item()) + + # Return mean Dice across all classes + return sum(dice_scores) / len(dice_scores) + + def run_training_epoch(self, dataloader) -> tuple: + """Execute one training epoch.""" + self.model.train() + loss_accumulator = 0.0 + dice_accumulator = 0.0 + batch_count = 0 + + for batch_num, (imgs, msks) in enumerate(dataloader, start=1): + imgs, msks = imgs.to(self.device), msks.to(self.device) + + # Forward and backward + self.optimizer.zero_grad() + outputs = self.model(imgs) + loss = self.combined_loss(outputs, msks) + loss.backward() + self.optimizer.step() + + # Track metrics + loss_accumulator += loss.item() + dice_score = self.compute_dice_metric(outputs, msks) + dice_accumulator += dice_score + batch_count += 1 + + if batch_num % 10 == 0: + print(f' [Batch {batch_num}/{len(dataloader)}] Loss: {loss.item():.4f}, Dice: {dice_score:.4f}') + + avg_loss = loss_accumulator / batch_count + avg_dice = dice_accumulator / batch_count + return avg_loss, avg_dice + + def run_validation_epoch(self, dataloader) -> tuple: + """Execute one validation epoch.""" + self.model.eval() + loss_accumulator = 0.0 + dice_accumulator = 0.0 + batch_count = 0 + + with torch.no_grad(): + for imgs, msks in dataloader: + imgs, msks = imgs.to(self.device), msks.to(self.device) + + outputs = self.model(imgs) + loss = self.combined_loss(outputs, msks) + + loss_accumulator += loss.item() + dice_score = self.compute_dice_metric(outputs, msks) + dice_accumulator += dice_score + batch_count += 1 + + avg_loss = loss_accumulator / batch_count + avg_dice = dice_accumulator / batch_count + return avg_loss, avg_dice + + def save_model(self, filename: str, is_best: bool = False): + """Save model checkpoint.""" + save_path = Path(self.cfg['checkpoint_dir']) / filename + + checkpoint_data = { + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'history': self.history, + 'best_dice': self.best_dice, + 'config': self.cfg + } + + torch.save(checkpoint_data, save_path) + + if is_best: + best_path = Path(self.cfg['checkpoint_dir']) / 'best_model.pth' + torch.save(checkpoint_data, best_path) + print(f' ✓ Best model saved (Dice: {self.best_dice:.4f})') + + def fit(self, train_loader, val_loader): + """Main training loop.""" + print('\n' + '='*70) + print('HIP MRI MULTI-CLASS SEGMENTATION TRAINING') + print('='*70) + print(f'Device: {self.device}') + print(f'Epochs: {self.cfg["epochs"]}') + print(f'Batch Size: {self.cfg["batch_size"]}') + print(f'Learning Rate: {self.cfg["learning_rate"]}') + print('='*70 + '\n') + + for epoch in range(1, self.cfg['epochs'] + 1): + print(f'Epoch {epoch}/{self.cfg["epochs"]}') + print('-' * 50) + + # Train phase + print(' Training...') + train_loss, train_dice = self.run_training_epoch(train_loader) + + # Validation phase + print(' Validating...') + val_loss, val_dice = self.run_validation_epoch(val_loader) + + # Update history + self.history['train_loss'].append(train_loss) + self.history['val_loss'].append(val_loss) + self.history['train_dice'].append(train_dice) + self.history['val_dice'].append(val_dice) + self.history['epochs_completed'] = epoch + + # Print epoch summary + print(f'\n Summary:') + print(f' Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}') + print(f' Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}') + + # Save best model + is_best = val_dice > self.best_dice + if is_best: + self.best_dice = val_dice + self.save_model('checkpoint_latest.pth', is_best=True) + else: + print(f' Best Dice: {self.best_dice:.4f}') + + print() + + # Training complete + print('='*70) + print('TRAINING COMPLETED') + print('='*70) + print(f'Best Validation Dice: {self.best_dice:.4f}') + print('='*70 + '\n') + + +def evaluate_on_test(trainer: ModelTrainer, test_loader): + """Final evaluation on test set.""" + print('='*70) + print('TEST SET EVALUATION') + print('='*70) + + test_dice = trainer.run_validation_epoch(test_loader) + + print(f'\nTest Set Dice Coefficient: {test_dice:.4f}') + print('='*70 + '\n') + + return test_dice + +def plot_training_curves(history, save_dir='./outputs'): + """Plot and save training/validation loss and Dice curves.""" + Path(save_dir).mkdir(parents=True, exist_ok=True) + + epochs = range(1, len(history['train_loss']) + 1) + + # Plot Loss + plt.figure(figsize=(10, 5)) + plt.subplot(1, 2, 1) + plt.plot(epochs, history['train_loss'], 'b-o', label='Train Loss', linewidth=2, markersize=6) + plt.plot(epochs, history['val_loss'], 'r-s', label='Val Loss', linewidth=2, markersize=6) + plt.title('Training and Validation Loss', fontsize=14, fontweight='bold') + plt.xlabel('Epoch', fontsize=12) + plt.ylabel('Loss', fontsize=12) + plt.legend(fontsize=10) + plt.grid(True, alpha=0.3) + + # Plot Dice Score + plt.subplot(1, 2, 2) + plt.plot(epochs, history['train_dice'], 'b-o', label='Train Dice', linewidth=2, markersize=6) + plt.plot(epochs, history['val_dice'], 'r-s', label='Val Dice', linewidth=2, markersize=6) + plt.title('Training and Validation Dice Score', fontsize=14, fontweight='bold') + plt.xlabel('Epoch', fontsize=12) + plt.ylabel('Dice Score', fontsize=12) + plt.legend(fontsize=10) + plt.grid(True, alpha=0.3) + + plt.tight_layout() + save_path = Path(save_dir) / 'training_curves.png' + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f'Training curves saved to: {save_path}') + plt.close() + +def main(): + """Main execution function.""" + print('='*70) + print('LOADING DATASETS') + print('='*70) + + # Get file paths for train, validation, and test sets + train_img_path = get_image_path('train', 'image') + train_mask_path = get_image_path('train', 'mask') + val_img_path = get_image_path('validate', 'image') + val_mask_path = get_image_path('validate', 'mask') + test_img_path = get_image_path('test', 'image') + test_mask_path = get_image_path('test', 'mask') + + print(f'Train images: {train_img_path}') + print(f'Train masks: {train_mask_path}') + print(f'Val images: {val_img_path}') + print(f'Val masks: {val_mask_path}') + print(f'Test images: {test_img_path}') + print(f'Test masks: {test_mask_path}') + print() + + # Create datasets + train_dataset = DataSegmenter(train_img_path, train_mask_path, augment=True) + val_dataset = DataSegmenter(val_img_path, val_mask_path, augment=False) + test_dataset = DataSegmenter(test_img_path, test_mask_path, augment=False) + + # Create dataloaders + train_loader = DataLoader( + train_dataset, + batch_size=CONFIG['batch_size'], + shuffle=True, + num_workers=0, + pin_memory=False + ) + val_loader = DataLoader( + val_dataset, + batch_size=CONFIG['batch_size'], + shuffle=False, + num_workers=0, + pin_memory=False + ) + test_loader = DataLoader( + test_dataset, + batch_size=CONFIG['batch_size'], + shuffle=False, + num_workers=0, + pin_memory=False + ) + + print(f'✓ Data loaded successfully') + print(f' Train batches: {len(train_loader)}') + print(f' Val batches: {len(val_loader)}') + print(f' Test batches: {len(test_loader)}') + print() + + # Initialize trainer + trainer = ModelTrainer(CONFIG) + + # Train model + trainer.fit(train_loader, val_loader) + + # Evaluate on test set + print('='*70) + print('TEST SET EVALUATION') + print('='*70) + test_loss, test_dice = trainer.run_validation_epoch(test_loader) + print(f'\nTest Loss: {test_loss:.4f}') + print(f'Test Dice: {test_dice:.4f}') + print('='*70 + '\n') + + # Plot and save training curves + plot_training_curves(trainer.history, save_dir=CONFIG['checkpoint_dir']) + + # Final summary + print('FINAL SUMMARY') + print('='*70) + print(f'Best Validation Dice: {trainer.best_dice:.4f}') + print(f'Test Set Dice: {test_dice:.4f}') + print(f'Model saved to: {CONFIG["checkpoint_dir"]}/best_model.pth') + print(f'Plots saved to: {CONFIG["checkpoint_dir"]}/training_curves.png') + print('='*70) + +if __name__ == '__main__': + main() \ No newline at end of file