Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions recognition/README.md

This file was deleted.

21 changes: 21 additions & 0 deletions recognition/s48027522-HipMRI-2DUnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Prostate MRI Segmentation with 2D U-Net

## Project Overview

This project implements a 2D UNet CNN architecture for segmenting prostates from MRI scans using a **2D U-Net** architecture. The workflow is designed to handle **pre-processed 2D slices of prostate MRIs**, enabling efficient training and evaluation of segmentation models.

The goal is to accurately delineate the prostate from surrounding tissues, which is critical for clinical applications such as **radiotherapy planning, disease diagnosis, and progression monitoring**.

---

## Features

- **2D Slice-Based Training**: Uses individual 2D slices extracted from 3D MRI volumes, enabling faster training and lower memory requirements.
- **U-Net Architecture**: Lightweight **2D U-Net** with encoder-decoder structure and skip connections for high-resolution segmentation.
- **Binary Dice Loss**: Implements a **Dice similarity coefficient loss** for robust training in binary segmentation (prostate vs. background).
- **Data Augmentation**: Supports **random flips, rotations, and z-score normalization** for better generalization.
- **Early Stopping**: Training halts automatically when the **Dice coefficient of the prostate exceeds 0.75**, avoiding overfitting.
- **Visualization Tools**: Includes slice-level visualizations of input images, ground truth masks, and model predictions.
- **Modular Design**: Easily replaceable components for experimenting with different models, loss functions, and transforms.

---
90 changes: 90 additions & 0 deletions recognition/s48027522-HipMRI-2DUnet/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F

import numpy as np
import os
import nibabel as nib

def zScoreNormalize(image):
mean = image.mean()
std = image.std()
if std > 0:
image = (image - mean) / std
else:
image = image - mean
return image

def RandomFlip(image, mask):
axes = [0, 1, 2] # D, H, W axes
for axis in axes:
if np.random.rand() > 0.5:
image = np.flip(image, axis=axis)
mask = np.flip(mask, axis=axis)
return image, mask

def RandomRotate_90(image, mask):
k = np.random.randint(0, 4) # 0, 90, 180, 270 degrees
axes = (1, 2) # rotate in-plane (H, W)
image = np.rot90(image, k, axes)
mask = np.rot90(mask, k, axes)
return image, mask

def TrainingTransform(image, mask):
image, mask = RandomFlip(image, mask)
image, mask = RandomRotate_90(image, mask)
image = zScoreNormalize(image)

return image, mask

def TestTransform(image, mask):
image = zScoreNormalize(image)

return image, mask

def Resize3dTensor(img_tensor, target_shape=(128,128,128), mode_type='trilinear'):
"""
img_tensor: torch tensor of shape (C, D, H, W)
"""
img_tensor = img_tensor.unsqueeze(0) # add batch dim
img_resized = F.interpolate(img_tensor, size=target_shape, mode=mode_type)
return img_resized.squeeze(0)

def to_channels(label_slice, dtype=np.float32):
"""Convert 2D label slice to one-hot channels."""
num_classes = int(label_slice.max()) + 1
out = np.zeros((num_classes,) + label_slice.shape, dtype=dtype)
for c in range(num_classes):
out[c] = (label_slice == c)
return out

class HipMriDataset2D(Dataset):
"""Dataset for pre-saved 2D slices."""

def __init__(self, image_path, mask_path, transform=None):
self.image_paths = sorted([os.path.join(image_path, f) for f in os.listdir(image_path)])
self.mask_paths = sorted([os.path.join(mask_path, f) for f in os.listdir(mask_path)])
self.transform = transform

assert len(self.image_paths) == len(self.mask_paths), "Number of images and masks must match"

def __len__(self):
return len(self.image_paths)

def __getitem__(self, idx):
# Load 2D slice
image = nib.load(self.image_paths[idx]).get_fdata().astype(np.float32)
mask = nib.load(self.mask_paths[idx]).get_fdata().astype(np.uint8)

# Apply transforms
if self.transform:
image, mask = self.transform(image, mask)

# Convert mask to one-hot channels
mask = to_channels(mask, dtype=np.uint8)

# Convert to tensors
image_tensor = torch.from_numpy(image).unsqueeze(0).float() # [1, H, W]
mask_tensor = torch.from_numpy(mask).float() # [C, H, W]

return image_tensor, mask_tensor
118 changes: 118 additions & 0 deletions recognition/s48027522-HipMRI-2DUnet/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

# -----------------------------
# 2D Convolutional Block
# -----------------------------
class ConvBlock2D(nn.Module):
def __init__(self, in_channels, out_channels, dropout=0.0):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.act = nn.ReLU(inplace=True)
self.dropout = nn.Dropout2d(dropout)

def forward(self, x):
x = self.act(self.bn1(self.conv1(x)))
x = self.dropout(x)
x = self.act(self.bn2(self.conv2(x)))
return x

