diff --git a/recognition/README.md b/recognition/README.md deleted file mode 100644 index 32c99e899..000000000 --- a/recognition/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Recognition Tasks -Various recognition tasks solved in deep learning frameworks. - -Tasks may include: -* Image Segmentation -* Object detection -* Graph node classification -* Image super resolution -* Disease classification -* Generative modelling with StyleGAN and Stable Diffusion \ No newline at end of file diff --git a/recognition/s48027522-HipMRI-2DUnet/README.md b/recognition/s48027522-HipMRI-2DUnet/README.md new file mode 100644 index 000000000..5ce8584aa --- /dev/null +++ b/recognition/s48027522-HipMRI-2DUnet/README.md @@ -0,0 +1,21 @@ +# Prostate MRI Segmentation with 2D U-Net + +## Project Overview + +This project implements a 2D UNet CNN architecture for segmenting prostates from MRI scans using a **2D U-Net** architecture. The workflow is designed to handle **pre-processed 2D slices of prostate MRIs**, enabling efficient training and evaluation of segmentation models. + +The goal is to accurately delineate the prostate from surrounding tissues, which is critical for clinical applications such as **radiotherapy planning, disease diagnosis, and progression monitoring**. + +--- + +## Features + +- **2D Slice-Based Training**: Uses individual 2D slices extracted from 3D MRI volumes, enabling faster training and lower memory requirements. +- **U-Net Architecture**: Lightweight **2D U-Net** with encoder-decoder structure and skip connections for high-resolution segmentation. +- **Binary Dice Loss**: Implements a **Dice similarity coefficient loss** for robust training in binary segmentation (prostate vs. background). +- **Data Augmentation**: Supports **random flips, rotations, and z-score normalization** for better generalization. +- **Early Stopping**: Training halts automatically when the **Dice coefficient of the prostate exceeds 0.75**, avoiding overfitting. +- **Visualization Tools**: Includes slice-level visualizations of input images, ground truth masks, and model predictions. +- **Modular Design**: Easily replaceable components for experimenting with different models, loss functions, and transforms. + +--- \ No newline at end of file diff --git a/recognition/s48027522-HipMRI-2DUnet/dataset.py b/recognition/s48027522-HipMRI-2DUnet/dataset.py new file mode 100644 index 000000000..78be0b200 --- /dev/null +++ b/recognition/s48027522-HipMRI-2DUnet/dataset.py @@ -0,0 +1,90 @@ +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F + +import numpy as np +import os +import nibabel as nib + +def zScoreNormalize(image): + mean = image.mean() + std = image.std() + if std > 0: + image = (image - mean) / std + else: + image = image - mean + return image + +def RandomFlip(image, mask): + axes = [0, 1, 2] # D, H, W axes + for axis in axes: + if np.random.rand() > 0.5: + image = np.flip(image, axis=axis) + mask = np.flip(mask, axis=axis) + return image, mask + +def RandomRotate_90(image, mask): + k = np.random.randint(0, 4) # 0, 90, 180, 270 degrees + axes = (1, 2) # rotate in-plane (H, W) + image = np.rot90(image, k, axes) + mask = np.rot90(mask, k, axes) + return image, mask + +def TrainingTransform(image, mask): + image, mask = RandomFlip(image, mask) + image, mask = RandomRotate_90(image, mask) + image = zScoreNormalize(image) + + return image, mask + +def TestTransform(image, mask): + image = zScoreNormalize(image) + + return image, mask + +def Resize3dTensor(img_tensor, target_shape=(128,128,128), mode_type='trilinear'): + """ + img_tensor: torch tensor of shape (C, D, H, W) + """ + img_tensor = img_tensor.unsqueeze(0) # add batch dim + img_resized = F.interpolate(img_tensor, size=target_shape, mode=mode_type) + return img_resized.squeeze(0) + +def to_channels(label_slice, dtype=np.float32): + """Convert 2D label slice to one-hot channels.""" + num_classes = int(label_slice.max()) + 1 + out = np.zeros((num_classes,) + label_slice.shape, dtype=dtype) + for c in range(num_classes): + out[c] = (label_slice == c) + return out + +class HipMriDataset2D(Dataset): + """Dataset for pre-saved 2D slices.""" + + def __init__(self, image_path, mask_path, transform=None): + self.image_paths = sorted([os.path.join(image_path, f) for f in os.listdir(image_path)]) + self.mask_paths = sorted([os.path.join(mask_path, f) for f in os.listdir(mask_path)]) + self.transform = transform + + assert len(self.image_paths) == len(self.mask_paths), "Number of images and masks must match" + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + # Load 2D slice + image = nib.load(self.image_paths[idx]).get_fdata().astype(np.float32) + mask = nib.load(self.mask_paths[idx]).get_fdata().astype(np.uint8) + + # Apply transforms + if self.transform: + image, mask = self.transform(image, mask) + + # Convert mask to one-hot channels + mask = to_channels(mask, dtype=np.uint8) + + # Convert to tensors + image_tensor = torch.from_numpy(image).unsqueeze(0).float() # [1, H, W] + mask_tensor = torch.from_numpy(mask).float() # [C, H, W] + + return image_tensor, mask_tensor \ No newline at end of file diff --git a/recognition/s48027522-HipMRI-2DUnet/module.py b/recognition/s48027522-HipMRI-2DUnet/module.py new file mode 100644 index 000000000..61c2ea400 --- /dev/null +++ b/recognition/s48027522-HipMRI-2DUnet/module.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# ----------------------------- +# 2D Convolutional Block +# ----------------------------- +class ConvBlock2D(nn.Module): + def __init__(self, in_channels, out_channels, dropout=0.0): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(out_channels) + self.act = nn.ReLU(inplace=True) + self.dropout = nn.Dropout2d(dropout) + + def forward(self, x): + x = self.act(self.bn1(self.conv1(x))) + x = self.dropout(x) + x = self.act(self.bn2(self.conv2(x))) + return x + +# ----------------------------- +# 2D U-Net +# ----------------------------- +class UNet2D(nn.Module): + def __init__(self, in_channels=1, out_channels=1, base_filters=64, dropout=0.0): + super().__init__() + + # Encoder + self.enc1 = ConvBlock2D(in_channels, base_filters, dropout) + self.enc2 = ConvBlock2D(base_filters, base_filters*2, dropout) + self.enc3 = ConvBlock2D(base_filters*2, base_filters*4, dropout) + self.enc4 = ConvBlock2D(base_filters*4, base_filters*8, dropout) + + self.pool = nn.MaxPool2d(2) + + # Bottleneck + self.bottleneck = ConvBlock2D(base_filters*8, base_filters*16, dropout) + + # Decoder + self.up4 = nn.ConvTranspose2d(base_filters*16, base_filters*8, kernel_size=2, stride=2) + self.dec4 = ConvBlock2D(base_filters*16, base_filters*8, dropout) + + self.up3 = nn.ConvTranspose2d(base_filters*8, base_filters*4, kernel_size=2, stride=2) + self.dec3 = ConvBlock2D(base_filters*8, base_filters*4, dropout) + + self.up2 = nn.ConvTranspose2d(base_filters*4, base_filters*2, kernel_size=2, stride=2) + self.dec2 = ConvBlock2D(base_filters*4, base_filters*2, dropout) + + self.up1 = nn.ConvTranspose2d(base_filters*2, base_filters, kernel_size=2, stride=2) + self.dec1 = ConvBlock2D(base_filters*2, base_filters, dropout) + + # Output + self.segmentation_head = nn.Conv2d(base_filters, out_channels, kernel_size=1) + + def forward(self, x): + # Encoder + e1 = self.enc1(x) + e2 = self.enc2(self.pool(e1)) + e3 = self.enc3(self.pool(e2)) + e4 = self.enc4(self.pool(e3)) + + # Bottleneck + b = self.bottleneck(self.pool(e4)) + + # Decoder + d4 = self.up4(b) + d4 = self.dec4(torch.cat([d4, e4], dim=1)) + + d3 = self.up3(d4) + d3 = self.dec3(torch.cat([d3, e3], dim=1)) + + d2 = self.up2(d3) + d2 = self.dec2(torch.cat([d2, e2], dim=1)) + + d1 = self.up1(d2) + d1 = self.dec1(torch.cat([d1, e1], dim=1)) + + out = self.segmentation_head(d1) # logits + return out + +class BinaryDiceLoss(nn.Module): + """ + Dice Loss for binary segmentation. + + Dice Loss = 1 - (2 * |X ∩ Y| + smooth) / (|X| + |Y| + smooth) + + Works for 2D or 3D tensors: + predictions: [B, 1, H, W] or [B, 1, D, H, W] (logits) + targets: [B, 1, H, W] or [B, 1, D, H, W] (binary 0/1) + """ + def __init__(self, smooth=1e-6): + super().__init__() + self.smooth = smooth + + def forward(self, predictions, targets): + """ + Args: + predictions (torch.Tensor): logits from the model [B, 1, ...] + targets (torch.Tensor): binary ground truth [B, 1, ...] + """ + # Apply sigmoid to convert logits to probabilities + probs = torch.sigmoid(predictions) + targets = targets.float() + + # Flatten spatial dimensions per batch + dims = tuple(range(1, predictions.ndim)) # flatten everything except batch + + intersection = torch.sum(probs * targets, dims) + pred_sum = torch.sum(probs, dims) + target_sum = torch.sum(targets, dims) + + dice_score = (2.0 * intersection + self.smooth) / (pred_sum + target_sum + self.smooth) + dice_loss = 1.0 - dice_score.mean() # average over batch + + return dice_loss diff --git a/recognition/s48027522-HipMRI-2DUnet/predict.py b/recognition/s48027522-HipMRI-2DUnet/predict.py new file mode 100644 index 000000000..bd8732682 --- /dev/null +++ b/recognition/s48027522-HipMRI-2DUnet/predict.py @@ -0,0 +1,49 @@ +import dataset +import matplotlib.pyplot as plt +import torch +import torch.optim as optim +from torch.optim.lr_scheduler import ReduceLROnPlateau +import numpy as np +import torch.nn.functional as F +import random +import train + +from module import UNet2D, BinaryDiceLoss # your 2D U-Net implementation +from dataset import HipMriDataset2D # 2D dataset + DiceLoss + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f'Using device: {device}') + +if __name__ == "__main__": + test_dataset = HipMriDataset2D( + image_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_test', + mask_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_seg_test', + transform=dataset.TrainingTransform2D, + train=True + ) + + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) + + model = UNet2D(in_channels=1, out_channels=1, base_filters=16, dropout=0.1) + + losses = [] + model.eval() + total_dice = 0.0 + with torch.no_grad(): + for images, masks in test_loader: + images, masks = images.to(device), masks.to(device) + outputs = model(images) + probs = torch.sigmoid(outputs) + pred_mask = (probs > 0.5).float() + + # Dice coefficient for the prostate class (assumes class=1) + intersection = (pred_mask * masks).sum(dim=(1,2,3)) + union = pred_mask.sum(dim=(1,2,3)) + masks.sum(dim=(1,2,3)) + dice_score = ((2*intersection + 1e-6) / (union + 1e-6)).mean().item() + total_dice += dice_score + + avg_dice = total_dice / len(test_loader) + + + print("Training complete!") + train.plot_loss(losses) \ No newline at end of file diff --git a/recognition/s48027522-HipMRI-2DUnet/train.py b/recognition/s48027522-HipMRI-2DUnet/train.py new file mode 100644 index 000000000..289d4d296 --- /dev/null +++ b/recognition/s48027522-HipMRI-2DUnet/train.py @@ -0,0 +1,140 @@ +import dataset +import matplotlib.pyplot as plt +import torch +import torch.optim as optim +from torch.optim.lr_scheduler import ReduceLROnPlateau +import numpy as np +import torch.nn.functional as F +import random + +from module import UNet2D, BinaryDiceLoss # your 2D U-Net implementation +from dataset import HipMriDataset2D # 2D dataset + DiceLoss + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f'Using device: {device}') + +# ----------------------------- +# Quick visualization of loss +# ----------------------------- +def plot_loss(losses): + plt.figure(figsize=(8, 4)) + plt.plot(losses, 'bo-', linewidth=2, markersize=8) + plt.title('🔥 Training Loss (Dice)', fontsize=14, fontweight='bold') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.grid(True, alpha=0.3) + plt.show() + +def visualize_slice(image, mask_pred, mask_gt): + """Visualize a single 2D slice with prediction and ground truth""" + _, axes = plt.subplots(1, 3, figsize=(12, 4)) + axes[0].imshow(image[0], cmap='gray') + axes[0].set_title("Input Image") + axes[1].imshow(mask_gt[0], cmap='Reds') + axes[1].set_title("Ground Truth") + axes[2].imshow(mask_pred[0], cmap='Reds') + axes[2].set_title("Predicted Mask") + for ax in axes: + ax.axis('off') + plt.show() + +# ----------------------------- +# Training function +# ----------------------------- +def train(model, train_loader, test_loader, epochs=50, lr=1e-3, dice_threshold=0.75): + model.to(device) + criterion = BinaryDiceLoss() + optimizer = optim.Adam(model.parameters(), lr=lr) + scheduler = ReduceLROnPlateau( + optimizer, + mode='min', + factor=0.5, + patience=5, + min_lr=1e-6 + ) + + losses = [] + + print("Starting 2D slice training...") + + for epoch in range(epochs): + model.train() + epoch_loss = 0.0 + + for images, masks in train_loader: + images, masks = images.to(device), masks.to(device) + + optimizer.zero_grad() + outputs = model(images) # logits + loss = criterion(outputs, masks) + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + + avg_loss = epoch_loss / len(train_loader) + scheduler.step(avg_loss) + losses.append(avg_loss) + print(f"Epoch {epoch+1}/{epochs} - Avg Loss: {avg_loss:.4f}") + + # ----------------------------- + # Validation & early stopping + # ----------------------------- + model.eval() + total_dice = 0.0 + with torch.no_grad(): + for images, masks in test_loader: + images, masks = images.to(device), masks.to(device) + outputs = model(images) + probs = torch.sigmoid(outputs) + pred_mask = (probs > 0.5).float() + + # Dice coefficient for the prostate class (assumes class=1) + intersection = (pred_mask * masks).sum(dim=(1,2,3)) + union = pred_mask.sum(dim=(1,2,3)) + masks.sum(dim=(1,2,3)) + dice_score = ((2*intersection + 1e-6) / (union + 1e-6)).mean().item() + total_dice += dice_score + + avg_dice = total_dice / len(test_loader) + print(f"Validation Dice (Prostate) after Epoch {epoch+1}: {avg_dice:.4f}") + + # Early stopping if prostate Dice exceeds threshold + if avg_dice >= dice_threshold: + print(f"Early stopping: Prostate Dice {avg_dice:.4f} >= {dice_threshold}") + # visualize a random slice + sample_img, sample_mask = next(iter(test_loader)) + sample_img = sample_img.to(device) + sample_mask = sample_mask.to(device) + sample_output = model(sample_img) + pred_mask = (torch.sigmoid(sample_output) > 0.5).float() + visualize_slice(sample_img[0].cpu().numpy(), pred_mask[0].cpu().numpy(), sample_mask[0].cpu().numpy()) + break + + print("🎯 Training complete!") + plot_loss(losses) + return losses + +# ----------------------------- +# Main +# ----------------------------- +if __name__ == "__main__": + train_dataset = HipMriDataset2D( + image_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_train', + mask_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_seg_train', + transform=dataset.TrainingTransform2D, + train=True + ) + + test_dataset = HipMriDataset2D( + image_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_validate', + mask_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_seg_validate', + transform=dataset.TestTransform2D, + train=False + ) + + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) + + model = UNet2D(in_channels=1, out_channels=1, base_filters=16, dropout=0.1) + + train(model, train_loader, test_loader, epochs=50, lr=1e-3, dice_threshold=0.75) \ No newline at end of file