diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..3a9161e7e --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +recognition/oasis2d_unet_45807321/OASIS +recognition/hipmri2d_unet_45807321/HipMRI_Study_open +recognition/predicts_hip +recognition/predicts_oasis +recognition/final_outputs +recognition/final_outputs_hip +recognition/hip_outputs +recognition/brain_outputs +recognition/oasis2d_unet_45807321/v2 +recognition/oasis2d_unet_45807321/v3 +recognition/oasis2d_unet_45807321/outputs +recognition/oasis2d_unet_45807321/__pycache__ +recognition/hipmri2d_unet_45807321/__pycache__ +settings.json +recognition/hipmri2d_unet_45807321/README.pdf +recognition/oasis2d_unet_45807321 \ No newline at end of file diff --git a/recognition/hipmri2d_unet_45807321/README.md b/recognition/hipmri2d_unet_45807321/README.md new file mode 100644 index 000000000..ac0943226 --- /dev/null +++ b/recognition/hipmri2d_unet_45807321/README.md @@ -0,0 +1,203 @@ +# HipMRI 2D Segmentation + +## About + +This repository contains the code for a 2D U-Net model for 6-class segmentation of HipMRI scans. The model uses a deep supervision architecture, and the code provides a full pipeline for training, evaluation, and visual inference. + +The problem that I have been tasked with is to create a model that can perform segmentation on the provided hip MRI dataset. More specifically, we are trying to outline the prostate gland so that professionals can more easily look for signs of prostate cancer. To do this, I have created an improved 2D U-Net model that was designed for the purpose of image segmentation, which is exactly what I'm required to do. + +How does this algorithm work? A normal 2D U-Net model works with the standard encoder, decoder, and skip connections. The improved 2D U-Net is similar; however, it just refines it further. For example, instead of utilizing Batch Normalization and ReLU, it instead uses Instance Normalization that will allow for more stable training. The below image shows an example of all the components that make up the 2D Improved U-Net model. + +![Network Architecture](./report_assets/report_network.png)[1](#ref-1) + +## 📊 Example Results + +Here are the training curves and example predictions from a 20-epoch run with the default settings. + +### Training Curves + +![Training Curves](./report_assets/curves.png) + +The plot on the left showcase the train_loss and validation_loss over the 20 epochs that were ran. The plot on the right showcases the dice coeffeicent over the 20 epochs. + +## Outputs + +The below are simple images that were created after the prediciton was ran on the model that was created. + +It can be seen that the prediction isn't too far off from the ground truth, but there is still definitely room for improvement as it can bee seen that the model is still doing some underestimations. + +### Output 1 + +![Example Predictions 1](./report_assets/sample_000_00_combined.png) + +### Output 2 + +![Example Predictions 2](./report_assets/sample_000_01_combined.png) + +### Output 3 + +![Example Predictions 2](./report_assets/sample_000_02_combined.png) + +### Dice coefficents + +The below are the end dice coefficents that I had after the mdoel was trained. + +For the first 4 classes it can be seen that the dice coefficents are excellent, proably due to the fact that these features appear more overtly. For the last features my model still has decent dice coefficeints however, not as good as the first 4 probably due to the fact these features are a lot smaller, so are harder to pinpoint. + +Per-class Dice: C0:0.982 C1:0.984 C2:0.942 C3:0.970 C4:0.876 C5:0.839 +Mean Dice: 0.932 + +## 🚀 How to Run + +Follow these steps to set up the environment, download the data, and run the code. + +### 1. Get the Code + +Clone the repository and navigate to the project directory: + +```bash +git clone [https://github.com/rike568/PatternAnalysis-2025.git](https://github.com/rike568/PatternAnalysis-2025.git) +cd PatternAnalysis-2025 +git checkout topic-recognition +cd recognition/hipmri2d_unet_45807321 +``` + +### 2\. Set Up the Environment + +You will need Conda to replicate the environment. + +1. **Install Conda:** If you don't have it, please [install Miniconda](https://docs.conda.io/en/latest/miniconda.html) for your system. + +2. **Create the Environment:** Use the provided `environment.yml` file to create the conda environment. This will install all required packages. + + ```bash + conda env create -f environment.yml + ``` + +3. **Activate the Environment:** Before running any scripts, you must activate the new environment: + + ```bash + conda activate comp3710 + ``` + +### 3\. Download the Dataset + +The code requires the `HipMRI_Study_open` dataset. + +**Rangpur Location:** + +If you are on the `rangpur` server, you can copy the data directly from the group directory: + +```bash +# This copies the dataset into your current folder +cp -r /home/groups/comp3710/HipMRI_Study_open/keras_slices_data HipMRI_Study_open +``` + +Your final folder structure should look like this: + +``` +hipmri2d_unet_45807321/ +├── HipMRI_Study_open/ +│ ├── keras_slices_train/ +│ ├── keras_slices_seg_train/ +│ ├── keras_slices_validate/ +│ ├── keras_slices_seg_validate/ +│ ├── keras_slices_test/ +│ └── keras_slices_seg_test/ +├── train.py +├── predict.py +├── dataset.py +├── modules.py +├── utils.py +└── environment.yml +``` + +### 4\. Train the Model + +With the `comp3710` environment active, you can run the training script. Checkpoints and results will be saved to the `outputs/` folder. + +**To train with default settings:** +This will use the defaults set in the script (e.g., seed=42, lr=0.0005). + +```bash +python train.py +``` + +**To train with custom hyperparameters:** +You can override the default settings by providing command-line arguments. + +- `--seed`: Set the random seed (e.g., `--seed 123`). +- `--lr`: Set the learning rate (e.g., `--lr 0.001`). +- `--weight_decay`: Set the Adam weight decay (e.g., `--weight_decay 1e-5`). +- `--grad_clip_norm`: Set the gradient clipping norm (e.g., `--grad_clip_norm 1.0`). + +**Example of a custom run:** +This command trains with a learning rate of 0.001 and a seed of 123. + +```bash +python train.py --lr 0.001 --seed 123 +``` + +### 5\. Run Predictions + +After training, you can run inference on the test set. This will load the `best.pt` checkpoint from the `outputs/` folder, calculate final Dice scores, and save visual predictions to `outputs/predictions/`. + +**To run with the default seed (42):** + +```bash +python predict.py +``` + +**To run with a custom seed:** +You can specify a different seed for reproducibility. + +```bash +python predict.py --seed 123 +``` + +# 📦 Project Dependencies: + +This document outlines the software environment and dependencies required for the `comp3710` project, typically sourced from an `environment.yml` or similar configuration file. + +--- + +## 🔗 Configuration Channels + +The following channels are used to locate and download packages: + +- **`pytorch`**: Primary channel for PyTorch-related packages, especially those built with specific CUDA versions. +- **`defaults`**: The standard set of channels used by the package manager (e.g., Anaconda/Miniconda). + +--- + +## 🐍 Core Dependencies (Conda) + +These packages are managed directly by the environment tool (Conda, in this case). + +- **`python=3.10`**: Specifies the required Python version. +- **`pip`**: Ensures the `pip` package installer is available for managing secondary dependencies. + +--- + +## ⚙️ Python Packages (Pip) + +These packages are installed using `pip`, often with specific build configurations. + +| Package Name | Installation Source / Note | Description | +| :---------------- | :--------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------ | +| **`torch`** | `--index-url https://download.pytorch.org/whl/cu118` | The core **PyTorch** library, explicitly compiled for **CUDA 11.8** for GPU acceleration. | +| **`torchvision`** | `--index-url https://download.pytorch.org/whl/cu118` | A package for computer vision, providing datasets, models, and image transformations, also built for **CUDA 11.8**. | +| **`torchaudio`** | `--index-url https://download.pytorch.org/whl/cu118` | A package for audio data, including data loading and transformations, also built for **CUDA 11.8**. | +| **`matplotlib`** | Standard PyPI | A comprehensive library for creating static, animated, and interactive visualizations in Python. | +| **`nibabel`** | Standard PyPI | Provides read/write access to common neuroimaging file formats (e.g., NIfTI, DICOM). | +| **`tqdm`** | Standard PyPI | A fast, extensible progress bar for loops and iterables. | + +--- + +## 🚀 Environment Summary + +This environment is specifically configured for deep learning tasks involving PyTorch, with a strong focus on GPU acceleration (CUDA 11.8), and includes specialized libraries for handling neuroimaging data (`nibabel`) and providing utility (`tqdm`, `matplotlib`). + +## References + +1. Isensee, F., Kickingereder, P., Wick, W., Bendszus, M., & Maier-Hein, K. H. (2018). _Brain Tumor Segmentation and Radiomics Survival Prediction: Contribution to the BRATS 2017 Challenge_. arXiv:1802.10508. Available: [https://arxiv.org/abs/1802.10508](https://arxiv.org/abs/1802.10508) diff --git a/recognition/hipmri2d_unet_45807321/dataset.py b/recognition/hipmri2d_unet_45807321/dataset.py new file mode 100644 index 000000000..548bb9c03 --- /dev/null +++ b/recognition/hipmri2d_unet_45807321/dataset.py @@ -0,0 +1,417 @@ +# dataset.py +# HipMRI_Study_open 2D dataloader: +# - auto-locates ./HipMRI_Study_open next to this file +# - pairs image slices (case_...) with masks (seg_...) via a canonical key +# - loads 2D Nifti slices using z-score normalization +# - applies simple train-time geometric augmentation +# - returns PyTorch DataLoaders for train/val/test + +from __future__ import annotations +import random +from pathlib import Path +from typing import List, Tuple, Dict + +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision.transforms import functional as F +import nibabel as nib # Added for Nifti loading +from tqdm import tqdm # Import tqdm for progress bar + +# --------------------------- +# Defaults / config +# --------------------------- +THIS_DIR = Path(__file__).resolve().parent +DEFAULT_DATA_ROOT = THIS_DIR / "HipMRI_Study_open" +DEFAULT_IMG_SIZE = None +DEFAULT_BATCH_SIZE = 64 +DEFAULT_NUM_WORKERS = 1 + +# --------------------------- +# Simple helpers +# --------------------------- + + +def _canonical_key(p: Path) -> str: + """ + Standardizes a Nifti filename to create a canonical key for matching. + + This allows 'case_001_slice_0.nii.gz' and 'seg_001-slice_0.nii.gz' + to both map to the same key '001_slice_0'. + + Args: + p: The Path object to the file. + + Returns: + A standardized string key. + """ + name = p.stem.lower() + if ".nii" in name: # Handle double extensions like .nii.gz + name = Path(name).stem + + if name.startswith("case_"): + name = name[len("case_") :] + elif name.startswith("seg_"): + name = name[len("seg_") :] + name = name.replace("-", "_") # unify separators + return name + + +class RandomAugment2D: + """ + Apply identical random geometric augmentations to an image and mask tensor. + + This is a callable class. + """ + + def __init__( + self, max_rot_deg: float = 10.0, p_hflip: float = 0.5, p_vflip: float = 0.5 + ): + """ + Initializes the augmentation transform. + + Args: + max_rot_deg: Maximum angle (in degrees) for random rotation. + p_hflip: Probability of a horizontal flip. + p_vflip: Probability of a vertical flip. + """ + self.max_rot_deg = max_rot_deg + self.p_hflip = p_hflip + self.p_vflip = p_vflip + + def __call__( + self, img: torch.Tensor, mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Applies the configured random augmentations to an image and mask. + + Ensures that the same random transformation is applied to both inputs + and that 'NEAREST' interpolation is used for the mask. + + Args: + img: The image tensor, expected shape [1, H, W]. + mask: The mask tensor, expected shape [H, W]. + + Returns: + A tuple of (augmented_img, augmented_mask). + """ + # Add channel dim to mask for transforms [H,W] -> [1,H,W] + mask = mask[None, ...] + + # Horizontal flip + if random.random() < self.p_hflip: + img = F.hflip(img) + mask = F.hflip(mask) + # Vertical flip + if random.random() < self.p_vflip: + img = F.vflip(img) + mask = F.vflip(mask) + # Small random rotation; bilinear for image, nearest for mask (to preserve labels) + if self.max_rot_deg > 0: + angle = random.uniform(-self.max_rot_deg, self.max_rot_deg) + img = F.rotate( + img, angle, interpolation=F.InterpolationMode.BILINEAR, fill=0 + ) + mask = F.rotate( + mask, angle, interpolation=F.InterpolationMode.NEAREST, fill=0 + ) + + # Remove channel dim from mask [1,H,W] -> [H,W] + return img, mask.squeeze(0) + + +class HipMRI2DSegDataset(Dataset): + """ + A PyTorch Dataset for loading 2D Nifti slices from the HipMRI_Study_open dataset. + + Expected folder layout inside ./HipMRI_Study_open: + keras_slices_train/ + keras_slices_validate/ + keras_slices_test/ + keras_slices_seg_train/ + keras_slices_seg_validate/ + keras_slices_seg_test/ + """ + + def __init__( + self, + data_root: Path = DEFAULT_DATA_ROOT, + split: str = "train", + train_augment: bool = False, + max_rot_deg: float = 10.0, + ): + """ + Initializes the dataset. + + This method scans the data directories, matches image and mask + files based on their canonical keys, and sets up the augmentation + pipeline for the training split. + + Args: + data_root: The root directory of the 'HipMRI_Study_open' dataset. + split: The dataset split to load ("train", "validate", or "test"). + train_augment: Whether to apply augmentations (only used if split="train"). + max_rot_deg: Maximum rotation angle for augmentation. + """ + super().__init__() + self.data_root = Path(data_root) + assert split in {"train", "validate", "test"} # enforce valid split names + self.split = split + + # Locate image/mask directories by split + img_dir = ( + self.data_root + / f"keras_slices_{'validate' if split=='validate' else split}" + ) + seg_dir = ( + self.data_root + / f"keras_slices_seg_{'validate' if split=='validate' else split}" + ) + if not (img_dir.exists() and seg_dir.exists()): + raise FileNotFoundError( + f"Missing expected HipMRI_Study_open folders:\n{img_dir}\n{seg_dir}" + ) + + # List files and build a mask index keyed by canonical names + images = sorted( + [p for p in img_dir.iterdir() if p.is_file() and ".nii" in p.name] + ) + masks = sorted( + [p for p in seg_dir.iterdir() if p.is_file() and ".nii" in p.name] + ) + mask_index: Dict[str, Path] = {_canonical_key(m): m for m in masks} + + # Pair image with mask via canonical key + pairs: List[Tuple[Path, Path]] = [] + missing: List[str] = [] + for ip in images: + key = _canonical_key(ip) + mp = mask_index.get(key) + if mp is not None: + pairs.append((ip, mp)) + else: + missing.append(f"{ip.name} (key: {key})") + + if not pairs: + raise RuntimeError( + "No image/mask pairs found. Check prefixes or directory names.\n" + f"Example image: {images[0].name if images else 'None'}\n" + f"Example mask : {masks[0].name if masks else 'None'}" + ) + if missing: + print( + f"[HipMRI] {len(missing)} images had no mask match (showing first 5): {missing[:5]}" + ) + + self.pairs = pairs + self.augment = ( + RandomAugment2D(max_rot_deg=max_rot_deg) + if train_augment and split == "train" + else None + ) + + def __len__(self) -> int: + """Returns the total number of paired samples in this split.""" + return len(self.pairs) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor | str]: + """ + Loads and preprocesses a single image/mask pair. + + Steps: + 1. Loads Nifti image and mask data. + 2. Applies instance-wise z-score normalization to the image. + 3. Converts both to PyTorch tensors. + 4. Center-crops the pair to the target size (256, 128). + 5. Applies augmentations if this is the training set. + + Args: + idx: The index of the sample to retrieve. + + Returns: + A dictionary containing: + - "image": The preprocessed image tensor [1, 256, 128]. + - "mask": The preprocessed mask tensor [256, 128]. + - "image_path": String path to the original image. + - "mask_path": String path to the original mask. + """ + img_path, mask_path = self.pairs[idx] + + # Load image + img = nib.load(img_path).get_fdata(caching="unchanged") + if len(img.shape) == 3: + img = img[:, :, 0] # Take first slice + img = img.astype(np.float32) + + # Load mask + mask = nib.load(mask_path).get_fdata(caching="unchanged") + if len(mask.shape) == 3: + mask = mask[:, :, 0] # Take first slice + + # Apply per-image z-score normalization + mean = img.mean() + std = img.std() + img = (img - mean) / (std + 1e-8) # Add epsilon for safety + + # Convert to Tensors + img_t = torch.from_numpy(img)[None, ...] # [1, H, W] + mask_t = torch.from_numpy(mask.astype(np.int64)) # [H, W] + + # Robust center-crop to (256, 128) + target_h, target_w = 256, 128 + _, current_h, current_w = img_t.shape + + if current_h == target_h and current_w > target_w: + # Image is correct height but too wide. Center-crop width. + crop_pixels = current_w - target_w + start_w = crop_pixels // 2 + end_w = start_w + target_w + + img_t = img_t[:, :, start_w:end_w] + mask_t = mask_t[:, start_w:end_w] + + # Apply augmentation (now on Tensors) + if self.augment: + img_t, mask_t = self.augment(img_t, mask_t) + + return { + "image": img_t, + "mask": mask_t, + "image_path": str(img_path), + "mask_path": str(mask_path), + } + + +def make_loaders( + data_root: Path = DEFAULT_DATA_ROOT, + batch_size: int = DEFAULT_BATCH_SIZE, + num_workers: int = DEFAULT_NUM_WORKERS, +) -> Tuple[DataLoader, DataLoader, DataLoader]: + """ + Creates and returns the train, validation, and test DataLoaders. + + Args: + data_root: The root directory of the 'HipMRI_Study_open' dataset. + batch_size: The batch size for all loaders. + num_workers: The number of worker processes for data loading. + + Returns: + A tuple of (train_loader, val_loader, test_loader). + """ + train_ds = HipMRI2DSegDataset(data_root, "train", train_augment=True) + val_ds = HipMRI2DSegDataset(data_root, "validate") + test_ds = HipMRI2DSegDataset(data_root, "test") + + train_loader = DataLoader( + train_ds, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + val_loader = DataLoader( + val_ds, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + test_loader = DataLoader( + test_ds, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + return train_loader, val_loader, test_loader + + +if __name__ == "__main__": + """ + Runs a verification script when the dataset is executed directly. + + This will scan all dataset splits and report on: + 1. Total number of samples found. + 2. Any image/mask shape mismatches. + 3. All unique image shapes (H, W) found after processing. + 4. All unique mask values (class labels) found. + """ + print(f"[HipMRI] Using data_root: {DEFAULT_DATA_ROOT}") + print("Running dataset shape verification...") + + all_shapes = set() + all_mask_values = set() + + def check_dataset_shapes(name: str, dataset: HipMRI2DSegDataset) -> Tuple[set, set]: + """ + Scans a dataset split, checking and reporting shapes and mask values. + + Args: + name: The name of the split (e.g., "train"). + dataset: The HipMRI2DSegDataset instance to check. + + Returns: + A tuple of (set_of_shapes, set_of_mask_values). + """ + print(f"\nChecking dataset: {name} ({len(dataset)} samples)") + shapes = set() + mask_vals = set() + + for i in tqdm(range(len(dataset)), desc=f"Scanning {name}"): + try: + sample = dataset[i] + img_shape = tuple(sample["image"].shape[1:]) # (H, W) + mask_shape = tuple(sample["mask"].shape) # (H, W) + + mask_vals.update(torch.unique(sample["mask"]).numpy().tolist()) + current_shape = img_shape + + if img_shape != mask_shape: + print( + f" WARNING: Mismatch! Img {dataset.pairs[i][0].name} is {img_shape}, Mask {dataset.pairs[i][1].name} is {mask_shape}" + ) + + shapes.add(current_shape) + + except Exception as e: + print(f" ERROR loading sample {i} ({dataset.pairs[i][0].name}): {e}") + + print(f"-> Found unique (H, W) shapes for {name}: {shapes}") + print(f"-> Found unique mask values for {name}: {sorted(list(mask_vals))}") + return shapes, mask_vals + + try: + train_ds = HipMRI2DSegDataset(DEFAULT_DATA_ROOT, "train") + val_ds = HipMRI2DSegDataset(DEFAULT_DATA_ROOT, "validate") + test_ds = HipMRI2DSegDataset(DEFAULT_DATA_ROOT, "test") + + train_shapes, train_mask_vals = check_dataset_shapes("train", train_ds) + val_shapes, val_mask_vals = check_dataset_shapes("validate", val_ds) + test_shapes, test_mask_vals = check_dataset_shapes("test", test_ds) + + all_shapes.update(train_shapes) + all_shapes.update(val_shapes) + all_shapes.update(test_shapes) + + all_mask_values.update(train_mask_vals) + all_mask_values.update(val_mask_vals) + all_mask_values.update(test_mask_vals) + + print("\n========================================") + print(f"All unique (H, W) shapes found: {all_shapes}") + print(f"All unique mask values found: {sorted(list(all_mask_values))}") + + if len(all_shapes) == 1 and (256, 128) in all_shapes: + print("Confirmation: All images are 256x128.") + else: + print("WARNING: Not all images are 256x128 or multiple sizes found.") + + if all(v in [0, 1, 2, 3, 4, 5] for v in all_mask_values): + print("Confirmation: All mask values are valid (0, 1, 2, 3, 4, 5).") + else: + print("WARNING: Invalid mask values found! Check the list above.") + print("========================================") + + except Exception as e: + print(f"\nFailed to initialize dataset. Check paths and folder names.") + print(f"Error: {e}") diff --git a/recognition/hipmri2d_unet_45807321/environment.yml b/recognition/hipmri2d_unet_45807321/environment.yml new file mode 100644 index 000000000..c7984c897 --- /dev/null +++ b/recognition/hipmri2d_unet_45807321/environment.yml @@ -0,0 +1,15 @@ +name: comp3710 +channels: + - pytorch + - defaults +dependencies: + - python=3.10 + - pip + + - pip: + - torch --index-url https://download.pytorch.org/whl/cu118 + - torchvision --index-url https://download.pytorch.org/whl/cu118 + - torchaudio --index-url https://download.pytorch.org/whl/cu118 + - matplotlib + - nibabel + - tqdm diff --git a/recognition/hipmri2d_unet_45807321/modules.py b/recognition/hipmri2d_unet_45807321/modules.py new file mode 100644 index 000000000..00769ec8b --- /dev/null +++ b/recognition/hipmri2d_unet_45807321/modules.py @@ -0,0 +1,439 @@ +# modules.py +# U-Net 2D model based on the provided diagram for HipMRI slices (256x128). +# Returns per-pixel logits [B, C, H, W] (no softmax in the model). + +from __future__ import annotations +import torch +import torch.nn as nn +import torch.nn.functional as F # Added for potential future use, though not strictly in this model for now + +__all__ = [ + "ImprovedUNet", # Renamed to keep consistent with train.py, but it's the new arch + "create_model", + "count_params", + "init_kaiming_normal_", # Retaining for good practice +] + + +# Helper function for weight initialization +def init_kaiming_normal_(m: nn.Module) -> None: + """ + Applies He (Kaiming) initialization to Conv2d and ConvTranspose2d layers + and initializes BatchNorm2d layers. + + Call with `model.apply(init_kaiming_normal_)`. + + Args: + m: The module to initialize. + """ + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if getattr(m, "bias", None) is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + +# --------------------------------------------- +# Core building blocks from the diagram +# --------------------------------------------- + + +class _ContextModule(nn.Module): + """ + The 'context module' from the diagram. + + A block consisting of two sequential 3x3 convolutions, each followed + by BatchNorm and ReLU. This block maintains the input resolution. + """ + + def __init__(self, in_channels: int, out_channels: int): + """ + Initializes the context module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + """ + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Runs the forward pass for the context module. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ + return self.block(x) + + +class _DownsamplingModule(nn.Module): + """ + The 'downsampling module' from the diagram (3x3 stride 2 convolution). + + A 3x3 strided convolution block that halves the spatial dimensions (H, W) + and increases the channel count. + """ + + def __init__(self, in_channels: int, out_channels: int): + """ + Initializes the downsampling module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + """ + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Runs the forward pass for the downsampling module. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ + return self.conv(x) + + +class _UpsamplingModule(nn.Module): + """ + The 'upsampling module' from the diagram. + + A 2x2 transposed convolution that doubles the spatial dimensions (H, W) + and halves the channel count. + """ + + def __init__(self, in_channels: int, out_channels: int): + """ + Initializes the upsampling module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels (typically in_channels // 2). + """ + super().__init__() + self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Runs the forward pass for the upsampling module. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ + return self.up(x) + + +class _LocalizationModule(nn.Module): + """ + The 'localization module' from the diagram, used in the decoder. + + A block consisting of two sequential 3x3 convolutions, each followed + by BatchNorm and ReLU. This is structurally identical to the + _ContextModule but is used on the decoder path. + """ + + def __init__(self, in_channels: int, out_channels: int): + """ + Initializes the localization module. + + Args: + in_channels: Number of input channels (from concatenated skip + upsample). + out_channels: Number of output channels. + """ + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Runs the forward pass for the localization module. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ + return self.block(x) + + +class _SegmentationLayer(nn.Module): + """ + The 'segmentation layer' from the diagram. + + A single 1x1 convolution used to map feature channels to the final + number of classes. + """ + + def __init__(self, in_channels: int, out_channels: int): + """ + Initializes the segmentation layer. + + Args: + in_channels: Number of input feature channels. + out_channels: Number of output classes. + """ + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Runs the forward pass for the segmentation layer. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ + return self.conv(x) + + +# --------------------------------------------- +# The U-Net Model +# --------------------------------------------- + + +class ImprovedUNet(nn.Module): + """ + The complete U-Net style model for HipMRI segmentation, based on the + provided architecture diagram. + + This model features an encoder-decoder structure with skip connections + and deep supervision, where segmentation maps from multiple decoder + levels are upscaled and summed for the final output. + """ + + def __init__(self, in_channels: int = 1, num_classes: int = 6): + """ + Initializes the U-Net model. + + Args: + in_channels: Number of input image channels (e.g., 1 for grayscale). + num_classes: Number of output segmentation classes. + """ + super().__init__() + + # --- Encoder Path --- + # Level 1 (256x128) + self.context1 = _ContextModule(in_channels, 16) + self.down1 = _DownsamplingModule(16, 32) + + # Level 2 (128x64) + self.context2 = _ContextModule(32, 32) + self.down2 = _DownsamplingModule(32, 64) + + # Level 3 (64x32) + self.context3 = _ContextModule(64, 64) + self.down3 = _DownsamplingModule(64, 128) + + # Level 4 (32x16) + self.context4 = _ContextModule(128, 128) + self.down4 = _DownsamplingModule(128, 256) + + # --- Bottleneck --- (16x8) + self.bottleneck = _ContextModule(256, 256) + + # --- Decoder Path --- + # Level 4 (32x16) + self.up4 = _UpsamplingModule(256, 128) + self.loc4 = _LocalizationModule( + 128 + 128, 128 + ) # Concatenates upsampled with context4 output + self.seg4 = _SegmentationLayer(128, num_classes) + + # Level 3 (64x32) + self.up3 = _UpsamplingModule(128, 64) + self.loc3 = _LocalizationModule( + 64 + 64, 64 + ) # Concatenates upsampled with context3 output + self.seg3 = _SegmentationLayer(64, num_classes) + + # Level 2 (128x64) + self.up2 = _UpsamplingModule(64, 32) + self.loc2 = _LocalizationModule( + 32 + 32, 32 + ) # Concatenates upsampled with context2 output + self.seg2 = _SegmentationLayer(32, num_classes) + + # Level 1 (256x128) + self.up1 = _UpsamplingModule(32, 16) + self.loc1 = _LocalizationModule( + 16 + 16, 16 + ) # Concatenates upsampled with context1 output + + # Final output segmentation layer + self.final_seg_layer = _SegmentationLayer(16, num_classes) + + # Apply Kaiming initialization + self.apply(init_kaiming_normal_) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Runs the forward pass of the U-Net. + + Args: + x: The input batch of images [B, C_in, H, W]. + + Returns: + The raw logits for each class [B, C_out, H, W]. + """ + # Encoder + # Level 1 + x_c1 = self.context1(x) # [B, 16, H, W] + x_d1 = self.down1(x_c1) # [B, 32, H/2, W/2] + + # Level 2 + x_c2 = self.context2(x_d1) # [B, 32, H/2, W/2] + x_d2 = self.down2(x_c2) # [B, 64, H/4, W/4] + + # Level 3 + x_c3 = self.context3(x_d2) # [B, 64, H/4, W/4] + x_d3 = self.down3(x_c3) # [B, 128, H/8, W/8] + + # Level 4 + x_c4 = self.context4(x_d3) # [B, 128, H/8, W/8] + x_d4 = self.down4(x_c4) # [B, 256, H/16, W/16] + + # Bottleneck + x_bottleneck = self.bottleneck(x_d4) # [B, 256, H/16, W/16] + + # Decoder + # Level 4 (decoding from bottleneck) + x_up4 = self.up4(x_bottleneck) # [B, 128, H/8, W/8] + x_cat4 = torch.cat([x_up4, x_c4], dim=1) # [B, 256, H/8, W/8] + x_loc4 = self.loc4(x_cat4) # [B, 128, H/8, W/8] + s4 = self.seg4(x_loc4) # [B, num_classes, H/8, W/8] + + # Level 3 + x_up3 = self.up3(x_loc4) # [B, 64, H/4, W/4] + x_cat3 = torch.cat([x_up3, x_c3], dim=1) # [B, 128, H/4, W/4] + x_loc3 = self.loc3(x_cat3) # [B, 64, H/4, W/4] + s3 = self.seg3(x_loc3) # [B, num_classes, H/4, W/4] + + # Level 2 + x_up2 = self.up2(x_loc3) # [B, 32, H/2, W/2] + x_cat2 = torch.cat([x_up2, x_c2], dim=1) # [B, 64, H/2, W/2] + x_loc2 = self.loc2(x_cat2) # [B, 32, H/2, W/2] + s2 = self.seg2(x_loc2) # [B, num_classes, H/2, W/2] + + # Level 1 + x_up1 = self.up1(x_loc2) # [B, 16, H, W] + x_cat1 = torch.cat([x_up1, x_c1], dim=1) # [B, 32, H, W] + x_loc1 = self.loc1(x_cat1) # [B, 16, H, W] + + # Final segmentation layer + s1 = self.final_seg_layer(x_loc1) # [B, num_classes, H, W] + + # Element-wise sum of segmentation layers (after upscaling s4, s3, s2 to s1's size) + # Note: F.interpolate is used for upscaling + output = ( + s1 + + F.interpolate(s2, scale_factor=2, mode="bilinear", align_corners=False) + + F.interpolate(s3, scale_factor=4, mode="bilinear", align_corners=False) + + F.interpolate(s4, scale_factor=8, mode="bilinear", align_corners=False) + ) + + return output + + +# ----------------------------------------------------------------- +# --- Factory and Parameter Counter (for compatibility) --- +# ----------------------------------------------------------------- + + +def create_model( + in_channels: int = 1, + num_classes: int = 6, # Removed base and p_drop as they are not used by this architecture +) -> ImprovedUNet: + """ + Factory function for quick construction of the ImprovedUNet model. + + Args: + in_channels: Number of input image channels. + num_classes: Number of output segmentation classes. + + Returns: + An instance of the ImprovedUNet model. + """ + return ImprovedUNet(in_channels=in_channels, num_classes=num_classes) + + +def count_params(model: nn.Module) -> int: + """ + Calculates the total number of trainable parameters in a model. + + Args: + model: The PyTorch model (nn.Module). + + Returns: + The integer count of trainable parameters. + """ + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +# ----------------------------------------------------------------- +# --- Sanity Check --- +# ----------------------------------------------------------------- + +if __name__ == "__main__": + """ + Runs a sanity check when the script is executed directly. + + Creates a model, passes a dummy tensor, and asserts that the + output shape is correct. + """ + print("Testing the U-Net model based on diagram (2D adaptation)...") + net = create_model(in_channels=1, num_classes=6) + + # Test with 256x128 input + x = torch.randn(2, 1, 256, 128) # Batch size 2, 1 channel, 256x128 + print(f"Input shape: {x.shape}") + y = net(x) + + # Expected output shape: (batch_size, num_classes, 256, 128) + print(f"Output shape: {tuple(y.shape)}") + print(f"Model params: {count_params(net):,}") + + expected_output_shape = (2, 6, 256, 128) + assert ( + tuple(y.shape) == expected_output_shape + ), f"Output shape mismatch! Expected {expected_output_shape}, got {tuple(y.shape)}" + print("Output shape is correct!") diff --git a/recognition/hipmri2d_unet_45807321/predict.py b/recognition/hipmri2d_unet_45807321/predict.py new file mode 100644 index 000000000..5b2f61c9e --- /dev/null +++ b/recognition/hipmri2d_unet_45807321/predict.py @@ -0,0 +1,311 @@ +# predict.py +# Inference + visualisation for HipMRI 2D segmentation (Improved U-Net) + +from __future__ import annotations +import os +from pathlib import Path +from typing import Tuple, List + +import torch +import matplotlib.pyplot as plt +import numpy as np + +from dataset import make_loaders +from modules import create_model +from utils import ( + # oasis_mask_to_class_ids, # No longer needed + dice_per_class_from_logits, + set_seed, + to_device, + load_checkpoint, +) + +# --------------------------- +# Config +# --------------------------- +SEED = 42 +NUM_CLASSES = 6 +IN_CHANNELS = 1 + +OUTDIR = Path("./outputs") +PRED_DIR = OUTDIR / "predictions" +CKPT_BEST = OUTDIR / "best.pt" +CKPT_LAST = OUTDIR / "last.pt" + +N_VIS = 8 # number of samples to visualise/save + + +# --------------------------- +# Small helpers +# --------------------------- + + +def _ensure_dir(p: Path) -> None: + """ + Ensures that a directory exists, creating it if necessary. + + Args: + p: The pathlib.Path of the directory to check/create. + """ + p.mkdir(parents=True, exist_ok=True) + + +def colorize(mask_ids: np.ndarray) -> np.ndarray: + """ + Maps a 2D array of class IDs to a 3D RGB color mask. + + Args: + mask_ids: A 2D numpy array of integer class labels [H, W]. + + Returns: + A 3D numpy array (RGB image) [H, W, 3] of type uint8. + """ + palette = np.array( + [ + [0, 0, 0], # 0: background - black + [0, 114, 189], # 1: blue + [217, 83, 25], # 2: orange + [237, 177, 32], # 3: yellow + [126, 47, 142], # 4: purple (NEW) + [119, 172, 48], # 5: green (NEW) + ], + dtype=np.uint8, + ) + mask_ids = np.clip(mask_ids, 0, len(palette) - 1) + return palette[mask_ids] + + +def tensor_to_uint8_img(x: torch.Tensor) -> np.ndarray: + """ + Converts a normalized image tensor to a uint8 grayscale image. + + Handles z-score normalized tensors by mapping a rough [-2, 2] range + to [0, 1] before scaling to [0, 255]. + + Args: + x: A [1, H, W] or [H, W] image tensor, typically z-score normalized. + + Returns: + A 2D numpy array (grayscale image) [H, W] of type uint8. + """ + x = x.detach().cpu().float() + if x.ndim == 3 and x.size(0) == 1: + x = x[0] # Squeeze channel dim + + # Check if tensor is z-score normalized (values outside [0, 1]) + x_min, x_max = float(x.min()), float(x.max()) + if x_min < -0.1 or x_max > 1.1: + # Assumes z-score norm, map roughly -2..2 to 0..1 + x = (x + 2.0) / 4.0 + + x = torch.clamp(x, 0.0, 1.0) # Clamp to [0, 1] range + return (x.numpy() * 255.0).astype(np.uint8) + + +def overlay( + img_gray_u8: np.ndarray, mask_rgb: np.ndarray, alpha: float = 0.5 +) -> np.ndarray: + """ + Overlays a color RGB mask onto a grayscale image. + + Args: + img_gray_u8: The base grayscale image [H, W] as uint8. + mask_rgb: The color mask [H, W, 3] as uint8. + alpha: The opacity of the mask (0.0 = transparent, 1.0 = opaque). + + Returns: + A 3D numpy array (RGB image) [H, W, 3] of the blended overlay. + """ + # Convert grayscale to 3-channel RGB + img_rgb = np.stack([img_gray_u8] * 3, axis=-1) + # Blend + out = (img_rgb * (1 - alpha) + mask_rgb * alpha).astype(np.uint8) + return out + + +# --------------------------- +# Main +# --------------------------- + + +@torch.no_grad() +def main() -> None: + """ + Main function to run inference and visualization. + + - Loads the test dataset. + - Loads the best trained model checkpoint. + - Calculates and prints per-class Dice scores for the entire test set. + - Saves N_VIS sample visualizations (input, gt, pred, overlay) to disk. + - Creates a preview grid of the saved overlays. + """ + print("==> HipMRI 2D — Inference & Visualisation") + set_seed(SEED) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + # Data: only need test loader for predictions + _, _, test_loader = make_loaders() + + # Model + model = create_model( + in_channels=IN_CHANNELS, + num_classes=NUM_CLASSES, + ).to(device) + + # Load best checkpoint + ckpt_path = CKPT_BEST if CKPT_BEST.exists() else CKPT_LAST + if ckpt_path.exists(): + load_checkpoint( + ckpt_path.as_posix(), model, optimizer=None, map_location=device + ) + print(f"Loaded checkpoint: {ckpt_path}") + else: + print("⚠️ No checkpoint found — running with random-initialized weights.") + + model.eval() + + # Metrics over entire test set + dice_sum = torch.zeros(NUM_CLASSES, device=device) + n_batches = 0 + + _ensure_dir(PRED_DIR) + + vis_count = 0 + saved_paths: List[Path] = [] + + print(f"Running evaluation and saving {N_VIS} samples to {PRED_DIR}...") + for batch_idx, batch in enumerate(test_loader): + batch = to_device(batch, device) + x, y_ids = batch["image"], batch["mask"] + + # Get model prediction + logits = model(x) + + # --- Metric Calculation --- + dice_c = dice_per_class_from_logits(logits, y_ids) + dice_sum += dice_c + n_batches += 1 + + # --- Visualization Saving --- + if vis_count < N_VIS: + take = min( + N_VIS - vis_count, x.size(0) + ) # Num samples to take from this batch + for i in range(take): + # Convert tensors to numpy images + img_u8 = tensor_to_uint8_img(x[i]) + pred_ids = logits[i].argmax(dim=0).cpu().numpy().astype(np.int32) + gt_ids = y_ids[i].cpu().numpy().astype(np.int32) + + # Colorize masks + pred_rgb = colorize(pred_ids) + gt_rgb = colorize(gt_ids) + + # <--- MODIFIED SECTION START ---> + + # --- Create combined plot (like your example) --- + + # NOTE: Your example image shows 3 panels. + # If you also want the 'overlay' panel, change 'n_cols=3' to 'n_cols=4' + # and uncomment the 4th panel (axes[3]) plotting lines below. + + # over_rgb = overlay(img_u8, pred_rgb, alpha=0.45) # Uncomment for 4 panels + + n_cols = 3 # Change to 4 if you want the overlay + fig, axes = plt.subplots( + nrows=1, + ncols=n_cols, + figsize=(n_cols * 5, 5.5), # 5x5 inch per panel + title space + ) + + # Ensure 'axes' is always an array for easy indexing + if n_cols == 1: + axes = np.array([axes]) + else: + axes = axes.flat + + # Panel 1: Original Image + axes[0].imshow(img_u8, cmap="gray") + axes[0].set_title("Original Image") + + # Panel 2: Ground Truth Mask + axes[1].imshow(gt_rgb) + axes[1].set_title("Ground Truth Mask") + + # Panel 3: Predicted Mask + axes[2].imshow(pred_rgb) + axes[2].set_title("Predicted Mask") + + # # Panel 4: Overlay (Optional - uncomment lines below) + # if n_cols >= 4: + # axes[3].imshow(over_rgb) + # axes[3].set_title("Overlay") + + # --- Clean up and Save --- + for ax in axes: + ax.axis("off") + + fig.tight_layout() + + base = Path(f"sample_{batch_idx:03d}_{i:02d}") + save_path = PRED_DIR / f"{base}_combined.png" + + fig.savefig(save_path, dpi=150, bbox_inches="tight") + plt.close(fig) # IMPORTANT: close fig in a loop + + # Add the path of the new combined image for the preview grid + saved_paths.append(save_path) + + # <--- MODIFIED SECTION END ---> + + vis_count += 1 + + # --- Report final metrics --- + print("\n==> Test Metrics:") + dice_mean_c = (dice_sum / max(n_batches, 1)).cpu().numpy() + dice_mean = float(dice_mean_c.mean()) + print( + "Per-class Dice:", + " ".join([f"C{c}:{dice_mean_c[c]:.3f}" for c in range(NUM_CLASSES)]), + ) + print(f"Mean Dice: {dice_mean:.3f}") + + # --- Create preview grid --- + if saved_paths: + print(f"\nCreating preview grid at {PRED_DIR / 'preview_overlays.png'}...") + + # <--- MODIFIED SECTION START ---> + # Adjust preview grid to better fit the new wide images + n_rows = min(N_VIS, 8) + + # Set a fixed width (e.g., 15 inches) and calculate row height based on + # the 3:1 aspect ratio of the new combined images (n_cols * 5, 5.5) + aspect_ratio = (n_cols * 5) / 5.5 + fig_width = 15.0 + row_height = fig_width / aspect_ratio + + fig, axes = plt.subplots( + nrows=n_rows, ncols=1, figsize=(fig_width, row_height * n_rows) + ) + # <--- MODIFIED SECTION END ---> + + if not isinstance(axes, np.ndarray): + axes = np.array([axes]) + + for ax, p in zip(axes.flat, saved_paths[:n_rows]): + ax.imshow(plt.imread(p)) + ax.set_title(p.name) + ax.axis("off") + + preview_path = PRED_DIR / "preview_overlays.png" + fig.tight_layout() + fig.savefig(preview_path, dpi=150) + plt.close(fig) + print(f"Saved {len(saved_paths)} combined samples to: {PRED_DIR}") + print(f"Preview grid saved: {preview_path}") + + +if __name__ == "__main__": + main() diff --git a/recognition/hipmri2d_unet_45807321/report_assets/curves.png b/recognition/hipmri2d_unet_45807321/report_assets/curves.png new file mode 100644 index 000000000..0b91425c5 Binary files /dev/null and b/recognition/hipmri2d_unet_45807321/report_assets/curves.png differ diff --git a/recognition/hipmri2d_unet_45807321/report_assets/report_network.png b/recognition/hipmri2d_unet_45807321/report_assets/report_network.png new file mode 100644 index 000000000..3112edd94 Binary files /dev/null and b/recognition/hipmri2d_unet_45807321/report_assets/report_network.png differ diff --git a/recognition/hipmri2d_unet_45807321/report_assets/sample_000_00_combined.png b/recognition/hipmri2d_unet_45807321/report_assets/sample_000_00_combined.png new file mode 100644 index 000000000..f293013b8 Binary files /dev/null and b/recognition/hipmri2d_unet_45807321/report_assets/sample_000_00_combined.png differ diff --git a/recognition/hipmri2d_unet_45807321/report_assets/sample_000_01_combined.png b/recognition/hipmri2d_unet_45807321/report_assets/sample_000_01_combined.png new file mode 100644 index 000000000..c548965ef Binary files /dev/null and b/recognition/hipmri2d_unet_45807321/report_assets/sample_000_01_combined.png differ diff --git a/recognition/hipmri2d_unet_45807321/report_assets/sample_000_02_combined.png b/recognition/hipmri2d_unet_45807321/report_assets/sample_000_02_combined.png new file mode 100644 index 000000000..85635af67 Binary files /dev/null and b/recognition/hipmri2d_unet_45807321/report_assets/sample_000_02_combined.png differ diff --git a/recognition/hipmri2d_unet_45807321/train.py b/recognition/hipmri2d_unet_45807321/train.py new file mode 100644 index 000000000..82f2e75c1 --- /dev/null +++ b/recognition/hipmri2d_unet_45807321/train.py @@ -0,0 +1,400 @@ +# train.py +# End-to-end training script for HipMRI 2D segmentation with Improved U-Net. +from __future__ import annotations +import os +from pathlib import Path +from typing import Tuple, Dict, List +import argparse +import csv + +import torch +import torch.nn as nn +import torch.optim as optim + +# --- NEW: headless plotting + csv logging --- +import matplotlib + +matplotlib.use("Agg") # safe for clusters / no display +import matplotlib.pyplot as plt + +from dataset import make_loaders, DEFAULT_BATCH_SIZE +from modules import create_model, count_params +from utils import ( + set_seed, + to_device, + CEDiceLoss, + dice_per_class_from_logits, + AvgMeter, + save_checkpoint, +) + +# --------------------------- +# Config (Default values) +# --------------------------- +SEED = 42 +NUM_CLASSES = 6 +IN_CHANNELS = 1 + +EPOCHS = 20 +LR = 0.0005 +WEIGHT_DECAY = 1e-4 +GRAD_CLIP_NORM = 1.0 + +AMP = True # mixed precision +OUTDIR = Path("./outputs") +OUTDIR.mkdir(parents=True, exist_ok=True) +CKPT_BEST = OUTDIR / "best.pt" +CKPT_LAST = OUTDIR / "last.pt" +HIST_CSV = OUTDIR / "history.csv" +CURVES_PNG = OUTDIR / "curves.png" + +LOG_EVERY = 50 # steps + +# --------------------------- +# Train / Val loops +# --------------------------- + + +def train_one_epoch( + model: nn.Module, + loader, + optimizer: optim.Optimizer, + criterion: nn.Module, + device: torch.device, + scaler: torch.amp.GradScaler | None, + epoch: int, + grad_clip_norm: float, +) -> float: + """ + Runs a single epoch of training. + + Args: + model: The segmentation model to train. + loader: DataLoader for the training set. + optimizer: The optimizer (e.g., Adam). + criterion: The loss function (e.g., CEDiceLoss). + device: The device to train on (e.g., 'cuda'). + scaler: Gradient scaler for mixed-precision training (AMP). + epoch: The current epoch number (for logging). + grad_clip_norm: The value for gradient clipping (0 to disable). + + Returns: + The average training loss for the epoch. + """ + model.train() + loss_meter = AvgMeter() + + for step, batch in enumerate(loader, 1): + batch = to_device(batch, device) + x, y_ids = batch["image"], batch["mask"] + + optimizer.zero_grad(set_to_none=True) + + if scaler is not None: + # Mixed precision training + with torch.amp.autocast("cuda"): + logits = model(x) + loss = criterion(logits, y_ids) + scaler.scale(loss).backward() + if grad_clip_norm > 0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_( + model.parameters(), grad_clip_norm + ) + scaler.step(optimizer) + scaler.update() + else: + # Standard precision training + logits = model(x) + loss = criterion(logits, y_ids) + loss.backward() + if grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_( + model.parameters(), grad_clip_norm + ) + optimizer.step() + + loss_meter.update(loss.item(), n=x.size(0)) + + if step % LOG_EVERY == 0: + print( + f"Epoch {epoch:03d} | step {step:05d}/{len(loader):05d} | loss {loss_meter.avg:.4f}" + ) + + return loss_meter.avg + + +@torch.no_grad() +def validate( + model: nn.Module, + loader, + criterion: nn.Module, + device: torch.device, +) -> Tuple[float, torch.Tensor]: + """ + Runs validation on the given data loader. + + Args: + model: The segmentation model to evaluate. + loader: DataLoader for the validation or test set. + criterion: The loss function. + device: The device to evaluate on. + + Returns: + A tuple containing: + - The average validation loss. + - A 1D tensor of per-class Dice scores. + """ + model.eval() + + loss_meter = AvgMeter() + dice_sum = None + n_batches = 0 + + for batch in loader: + batch = to_device(batch, device) + x, y_ids = batch["image"], batch["mask"] + + logits = model(x) + loss = criterion(logits, y_ids) + loss_meter.update(loss.item(), n=x.size(0)) + + # Calculate per-class dice + dice_c = dice_per_class_from_logits(logits, y_ids) # [C] + dice_sum = dice_c if dice_sum is None else (dice_sum + dice_c) + n_batches += 1 + + dice_mean_c = dice_sum / max(n_batches, 1) # [C] + return loss_meter.avg, dice_mean_c # val_loss, per-class dice + + +# --------------------------- +# Plotting / Logging helpers +# --------------------------- + + +def write_history_csv(rows: List[Dict], path: Path) -> None: + """ + Writes a list of metric dictionaries to a CSV file. + + Args: + rows: A list of dictionaries, where each dict is one epoch's metrics. + path: The pathlib.Path object to write the CSV to. + """ + if not rows: + return + keys = list(rows[0].keys()) + with open(path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=keys) + w.writeheader() + for r in rows: + w.writerow(r) + + +def plot_curves(history: List[Dict], png_path: Path, num_classes: int) -> None: + """ + Plots training & validation loss and per-class Dice curves. + + Args: + history: A list of metric dictionaries (one per epoch). + png_path: The pathlib.Path to save the plot image to. + num_classes: The number of classes to plot Dice scores for. + """ + epochs = [h["epoch"] for h in history] + tr = [h["train_loss"] for h in history] + vl = [h["val_loss"] for h in history] + dice_per_c = [] + for c in range(num_classes): + dice_per_c.append([h[f"dice_c{c}"] for h in history]) + + fig, axes = plt.subplots(1, 2, figsize=(12, 4)) + + # Loss curves + axes[0].plot(epochs, tr, label="train_loss") + axes[0].plot(epochs, vl, label="val_loss") + axes[0].set_xlabel("Epoch") + axes[0].set_ylabel("Loss") + axes[0].set_title("Loss Curves") + axes[0].legend() + axes[0].grid(True, alpha=0.3) + + # Dice curves per class + for c in range(num_classes): + axes[1].plot(epochs, dice_per_c[c], label=f"Dice C{c}") + axes[1].set_xlabel("Epoch") + axes[1].set_ylabel("Dice") + axes[1].set_title("Per-class Dice (Validation)") + axes[1].set_ylim(0.0, 1.0) + axes[1].legend() + axes[1].grid(True, alpha=0.3) + + fig.tight_layout() + fig.savefig(png_path, dpi=160) + plt.close(fig) + + +# --------------------------- +# Main +# --------------------------- + + +def main() -> None: + """ + Main function to orchestrate the end-to-end training and evaluation process. + + Parses arguments, sets up the model, data, optimizer, and scheduler, + runs the training loop, and performs final evaluation. + """ + # --- ADDED: Argument Parser --- + parser = argparse.ArgumentParser(description="HipMRI 2D U-Net Training") + parser.add_argument( + "--seed", type=int, default=SEED, help=f"Random seed (default: {SEED})" + ) + parser.add_argument( + "--lr", type=float, default=LR, help=f"Learning rate (default: {LR})" + ) + parser.add_argument( + "--weight_decay", + type=float, + default=WEIGHT_DECAY, + help=f"Adam weight decay (default: {WEIGHT_DECAY})", + ) + parser.add_argument( + "--grad_clip_norm", + type=float, + default=GRAD_CLIP_NORM, + help=f"Gradient clipping norm, 0 to disable (default: {GRAD_CLIP_NORM})", + ) + args = parser.parse_args() + # --------------------------------- + + print("==> HipMRI 2D — Improved U-Net training") + + # --- ADDED: Print settings --- + print("==> Settings:") + print(f" Seed: {args.seed}") + print(f" LR: {args.lr}") + print(f" Weight Decay: {args.weight_decay}") + print(f" Grad Clip Norm: {args.grad_clip_norm}") + print(f" Epochs: {EPOCHS}") + print(f" Batch Size: {DEFAULT_BATCH_SIZE}") + print(f" AMP: {AMP}") + print(f" Output Dir: {OUTDIR.as_posix()}") + print("-" * 30) + # ----------------------------- + + # Repro + set_seed(args.seed) + + # Device & AMP + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + scaler = torch.amp.GradScaler("cuda") if (AMP and device.type == "cuda") else None + print(f"Device: {device} | AMP: {scaler is not None}") + + # Data + train_loader, val_loader, test_loader = make_loaders( + batch_size=DEFAULT_BATCH_SIZE, + ) + print( + f"Train/Val/Test batches: {len(train_loader)}/{len(val_loader)}/{len(test_loader)}" + ) + + # Model + model = create_model( + in_channels=IN_CHANNELS, + num_classes=NUM_CLASSES, + ).to(device) + print(f"Model params: {count_params(model):,}") + + # Loss, Optim, Scheduler + criterion = CEDiceLoss(num_classes=NUM_CLASSES, alpha_ce=0.3, alpha_dice=0.7) + optimizer = optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=3 + ) + + best_val = float("inf") + + # --- History for plots/CSV --- + history: List[Dict] = [] + + for epoch in range(1, EPOCHS + 1): + train_loss = train_one_epoch( + model, + train_loader, + optimizer, + criterion, + device, + scaler, + epoch, + grad_clip_norm=args.grad_clip_norm, + ) + val_loss, dice_c = validate(model, val_loader, criterion, device) + + # Logging + dice_str = " ".join([f"C{ci}:{d.item():.3f}" for ci, d in enumerate(dice_c)]) + print( + f"[Epoch {epoch:03d}] train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | dice({NUM_CLASSES}): {dice_str}" + ) + + # Record LR (first param group) + curr_lr = next(iter(optimizer.param_groups))["lr"] + + # Save to history + rec = { + "epoch": epoch, + "train_loss": float(train_loss), + "val_loss": float(val_loss), + "lr": float(curr_lr), + } + for ci, d in enumerate(dice_c): + rec[f"dice_c{ci}"] = float(d.item()) + history.append(rec) + + # Scheduler on val loss + scheduler.step(val_loss) + + # Checkpoints + save_checkpoint( + CKPT_LAST.as_posix(), model, optimizer, epoch, extra={"val_loss": val_loss} + ) + if val_loss < best_val: + best_val = val_loss + save_checkpoint( + CKPT_BEST.as_posix(), + model, + optimizer, + epoch, + extra={"val_loss": val_loss}, + ) + print(f" ↳ New best! Saved to {CKPT_BEST}") + + # Update plots & CSV each epoch (so you can watch mid-run) + write_history_csv(history, HIST_CSV) + plot_curves(history, CURVES_PNG, NUM_CLASSES) + + # Final test evaluation (optional, after best/last) + print("==> Evaluating on test split (using last epoch weights)...") + test_loss, test_dice_c = validate(model, test_loader, criterion, device) + test_dice_str = " ".join( + [f"C{ci}:{d.item():.3f}" for ci, d in enumerate(test_dice_c)] + ) + print(f"[Test] loss={test_loss:.4f} | dice({NUM_CLASSES}): {test_dice_str}") + + # Append final test row to CSV (without plotting new points) + final_rec = { + "epoch": EPOCHS + 1, + "train_loss": float("nan"), + "val_loss": float(test_loss), + "lr": float(curr_lr), + } + for ci, d in enumerate(test_dice_c): + final_rec[f"dice_c{ci}"] = float(d.item()) + history.append(final_rec) + write_history_csv(history, HIST_CSV) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/recognition/hipmri2d_unet_45807321/utils.py b/recognition/hipmri2d_unet_45807321/utils.py new file mode 100644 index 000000000..0773e03f5 --- /dev/null +++ b/recognition/hipmri2d_unet_45807321/utils.py @@ -0,0 +1,287 @@ +# utils.py +# Small utilities for training: one-hot, Dice metrics/loss, meters, seeding, checkpoints. + +from __future__ import annotations +import os +import random +from dataclasses import dataclass +from typing import Dict, Iterable, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + + +# --------------------------- +# Reproducibility +# --------------------------- + + +def set_seed(seed: int = 42) -> None: + """ + Sets the random seed for Python, NumPy, and PyTorch for reproducible results. + + Args: + seed: The integer seed value to use. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # no-op if CUDA not available + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +# --------------------------- +# Tensor helpers +# --------------------------- + + +def to_device( + batch: Dict[str, torch.Tensor], device: torch.device +) -> Dict[str, torch.Tensor]: + """ + Moves a dictionary of tensors to the specified device (e.g., 'cuda' or 'cpu'). + + Args: + batch: A dictionary where keys are strings and values are tensors + (or other data types which will be ignored). + device: The target torch.device to move tensors to. + + Returns: + A new dictionary with all tensor values moved to the specified device. + """ + out = {} + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + out[k] = v.to(device, non_blocking=True) + else: + out[k] = v # Keep non-tensor items (like paths) as-is + return out + + +def labels_to_onehot(y: torch.Tensor, num_classes: int) -> torch.Tensor: + """ + Converts a batch of integer label masks to a one-hot encoded format. + + Args: + y: A tensor of integer labels with shape [B, H, W]. + num_classes: The total number of classes (C). + + Returns: + A one-hot encoded tensor with shape [B, C, H, W]. + """ + # y expected long dtype; ensure safety. + y = y.long() + # F.one_hot creates [B, H, W, C], permute to [B, C, H, W] + return F.one_hot(y, num_classes=num_classes).permute(0, 3, 1, 2).float() + + +# --------------------------- +# Dice metrics / Dice loss +# --------------------------- + + +@torch.no_grad() +def dice_per_class_from_logits( + logits: torch.Tensor, y_true: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + """ + Calculates the Dice score for each class from raw logits. + + Args: + logits: The model's raw output tensor [B, C, H, W]. + y_true: The ground truth integer labels [B, H, W]. + eps: A small epsilon value to prevent division by zero. + + Returns: + A 1D tensor of shape [C] containing the Dice score for each class. + """ + num_classes = logits.size(1) + probs = torch.softmax(logits, dim=1) # [B, C, H, W] + y_1h = labels_to_onehot(y_true, num_classes) # [B, C, H, W] + + # Sum over batch (0) and spatial dims (2, 3) + num = 2.0 * (probs * y_1h).sum(dim=(0, 2, 3)) + den = (probs * probs).sum(dim=(0, 2, 3)) + (y_1h * y_1h).sum(dim=(0, 2, 3)) + eps + + return num / den + + +def dice_loss_from_logits( + logits: torch.Tensor, y_true_1h: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + """ + Calculates the soft Dice loss from raw logits. + This function expects the target to already be one-hot encoded. + + Args: + logits: The model's raw output tensor [B, C, H, W]. + y_true_1h: The ground truth one-hot encoded labels [B, C, H, W]. + eps: A small epsilon value to prevent division by zero. + + Returns: + A scalar tensor representing the mean Dice loss (1.0 - mean_dice). + """ + probs = torch.softmax(logits, dim=1) + + # Sum over batch (0) and spatial dims (2, 3) + num = 2.0 * (probs * y_true_1h).sum(dim=(0, 2, 3)) + den = ( + (probs * probs).sum(dim=(0, 2, 3)) + + (y_true_1h * y_true_1h).sum(dim=(0, 2, 3)) + + eps + ) + + dice_per_class = num / den + return 1.0 - dice_per_class.mean() # Mean loss across classes + + +# --------------------------- +# Combined loss (CE + Dice) +# --------------------------- + + +class CEDiceLoss(torch.nn.Module): + """ + A combined loss function that is a weighted sum of Cross-Entropy + and Dice loss. + + This is useful for segmentation tasks to balance pixel-wise accuracy (CE) + with spatial overlap (Dice). + """ + + def __init__( + self, + num_classes: int, + alpha_ce: float = 0.5, + alpha_dice: float = 0.5, + ignore_index: int | None = None, + ): + """ + Initializes the CEDiceLoss module. + + Args: + num_classes: The number of classes for one-hot encoding. + alpha_ce: The weight for the Cross-Entropy loss component. + alpha_dice: The weight for the Dice loss component. + ignore_index: An optional class index to ignore in CE loss. + """ + super().__init__() + self.num_classes = num_classes + self.alpha_ce = alpha_ce + self.alpha_dice = alpha_dice + self.ignore_index = ignore_index + + def forward(self, logits: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + """ + Calculates the combined loss. + + Args: + logits: The model's raw output tensor [B, C, H, W]. + y_true: The ground truth integer labels [B, H, W]. + + Returns: + A scalar tensor representing the final combined loss. + """ + # Calculate Cross-Entropy loss + if self.ignore_index is None: + loss_ce = F.cross_entropy(logits, y_true) + else: + loss_ce = F.cross_entropy(logits, y_true, ignore_index=self.ignore_index) + + # Calculate Dice loss + y_1h = labels_to_onehot(y_true, num_classes=self.num_classes) + loss_dice = dice_loss_from_logits(logits, y_1h) + + # Combine losses + return self.alpha_ce * loss_ce + self.alpha_dice * loss_dice + + +# --------------------------- +# Running meters +# --------------------------- + + +@dataclass +class AvgMeter: + """ + A simple data class to track the running average of a metric. + """ + + total: float = 0.0 + count: int = 0 + + def update(self, val: float, n: int = 1) -> None: + """ + Updates the meter with a new value and count. + + Args: + val: The value to add (e.g., loss for a batch). + n: The number of items this value represents (e.g., batch size). + """ + self.total += float(val) * n + self.count += n + + @property + def avg(self) -> float: + """Calculates the current average.""" + return self.total / max(self.count, 1) # Avoid division by zero + + +# --------------------------- +# Checkpoint helpers +# --------------------------- + + +def save_checkpoint( + path: str, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer | None = None, + epoch: int | None = None, + extra: Dict | None = None, +) -> None: + """ + Saves a model checkpoint to a .pt file. + + Args: + path: The full path to save the checkpoint file (e.g., 'outputs/best.pt'). + model: The model to save (state_dict will be saved). + optimizer: An optional optimizer to save (state_dict will be saved). + epoch: An optional epoch number to save. + extra: An optional dictionary of extra data to save (e.g., val_loss). + """ + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + state = {"model": model.state_dict()} + if optimizer is not None: + state["optimizer"] = optimizer.state_dict() + if epoch is not None: + state["epoch"] = epoch + if extra is not None: + state["extra"] = extra + torch.save(state, path) + + +def load_checkpoint( + path: str, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer | None = None, + map_location: str | torch.device = "cpu", +) -> Dict: + """ + Loads a model checkpoint from a .pt file. + + Args: + path: The full path to the checkpoint file. + model: The model instance to load the state_dict into. + optimizer: An optional optimizer instance to load the state_dict into. + map_location: The device to load the checkpoint onto. + + Returns: + The full checkpoint dictionary that was loaded. + """ + ckpt = torch.load(path, map_location=map_location) + model.load_state_dict(ckpt["model"]) + if optimizer is not None and "optimizer" in ckpt: + optimizer.load_state_dict(ckpt["optimizer"]) + return ckpt