# -----------------------------
# 2D U-Net
# -----------------------------
class UNet2D(nn.Module):
def __init__(self, in_channels=1, out_channels=1, base_filters=64, dropout=0.0):
super().__init__()

# Encoder
self.enc1 = ConvBlock2D(in_channels, base_filters, dropout)
self.enc2 = ConvBlock2D(base_filters, base_filters*2, dropout)
self.enc3 = ConvBlock2D(base_filters*2, base_filters*4, dropout)
self.enc4 = ConvBlock2D(base_filters*4, base_filters*8, dropout)

self.pool = nn.MaxPool2d(2)

# Bottleneck
self.bottleneck = ConvBlock2D(base_filters*8, base_filters*16, dropout)

# Decoder
self.up4 = nn.ConvTranspose2d(base_filters*16, base_filters*8, kernel_size=2, stride=2)
self.dec4 = ConvBlock2D(base_filters*16, base_filters*8, dropout)

self.up3 = nn.ConvTranspose2d(base_filters*8, base_filters*4, kernel_size=2, stride=2)
self.dec3 = ConvBlock2D(base_filters*8, base_filters*4, dropout)

self.up2 = nn.ConvTranspose2d(base_filters*4, base_filters*2, kernel_size=2, stride=2)
self.dec2 = ConvBlock2D(base_filters*4, base_filters*2, dropout)

self.up1 = nn.ConvTranspose2d(base_filters*2, base_filters, kernel_size=2, stride=2)
self.dec1 = ConvBlock2D(base_filters*2, base_filters, dropout)

# Output
self.segmentation_head = nn.Conv2d(base_filters, out_channels, kernel_size=1)

def forward(self, x):
# Encoder
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
e4 = self.enc4(self.pool(e3))

# Bottleneck
b = self.bottleneck(self.pool(e4))

# Decoder
d4 = self.up4(b)
d4 = self.dec4(torch.cat([d4, e4], dim=1))

d3 = self.up3(d4)
d3 = self.dec3(torch.cat([d3, e3], dim=1))

d2 = self.up2(d3)
d2 = self.dec2(torch.cat([d2, e2], dim=1))

d1 = self.up1(d2)
d1 = self.dec1(torch.cat([d1, e1], dim=1))

out = self.segmentation_head(d1) # logits
return out

class BinaryDiceLoss(nn.Module):
"""
Dice Loss for binary segmentation.

Dice Loss = 1 - (2 * |X ∩ Y| + smooth) / (|X| + |Y| + smooth)

Works for 2D or 3D tensors:
predictions: [B, 1, H, W] or [B, 1, D, H, W] (logits)
targets: [B, 1, H, W] or [B, 1, D, H, W] (binary 0/1)
"""
def __init__(self, smooth=1e-6):
super().__init__()
self.smooth = smooth

def forward(self, predictions, targets):
"""
Args:
predictions (torch.Tensor): logits from the model [B, 1, ...]
targets (torch.Tensor): binary ground truth [B, 1, ...]
"""
# Apply sigmoid to convert logits to probabilities
probs = torch.sigmoid(predictions)
targets = targets.float()

# Flatten spatial dimensions per batch
dims = tuple(range(1, predictions.ndim)) # flatten everything except batch

intersection = torch.sum(probs * targets, dims)
pred_sum = torch.sum(probs, dims)
target_sum = torch.sum(targets, dims)

dice_score = (2.0 * intersection + self.smooth) / (pred_sum + target_sum + self.smooth)
dice_loss = 1.0 - dice_score.mean() # average over batch

return dice_loss
49 changes: 49 additions & 0 deletions recognition/s48027522-HipMRI-2DUnet/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import dataset
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import torch.nn.functional as F
import random
import train

from module import UNet2D, BinaryDiceLoss # your 2D U-Net implementation
from dataset import HipMriDataset2D # 2D dataset + DiceLoss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

if __name__ == "__main__":
test_dataset = HipMriDataset2D(
image_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_test',
mask_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_seg_test',
transform=dataset.TrainingTransform2D,
train=True
)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

model = UNet2D(in_channels=1, out_channels=1, base_filters=16, dropout=0.1)

losses = []
model.eval()
total_dice = 0.0
with torch.no_grad():
for images, masks in test_loader:
images, masks = images.to(device), masks.to(device)
outputs = model(images)
probs = torch.sigmoid(outputs)
pred_mask = (probs > 0.5).float()

# Dice coefficient for the prostate class (assumes class=1)
intersection = (pred_mask * masks).sum(dim=(1,2,3))
union = pred_mask.sum(dim=(1,2,3)) + masks.sum(dim=(1,2,3))
dice_score = ((2*intersection + 1e-6) / (union + 1e-6)).mean().item()
total_dice += dice_score

avg_dice = total_dice / len(test_loader)


print("Training complete!")
train.plot_loss(losses)
Loading