diff --git a/recognition/AttUNetOASIS/README.md b/recognition/AttUNetOASIS/README.md new file mode 100644 index 000000000..0119955d7 --- /dev/null +++ b/recognition/AttUNetOASIS/README.md @@ -0,0 +1,26 @@ +# Attention U-Net for OASIS Brain Tissue Segmentation + +## Algorithm Description +This implementation uses an Attention U-Net architecture to perform semantic segmentation of brain MRI scans from the OASIS dataset. It solves the problem of automatic tissue classifcation: that is, given a 2D image of an MRI slice, the model predicts a label for every pixel indicating whether it belongs to the background, cerebrospinal fluid (CSF), gray matter (GM), or white matter (WM). This helps avoid the need to rely traditional methods such as pixel counting and voxel-based morphometry, which are time-consuming and subject to inter-rater variability. + +## How It Works +The Attention U-Net is an improved version of the traditional U-Net neural network. In line with its predecessor, Attention U-Net features an encoder-decoder architecture that uses convolution blocks to perform feature extraction, performs hierarchical downsampling and upsampling, and produces a pixelwise segmentation output, with the added benefit of attention mechanisms. In the context of the OASIS dataset, the encoder progressively downsamples the input MRI slices, extracting hierarchical features from low-level edges to high-level anatomical patterns. The decoder symmentrically upsamples these features back to the original resolution. But critically, before joining encoder features via the skip connections, attention gates are used to dynamically highlight regions with brain tissue while suppressing irrelevant background areas. This is particularly useful for brain segmentation where tissue boundaries can be subtle and structures vary in size. + +The network is trained using a combined Dice and Cross-Entropy loss function, which directly optimises the segmentation quality metric while maintaining stable gradients. It also employs data augmentation via horizontal flipping to improve generalization across the anatomical variability present across patients. The AdamW optimiser is also used to dynamically adjust the learning rate with decoupled weight decay. + +![Attention U‑Net architecture](https://www.researchgate.net/publication/347344899/figure/fig6/AS:971357475069952@1608601077414/The-architecture-of-Attention-U‑Net‑Attention‑gate‑selects‑features‑by‑using‑the.png) +*Figure 1: Attention U-Net architecture. Source: Hwang et al., 2020. Licensed under CC BY-NC 4.0.* + +## Dependencies +Python 3.10.19 was used for this implementation. Packages used and their versions listed below: +| Packages | Version | +| :------- | :------: | +| torch+cu118 | 2.7.1 | +| pillow | 12.0.0 | + +Results are highly reproducible for homogenous data, i.e. MRI scans with unique pixel values for brain regions. + +## Training data +The training process involved using preprocessed slices from 3D OASIS MRI volumes, which served as training targets and were used to measure the loss from the original input. + +**SID** 48915768 diff --git a/recognition/AttUNetOASIS/unet.py b/recognition/AttUNetOASIS/unet.py new file mode 100644 index 000000000..734d3374e --- /dev/null +++ b/recognition/AttUNetOASIS/unet.py @@ -0,0 +1,441 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from PIL import Image +import numpy as np +import os +from pathlib import Path +import time + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# Data paths +DATA_ROOT = '/home/groups/comp3710/OASIS' +TRAIN_IMG_DIR = f'{DATA_ROOT}/keras_png_slices_train' +TRAIN_SEG_DIR = f'{DATA_ROOT}/keras_png_slices_seg_train' +VAL_IMG_DIR = f'{DATA_ROOT}/keras_png_slices_validate' +VAL_SEG_DIR = f'{DATA_ROOT}/keras_png_slices_seg_validate' +TEST_IMG_DIR = f'{DATA_ROOT}/keras_png_slices_test' +TEST_SEG_DIR = f'{DATA_ROOT}/keras_png_slices_seg_test' + +# Map raw PNG values to class indices. Values retrieved at runtime using PIL. +# Expect 4 unique PNG values -> classes from OASIS dataset. +LABEL_MAP = None + +# Hyperparameters +NUM_CLASSES = None # Number of unique pixel values -> brain tissue classes in MRI slices +NUM_EPOCHS = 5 # Small for testing +LEARNING_RATE = 1.0e-4 # Standard for AdamW in segmentation tasks +BATCH_SIZE = 32 # Rangpur uses A100 +IMG_SIZE = None # Extracted from data at runtime +WEIGHT_DECAY = 1.0e-5 # Standard for AdamW + +def detect_label_mapping(seg_dir): + """Auto-detect label values, number of classes, and image size from first segmentation mask""" + seg_file = sorted([f for f in os.listdir(seg_dir) if f.endswith('.png')])[0] + seg = np.array(Image.open(os.path.join(seg_dir, seg_file))) + + unique_values = sorted(np.unique(seg).tolist()) + label_map = {val: idx for idx, val in enumerate(unique_values)} + num_classes = len(unique_values) + img_size = seg.shape[0] + + return label_map, num_classes, img_size + +class OASISDataset(Dataset): + """Custom dataset for OASIS brain MRI slices and segmentation masks""" + + def __init__(self, img_dir, seg_dir, label_map, augment=False): + """ + Dataset constructor + + @param img_dir Directory to load input images from + @param seg_dir Directory to load segmentation masks from + @param label_map Dictionary to maps unique PNG values to class indices + @param augment Whether or not to augment training data (default: faulse) + """ + self.img_dir = img_dir + self.seg_dir = seg_dir + self.label_map = label_map + self.augment = augment + + # Get all image filenames + self.img_files = sorted([f for f in os.listdir(img_dir) if f.endswith('.png')]) + + def __len__(self): + return len(self.img_files) + + def __getitem__(self, idx): + # Load image + img_name = self.img_files[idx] + img_path = os.path.join(self.img_dir, img_name) + img = np.array(Image.open(img_path), dtype=np.float32) + + # Load segmentation mask + seg_name = img_name.replace('case_', 'seg_') + seg_path = os.path.join(self.seg_dir, seg_name) + seg = np.array(Image.open(seg_path), dtype=np.uint8) + + # Map labels: [0, 85, 170, 255] -> [0, 1, 2, 3] + seg_mapped = np.zeros_like(seg) + for old_val, new_val in self.label_map.items(): + seg_mapped[seg == old_val] = new_val + + # Normalize image to [0, 1] + img = img / 255.0 + + # Add channel dimension + img = np.expand_dims(img, axis=0) + + # Convert to tensors + img = torch.from_numpy(img).float() + seg = torch.from_numpy(seg_mapped).long() + + # Basic augmentation + if self.augment: + if torch.rand(1) > 0.5: + img = torch.flip(img, dims=[2]) # Horizontal flip + seg = torch.flip(seg, dims=[1]) + + return img, seg + +class ConvBlock(nn.Module): + """Double convolution block: Conv-BN-ReLU-Conv-BN-ReLU""" + + def __init__(self, in_channels, out_channels): + super(ConvBlock, self).__init__() + self.conv = nn.Sequential( + # First convolution layer + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), # Normalise activations to stabilize training + nn.ReLU(inplace=True), # Nonlinearity for feature learning + + # Second convolution layer + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.conv(x) # Apply sequential block to input tensor + +class AttentionBlock(nn.Module): + """Attention gate for focusing on relevant spatial regions in U-Net skip connections""" + + def __init__(self, F_g, F_l, F_int): + super(AttentionBlock, self).__init__() + self.W_g = nn.Sequential( + """ + @param F_g Number of feature channels in the gating signal (from decoder) + @param F_l Number of feature channels in the skip connection (from encoder) + @param F_int Number of intermediate channels for computing attention + """ + nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), + nn.BatchNorm2d(F_int) + ) + + # 1x1 convolution to project gating signal to intermediate feature space + self.W_x = nn.Sequential( + nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), + nn.BatchNorm2d(F_int) + ) + + # 1x1 convolution to project encoder features to same intermediate space + self.psi = nn.Sequential( + nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), + nn.BatchNorm2d(1), + nn.Sigmoid() + ) + + self.relu = nn.ReLU(inplace=True) # Nonlinearity for combining features + + def forward(self, g, x): + g1 = self.W_g(g) # Transform decoder input (context) + x1 = self.W_x(x) # Transform encoder input (skip feature) + + # Add projected tensors and apply ReLU + psi = self.relu(g1 + x1) + psi = self.psi(psi) # Apply spatial attention weights + + # Apply attention mask to suppress irrelevant regions + return x * psi # Output tensor shape: (batch, 1, H, W) + +class AttentionUNet(nn.Module): + """ + Attention U-Net for semantic segmentation + Encoder-decoder with attention gates before skip connections + """ + + def __init__(self, in_channels=1, num_classes=4, init_features=64): + super(AttentionUNet, self).__init__() + + features = init_features # Base number of feature maps + + # Encoder (downsampling path) + self.enc1 = ConvBlock(in_channels, features) + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # Downsample by 2x + + self.enc2 = ConvBlock(features, features * 2) + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + + self.enc3 = ConvBlock(features * 2, features * 4) + self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) + + self.enc4 = ConvBlock(features * 4, features * 8) + self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) + + # Bottleneck + self.bottleneck = ConvBlock(features * 8, features * 16) # Deepest feature extraction + + # Decoder (upsampling path) with Attention Gates + # At each stage, upsample -> apply attention on encoder feature -> concatenate -> refine + self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2) + self.att4 = AttentionBlock(F_g=features * 8, F_l=features * 8, F_int=features * 4) + self.dec4 = ConvBlock(features * 16, features * 8) + + self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2) + self.att3 = AttentionBlock(F_g=features * 4, F_l=features * 4, F_int=features * 2) + self.dec3 = ConvBlock(features * 8, features * 4) + + self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2) + self.att2 = AttentionBlock(F_g=features * 2, F_l=features * 2, F_int=features) + self.dec2 = ConvBlock(features * 4, features * 2) + + self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2) + self.att1 = AttentionBlock(F_g=features, F_l=features, F_int=features // 2) + self.dec1 = ConvBlock(features * 2, features) + + # Output layer, 1x1 conv for per-pixel class scores + self.out = nn.Conv2d(features, num_classes, kernel_size=1) + + def forward(self, x): + # Encoder + enc1 = self.enc1(x) + enc2 = self.enc2(self.pool1(enc1)) + enc3 = self.enc3(self.pool2(enc2)) + enc4 = self.enc4(self.pool3(enc3)) + + # Bottleneck + bottleneck = self.bottleneck(self.pool4(enc4)) + + # Decoder with Attention Gates + # At each stage, upsample bottleneck -> apply attention on encoder feature + # -> concatenate filtered encoder feature with decoder -> refine combined features + dec4 = self.upconv4(bottleneck) + enc4_att = self.att4(g=dec4, x=enc4) + dec4 = torch.cat((enc4_att, dec4), dim=1) + dec4 = self.dec4(dec4) + + dec3 = self.upconv3(dec4) + enc3_att = self.att3(g=dec3, x=enc3) + dec3 = torch.cat((enc3_att, dec3), dim=1) + dec3 = self.dec3(dec3) + + dec2 = self.upconv2(dec3) + enc2_att = self.att2(g=dec2, x=enc2) + dec2 = torch.cat((enc2_att, dec2), dim=1) + dec2 = self.dec2(dec2) + + dec1 = self.upconv1(dec2) + enc1_att = self.att1(g=dec1, x=enc1) + dec1 = torch.cat((enc1_att, dec1), dim=1) + dec1 = self.dec1(dec1) + + return self.out(dec1) # Per-pixel class predictions + +class DiceLoss(nn.Module): + """Dice loss for segmentation""" + + def __init__(self, smooth=1.0): + super(DiceLoss, self).__init__() + self.smooth = smooth # Small constant to avoid division by zero + + def forward(self, pred, target): + # Apply softmax and convert target labels to one-hot encoding + pred = F.softmax(pred, dim=1) + target_one_hot = F.one_hot(target, num_classes=NUM_CLASSES).permute(0, 3, 1, 2).float() + + # Compute per-class intersection over spatial dimensions + intersection = (pred * target_one_hot).sum(dim=(2, 3)) + union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3)) # Compute per-class union + + # Compute Dice coefficient per class and add smoothing + dice = (2.0 * intersection + self.smooth) / (union + self.smooth) + return 1.0 - dice.mean() + + +class CombinedLoss(nn.Module): + """Combined Cross-Entropy + Dice loss""" + + def __init__(self, weight=None): + super(CombinedLoss, self).__init__() + self.ce_loss = nn.CrossEntropyLoss(weight=weight) # Standard CE loss + self.dice_loss = DiceLoss() # Dice loss for overlap + + def forward(self, pred, target): + # Compute cross-entropy and Dice losses + ce = self.ce_loss(pred, target) + dice = self.dice_loss(pred, target) + + # Combine and return losses + return ce + dice + +def dice_coefficient(pred, target, num_classes): + """Calculate Dice coefficient per class for evaluation""" + dice_scores = [] # Dice scores for each class + + # Convert model predictions to discrete class labels (argmax over channels) + pred = torch.argmax(pred, dim=1) + + for cls in range(num_classes): + # Compute binary masks for current class + pred_cls = (pred == cls).float() # prediction == class ? 1 : 0 + target_cls = (target == cls).float() # target == class ? 1 : 0 + + # Compute intersection and union for Dice + intersection = (pred_cls * target_cls).sum() + union = pred_cls.sum() + target_cls.sum() + + # Handle edge cases where union is 0 + if union == 0: + dice = 1.0 if intersection == 0 else 0.0 + else: + dice = (2.0 * intersection) / union + + dice_scores.append(dice.item()) # Convert tensor to float and store + + return dice_scores + +def train_epoch(model, loader, criterion, optimizer, device): + """Train for one epoch""" + model.train() + total_loss = 0 # Track batch losses + + for images, masks in tqdm(loader, desc="Training", leave=False): + # Move input and target images to device + images = images.to(device) + masks = masks.to(device) + + # Reset gradients and perform forward pass + optimizer.zero_grad() + outputs = model(images) + loss = criterion(outputs, masks) # Compute loss + + # Perform backward pass and update model params + loss.backward() + optimizer.step() + + total_loss += loss.item() # Accumulate loss for reporting + + return total_loss / len(loader) # Average loss over all batches + +def validate(model, loader, criterion, device, epoch, phase="Validation"): + """Validate on validation/test set""" + model.eval() + total_loss = 0 + dice_per_class = [[] for _ in range(NUM_CLASSES)] + num_batches = len(loader) + + with torch.no_grad(): + for i, (images, masks) in enumerate(loader): + images = images.to(device) + masks = masks.to(device) + + outputs = model(images) + loss = criterion(outputs, masks) + total_loss += loss.item() + + # Calculate Dice per class + dice_scores = dice_coefficient(outputs, masks, NUM_CLASSES) + for cls in range(NUM_CLASSES): + dice_per_class[cls].append(dice_scores[cls]) + + # Print progress every 20% of validation + if (i + 1) % max(1, num_batches // 5) == 0: + progress = (i + 1) / num_batches + print(f" Epoch {epoch} - {phase}: {progress:.1%} ({i+1}/{num_batches})") + + avg_loss = total_loss / len(loader) + avg_dice_per_class = [np.mean(scores) for scores in dice_per_class] + avg_dice = np.mean(avg_dice_per_class) + + return avg_loss, avg_dice, avg_dice_per_class + +def main(): + global LABEL_MAP, NUM_CLASSES, IMG_SIZE + + # Auto-detect label mapping, number of classes, and image size + LABEL_MAP, NUM_CLASSES, IMG_SIZE = detect_label_mapping(TRAIN_SEG_DIR) + + # Create datasets + train_dataset = OASISDataset(TRAIN_IMG_DIR, TRAIN_SEG_DIR, augment=True) + val_dataset = OASISDataset(VAL_IMG_DIR, VAL_SEG_DIR, augment=False) + test_dataset = OASISDataset(TEST_IMG_DIR, TEST_SEG_DIR, augment=False) + + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, pin_memory=True) + val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1, pin_memory=True) + test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1, pin_memory=True) + + # Print number of samples loaded + print(f"Train: {len(train_dataset)} samples") + print(f"Val: {len(val_dataset)} samples") + print(f"Test: {len(test_dataset)} samples") + + # Create model + model = AttentionUNet(in_channels=1, num_classes=NUM_CLASSES, init_features=64) + model = model.to(device) + + # Loss and optimizer + criterion = CombinedLoss() + optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10) + + # Training loop + best_dice = 0.0 + patience_counter = 0 + + start_time = time.time() + + for epoch in range(NUM_EPOCHS): + train_loss = train_epoch(model, train_loader, criterion, optimizer, device) + val_loss, val_dice, val_dice_per_class = validate(model, val_loader, criterion, device) + + print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Dice={val_dice:.4f}") + + scheduler.step(val_dice) + + if val_dice > best_dice: + best_dice = val_dice + torch.save({'model_state_dict': model.state_dict(), 'best_dice': best_dice}, 'best_attention_unet.pth') + print(f"Saved (Dice: {best_dice:.4f})") + patience_counter = 0 + else: + patience_counter += 1 + + if patience_counter >= 30: + print(f"Early stopping at epoch {epoch+1}") + break + + elapsed = time.time() - start_time + print(f"Training done in {(time.time()-start_time)/60:.2f} min. Best Dice: {best_dice:.4f}") + + # Test evaluation + print("\nTesting...") + checkpoint = torch.load('best_attention_unet.pth', weights_only=False) + model.load_state_dict(checkpoint['model_state_dict']) + + test_loss, test_dice, test_dice_per_class = validate(model, test_loader, criterion, device) + + class_names = ['Background', 'CSF', 'Gray Matter', 'White Matter'] + print(f"\nTest Loss: {test_loss:.4f}, Test Dice: {test_dice:.4f}") + for i, (name, dice) in enumerate(zip(class_names, test_dice_per_class)): + print(f" {name}: {dice:.4f}") + + min_dice = min(test_dice_per_class[1:]) + print( + f"\nMin Dice (no bg): {min_dice:.4f} {'Dice threshold of 0.9 met' if min_dice >= 0.9 else 'Dice threshold of 0.9 not met'}" + ) + +if __name__ == '__main__': + main() diff --git a/recognition/AttUNetOASIS/unetrunner b/recognition/AttUNetOASIS/unetrunner new file mode 100644 index 000000000..79df21f03 --- /dev/null +++ b/recognition/AttUNetOASIS/unetrunner @@ -0,0 +1,14 @@ +#!/bin/bash +#SBATCH --time=00:20:00 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --job-name=unet +#SBATCH -o unet.out + +module load cuda/11.4 + +source ~/miniconda3/etc/profile.d/conda.sh +conda activate unet + +python unet.py