diff --git a/recognition/README.md b/recognition/README.md new file mode 100644 index 000000000..32c99e899 --- /dev/null +++ b/recognition/README.md @@ -0,0 +1,10 @@ +# Recognition Tasks +Various recognition tasks solved in deep learning frameworks. + +Tasks may include: +* Image Segmentation +* Object detection +* Graph node classification +* Image super resolution +* Disease classification +* Generative modelling with StyleGAN and Stable Diffusion \ No newline at end of file diff --git a/recognition/adni-convnext-simar/.gitignore b/recognition/adni-convnext-simar/.gitignore new file mode 100644 index 000000000..4a0af1fb9 --- /dev/null +++ b/recognition/adni-convnext-simar/.gitignore @@ -0,0 +1,34 @@ +# OS +.DS_Store + +# Python +__pycache__/ +*.pyc + +# HPC / SLURM +*.slurm +slurm-*.out +*.out +*.err + +# Training artifacts +runs/ +checkpoints/ +*.pt +*.pth +*.ckpt + +# Local notebooks / scratch +*.ipynb_checkpoints + +# Eval byproducts +recognition/adni-convnext-simar/images/metrics.json + +# HPC/SLURM +*.slurm +slurm-*.out +*.out +*.err + +# macOS +.DS_Store diff --git a/recognition/adni-convnext-simar/README.md b/recognition/adni-convnext-simar/README.md new file mode 100644 index 000000000..97a1748a0 --- /dev/null +++ b/recognition/adni-convnext-simar/README.md @@ -0,0 +1,341 @@ +# ConvNeXt for AD vs NC MRI Slice Classification (ADNI) + +> version: 1.0 +> author: Simar Wadhawan +> course: COMP3710 – Pattern Analysis & Recognition +> repo_url: https://github.com/simarrawadhawan/PatternAnalysis-2025 +> branch: topic-recognition + +## Summary: +This project fine-tunes a ConvNeXt (Small) CNN on a 2-class ADNI slice dataset (AD vs NC). It implements a clean PyTorch training/evaluation pipeline with: label smoothing, weighted loss, 384px inputs, cosine LR schedule with warmup, TTA at test-time +(orig +hflip), and threshold tuning. Final slice-level test accuracy is ~78% with strong NC recall while improving AD recall over the +baseline. All code is modularized (dataset.py, train.py, predict.py, modules.py). + +### Why Convnext: +ConvNeXt is a modern ConvNet that matches or beats transformer baselines (e.g., Swin) when trained on ViT-style design choices (depthwise convs, GELU, LayerNorm in convnets, larger patch sizes). It remains compute-efficient and stable for grayscale medical imaging. + +# Table of Contents: + - Introduction + - Dataset + - Repository Structure + - Environment & Setup + - Training + - Evaluation + - Results (slice-level only) + - Visualisations + - Reproducibility & Determinism + - Notes for Markers + - References + - License + +--- + +# Introduction: + We tackle binary classification of brain MRI slices into Alzheimer’s Disease (AD) vs Normal Control (NC) using a ConvNeXt-S backbone. The pipeline emphasizes robust training (class weights, label smoothing) and careful evaluation (TTA + calibrated threshold). We report **slice-level** metrics per the course rubric. + +## Dataset: + - Name: ADNI (Alzheimer’s Disease Neuroimaging Initiative) + - Task: Binary classification on 2D MRI slices: AD vs NC + - Classes: ["AD", "NC"] + - Approximate_counts: + Total Images: 30520 + Train Images: 21520 + Test Images: 9000 + - Balance Note: Roughly balanced per split, mild skew per class. + + ### Expected Directory: + + ADNI/ + ├── meta_data_with_label.json + └── AD_NC/ + ├── train/ + │ ├── AD/ + │ └── NC/ + └── test/ + ├── AD/ + └── NC/ + + ### Transforms: + + Train: + - Resize to 384×384 + - RandomAffine (small degrees/translate/scale) + - RandomHorizontalFlip(p=0.5) + - ToTensor + Normalize (ImageNet stats for ConvNeXt pretraining) + Test: + - Resize to 384×384 + - ToTensor + Normalize + TTA: "Original + horizontal flip, logits averaged" + + +### Repository Structure: + ```text + recognition/adni-convnext-simar/ + ├── dataset.py # ADNIDataset + transforms + ├── modules.py # build_model() → ConvNeXt variants via timm + ├── train.py # training loop, early-stopping, checkpointing + ├── predict.py # evaluation + visualisations + metrics.json + ├── utils.py # helpers (metrics, samplers, seed utils) + ├── requirements.txt # pinned deps +``` + +## Environment: + - Python: >=3.10 + - Requirements_file: requirements.txt + - Key Packages: + - torch, torchvision + - timm + - numpy, pandas + - scikit-learn + - matplotlib, seaborn + - tqdm + +## Installations: + conda create -n comp3710 python=3.10 -y + conda activate comp3710 + + # install deps + pip install -r requirements.txt + + ## Paths: + - Data Root: "/home/groups/comp3710/ADNI/AD_NC" + - Checkpoint Example: "runs/adni_384_ls01_nomix_last/best.pt" + - Image Outputs: "images/" + +--- + +# Training: + - Model: ConvNeXt-Small (timm) + - Image Size: 384 + - Batch Size: 32 + - Epochs: 30-60 + - Optimizer: AdamW + - Scheduler: CosineAnnealingLR with warmup + - Learning Rate: 5.0e-5 + - Weight Decay: 0.05 + - Label Smoothing: 0.1 + - Class Weights: "Computed from train split (AD vs NC)" + - Mixup: Disabled in final 78% run + - Early Stopping: + - Monitor: "val_acc" + - Patience: 10 + - Freeze Policy: "Unfreeze last 3 stages + head (≈99.5% trainable here)" + + ## Commands: + python train.py \ + --data-root /home/groups/comp3710/ADNI/AD_NC \ + --model convnext_small \ + --img-size 384 \ + --batch-size 32 \ + --epochs 60 \ + --lr 5e-5 \ + --weight-decay 0.05 \ + --label-smoothing 0.1 \ + --no-mixup \ + --save-dir runs/adni_384_ls01_nomix_last + +--- + +# Evaluation: + - Mode: Slice-level (primary, per rubric) + - TTA: orig + hflip (averaged) + - Threshold: P(NC) = 0.55 used for the 78.18% run + + # Outputs: + + images/ + ├── confusion_matrix.png + ├── roc_curve.png + ├── performance_metrics.png + ├── sample_predictions.png + ├── misclassified_samples.png + └── metrics.json + + ## Commands: + python predict.py \ + --checkpoint runs/adni_384_ls01_nomix_last/best.pt \ + --data-root /home/groups/comp3710/ADNI/AD_NC \ + --model convnext_small \ + --batch-size 32 \ + --img-size 384 \ + --save-dir images \ + --num-workers 4 + + - Download from Rangpur: + > scp -r s4977354@rangpur.compute.eait.uq.edu.au:/home/Student/s4977354/PatternAnalysis-2025/recognition/adni-convnext-simar/images ./ + +--- + +# Results Slice Level: + ## Overall: + - Accuracy: 0.7818 + - precision_weighted: 0.7923 + - recall_weighted: 0.7818 + - f1_weighted: 0.7795 + - specificity_AD: 0.6827 + - auc_roc: 0.8497 + - total_samples: 9000 + + ## Per Class: + - AD: + - Precision: 0.8472 + - Recall: 0.6827 + - f1: 0.7561 + - Support: 4460 + + - NC: + - Precision: 0.7383 + - Recall: 0.8791 + - f1: 0.8025 + - Support: 4540 + + - Confusion Matrix Counts: + - TN_AD_correct: 3045 + - FP_AD_as_NC: 1415 + - FN_NC_as_AD: 549 + - TP_NC_correct: 3991 + + Note: + > Metrics above correspond to TTA (orig+hflip) with threshold P(NC)=0.55. + This configuration yielded the best slice-level balance in our runs. + +--- + +# Figures: + + ## Results — Visualizations (Slice-Level, Threshold = 0.55, TTA = orig+hflip) + +

+ Confusion Matrix (Slice-Level) +
+ Figure 1. Confusion Matrix — slice-level. +

+ +**What it shows (slice-level):** +Accuracy **78.18%** (7036/9000). True AD (TN)=**3045**, False AD→NC (FP)=**1415**; True NC (TP)=**3991**, False NC→AD (FN)=**549**. Model is slightly conservative for NC (higher NC recall), with most errors being AD→NC. + +--- + +

+ ROC Curve (Slice-Level) +
+ Figure 2. ROC Curve — slice-level. +

+ +**What it shows:** +Overall discrimination between AD and NC. The curve sits well above the diagonal (random). AUC summarizes ranking performance; even when the threshold is fixed at 0.55 for reporting, the ROC shows performance across all thresholds. + +--- + +

+ Performance Metrics (Slice-Level) +
+ Figure 3. Metric Summary — slice-level. +

+ +**What it shows:** +Accuracy **0.7818**, Precision **0.7923**, Recall **0.7818**, F1 **0.7795**, Specificity **0.6827**. Balanced Precision/Recall, with Specificity lower than Sensitivity → the model favors catching NC over rejecting AD false alarms. + +--- + +

+ Sample Predictions (Correct & Incorrect) +
+ Figure 4. Sample Predictions — slice-level. +

+ +**What it shows:** +A random subset of predictions with model confidence. Green titles are correct; red would indicate errors. This provides qualitative insight into what the model considers “confidently” AD vs NC at the slice level. + +--- + +

+ Misclassified Samples (Slice-Level) +
+ Figure 5. Misclassified Samples — slice-level. +

+ +**What it shows:** +Representative failures (always red titles). Typical mistakes include anatomically subtle slices (early AD or near-blank edge slices) where class cues are weak. These errors motivate aggregation or weighting by slice informativeness in future work. + +--- + +

+ Combined Confusion Matrix + ROC (Slice-Level) +
+ Figure 6. Combined CM + ROC — slice-level (for compact reporting). +

+ +**What it shows:** +A single panel to include in reports: confusion matrix counts plus ROC curve. Useful when space is limited but both error types and discrimination performance must be shown. + +--- + +### Metrics (Slice-Level, for Repro) +- **Accuracy:** 78.18% +- **Precision (weighted):** 0.7923 +- **Recall (weighted):** 0.7818 +- **F1 (weighted):** 0.7795 +- **Specificity (AD as negative class):** 0.6827 +- **Counts:** TN=3045, FP=1415, FN=549, TP=3991 +- **Setup:** ConvNeXt Small @ 384px, label smoothing=0.1, **no mixup**, TTA (orig + hflip), threshold=0.55 + +--- + +# Usage Quickstart: + + ### 1) Install + conda activate comp3710 + pip install -r requirements.txt + + ### 2) Train + python train.py --data-root /home/groups/comp3710/ADNI/AD_NC --img-size 384 --batch-size 32 + + ### 3) Evaluate (slice-level + figures) + python predict.py --checkpoint runs/adni_384_ls01_nomix_last/best.pt --save-dir images + + ### 4) Copy figures off Rangpur + scp -r s4977354@rangpur.compute.eait.uq.edu.au:/home/Student/s4977354/PatternAnalysis-2025/recognition/adni-convnext-simar/images ./ + +## Checklist To Submit: + - [x] Push updated .py files (no checkpoints) + - [x] README with slice-level 78% results and figures + - [x] Open PR from topic-recognition → main with clear title/body + - [x] Export README to PDF (GitHub “Print to PDF” or Markdown → PDF) + - [x] Submit PDF + repo link + +--- + +# References: + - Name: ConvNeXt + Link: "https://arxiv.org/abs/2201.03545" + - Name: timm: PyTorch Image Models + Link: "https://github.com/huggingface/pytorch-image-models" + - Name: ADNI overview + Link: "https://adni.loni.usc.edu/" + - Name: ROC/AUC interpretation (medical) + Link: "https://pmc.ncbi.nlm.nih.gov/articles/PMC12260203/" + - ### AI Assistance & Authorship +I used **ChatGPT (GPT-5 Thinking)** as a writing/coding assistant for this project. +**Scope of assistance:** +- Helped draft and polish this README (structure, wording, and formatting). +- Suggested code refactors and guardrails (e.g., dataloader fixes, evaluation logging, commit message style). +- Generated shell/Git one-liners and SLURM templates, which I **reviewed and edited**. +- Provided troubleshooting ideas (e.g., threshold sweep, TTA, majority-vote scan aggregation) which I **implemented and verified**. + +**Not done by AI:** +- Dataset preparation, training runs, hyperparameter choices, model selection, and **all results** were executed by me. +- Figures (confusion matrix, ROC, metrics bar chart, samples) were generated directly from my evaluation scripts and saved under `recognition/adni-convnext-simar/images/`. + +**Verification & integrity:** +- I reviewed every AI suggestion before use and tested all code paths that entered the repository. +- No metrics or figures were fabricated or manually edited; they are reproducible from the provided scripts and logs. +- External sources are cited in the **References** section. + +**Provenance (dates/tools):** +- ChatGPT sessions: Oct–Nov 2025 +- Model: GPT-5 Thinking (ChatGPT) +- Purpose: drafting + developer productivity; **final technical decisions are mine**. + + +License: Apache-2.0 (same as course starter) diff --git a/recognition/adni-convnext-simar/dataset.py b/recognition/adni-convnext-simar/dataset.py new file mode 100644 index 000000000..d210cde0c --- /dev/null +++ b/recognition/adni-convnext-simar/dataset.py @@ -0,0 +1,107 @@ +import os +from dataclasses import dataclass +from typing import List, Tuple + +from PIL import Image +import torch +from torch.utils.data import Dataset +from torchvision import transforms + +CLASS_TO_IDX = {"AD": 0, "NC": 1} + + +@dataclass +class ADNIConfig: + data_root: str = "/home/groups/comp3710/ADNI/AD_NC" + img_size: int = 384 + mean: Tuple[float, float, float] = (0.485, 0.456, 0.406) + std: Tuple[float, float, float] = (0.229, 0.224, 0.225) + train_ratio: float = 0.8 + + +class ADNIDataset(Dataset): + def __init__(self, data_root: str, split: str, img_size: int = 384): + self.data_root = data_root + self.split = split + self.samples = self.discover_samples() + self.tform = self._build_transforms(img_size) + + def discover_samples(self) -> List[Tuple[str, int]]: + samples: List[Tuple[str, int]] = [] + + if self.split in ("train", "val"): + split_dir = os.path.join(self.data_root, "train") + else: + split_dir = os.path.join(self.data_root, "test") + + for cls in ("AD", "NC"): + cls_dir = os.path.join(split_dir, cls) + if not os.path.isdir(cls_dir): + continue + for root, _, files in os.walk(cls_dir): + for f in files: + if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")): + samples.append((os.path.join(root, f), CLASS_TO_IDX[cls])) + + if not samples: + raise FileNotFoundError( + f"No images found under {split_dir}. Expected subfolders AD/ and NC/." + ) + return samples + + def _build_transforms(self, img_size: int): + if self.split == "train": + # Softer, MRI-sensible augmentations + better inductive bias for ConvNeXt + return transforms.Compose([ + transforms.RandomResizedCrop(img_size, scale=(0.85, 1.0)), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomRotation(10), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225)), + transforms.RandomErasing(p=0.1), + ]) + else: + return transforms.Compose([ + transforms.Resize((img_size, img_size)), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225)), + ]) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx: int): + path, label = self.samples[idx] + with Image.open(path) as img: + img = img.convert("RGB") + x = self.tform(img) + y = torch.tensor(label, dtype=torch.long) + return x, y, path + + +def make_splits(cfg: ADNIConfig, seed: int = 42): + """ + Create train and val datasets from the train folder. + Uses random_split to divide the training data. + """ + from torch.utils.data import random_split + + # Create full training dataset + full_train = ADNIDataset(cfg.data_root, split="train", img_size=cfg.img_size) + + # Calculate split sizes + total_size = len(full_train) + train_size = int(cfg.train_ratio * total_size) + val_size = total_size - train_size + + # Split the dataset + train_set, val_set = random_split( + full_train, + [train_size, val_size], + generator=torch.Generator().manual_seed(seed) + ) + + return train_set, val_set + diff --git a/recognition/adni-convnext-simar/images/confusion_matrix.png b/recognition/adni-convnext-simar/images/confusion_matrix.png new file mode 100644 index 000000000..6d15331fd Binary files /dev/null and b/recognition/adni-convnext-simar/images/confusion_matrix.png differ diff --git a/recognition/adni-convnext-simar/images/confusion_matrix_roc.png b/recognition/adni-convnext-simar/images/confusion_matrix_roc.png new file mode 100644 index 000000000..24117ce20 Binary files /dev/null and b/recognition/adni-convnext-simar/images/confusion_matrix_roc.png differ diff --git a/recognition/adni-convnext-simar/images/misclassified_samples.png b/recognition/adni-convnext-simar/images/misclassified_samples.png new file mode 100644 index 000000000..4550f562a Binary files /dev/null and b/recognition/adni-convnext-simar/images/misclassified_samples.png differ diff --git a/recognition/adni-convnext-simar/images/performance_metrics.png b/recognition/adni-convnext-simar/images/performance_metrics.png new file mode 100644 index 000000000..6622b03eb Binary files /dev/null and b/recognition/adni-convnext-simar/images/performance_metrics.png differ diff --git a/recognition/adni-convnext-simar/images/roc_curve.png b/recognition/adni-convnext-simar/images/roc_curve.png new file mode 100644 index 000000000..34e185b35 Binary files /dev/null and b/recognition/adni-convnext-simar/images/roc_curve.png differ diff --git a/recognition/adni-convnext-simar/images/sample_predictions.png b/recognition/adni-convnext-simar/images/sample_predictions.png new file mode 100644 index 000000000..a948ae3b7 Binary files /dev/null and b/recognition/adni-convnext-simar/images/sample_predictions.png differ diff --git a/recognition/adni-convnext-simar/modules.py b/recognition/adni-convnext-simar/modules.py new file mode 100644 index 000000000..81b6cb72d --- /dev/null +++ b/recognition/adni-convnext-simar/modules.py @@ -0,0 +1,100 @@ +# modules.py +from __future__ import annotations +from typing import Optional +import torch +import torch.nn as nn +import timm + +__all__ = [ + "ConvNeXtClassifier", + "build_model", + "freeze_backbone", + "unfreeze_all", + "num_trainable_params", +] + +class ConvNeXtClassifier(nn.Module): + def __init__( + self, + model_name: str = "convnext_tiny", + num_classes: int = 2, + pretrained: bool = True, + drop_rate: float = 0.0, + ) -> None: + super().__init__() + self.backbone = timm.create_model( + model_name, + pretrained=pretrained, + num_classes=num_classes, + drop_rate=drop_rate, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.backbone(x) + + +def build_model( + model_name: str = "convnext_tiny", + num_classes: int = 2, + pretrained: bool = True, + drop_rate: float = 0.0, +) -> nn.Module: + """ + Build a ConvNeXt model with all layers unfrozen for training. + """ + model = ConvNeXtClassifier( + model_name=model_name, + num_classes=num_classes, + pretrained=pretrained, + drop_rate=drop_rate, + ) + + # Ensure all parameters are trainable + unfreeze_all(model) + + return model + + +def freeze_backbone(model: nn.Module, except_head: bool = True) -> None: + """ + Freeze parameters for finetuning. If except_head=True, keep classifier head trainable. + Works with common timm naming (e.g., model.backbone.head). + """ + # Freeze all parameters first + for p in model.parameters(): + p.requires_grad = False + + if except_head: + # Find and unfreeze the classifier head + head = None + if hasattr(model, "backbone") and hasattr(model.backbone, "get_classifier"): + head = model.backbone.get_classifier() + elif hasattr(model, "backbone") and hasattr(model.backbone, "head"): + head = model.backbone.head + + if head is None: + # Fallback: find last linear/conv layer + for m in model.modules(): + if isinstance(m, (nn.Linear, nn.Conv2d)): + head = m + + if head is None: + # If we still can't find it, unfreeze everything + for p in model.parameters(): + p.requires_grad = True + return + + # Unfreeze the head + for p in head.parameters(): + p.requires_grad = True + + +def unfreeze_all(model: nn.Module) -> None: + """Unfreeze all model parameters""" + for p in model.parameters(): + p.requires_grad = True + + +def num_trainable_params(model: nn.Module) -> int: + """Count the number of trainable parameters""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) diff --git a/recognition/adni-convnext-simar/predict.py b/recognition/adni-convnext-simar/predict.py new file mode 100644 index 000000000..3a1d6124f --- /dev/null +++ b/recognition/adni-convnext-simar/predict.py @@ -0,0 +1,333 @@ +""" +predict.py - Comprehensive Model Evaluation and Visualization (with TTA + threshold) +""" + +import argparse +import json +from pathlib import Path + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import numpy as np +import matplotlib +matplotlib.use('Agg') # Non-interactive backend for Rangpur +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.metrics import ( + confusion_matrix, classification_report, roc_curve, auc, + accuracy_score, precision_score, recall_score, f1_score +) +from tqdm import tqdm + +from modules import build_model +from dataset import ADNIDataset + +IDX_TO_CLASS = {0: "AD", 1: "NC"} +CLASS_TO_IDX = {"AD": 0, "NC": 1} + + +def tta_logits(model, x): + """Simple TTA: average logits over original and horizontal flip.""" + logits_list = [] + with torch.no_grad(): + logits_list.append(model(x)) + logits_list.append(model(torch.flip(x, dims=[3]))) # H-flip + return torch.stack(logits_list, dim=0).mean(dim=0) + + +def evaluate_model(model, loader, device): + """ + Evaluate model and return predictions, labels, and probabilities (with TTA + threshold). + Threshold is read from evaluate_model._thr (float). + """ + model.eval() + all_preds = [] + all_labels = [] + all_probs = [] + all_paths = [] + + thr = getattr(evaluate_model, "_thr", 0.5) + + with torch.no_grad(): + for x, y, paths in tqdm(loader, desc="Evaluating"): + x = x.to(device) + # TTA logits + logits = tta_logits(model, x) + probs = F.softmax(logits, dim=1).cpu().numpy() # (B,2) + preds = (probs[:, 1] >= thr).astype(np.int64) # NC if P(NC) >= thr else AD + + # labels -> ints + labels = [label.item() if torch.is_tensor(label) else label for label in y] + + all_preds.extend(preds) + all_labels.extend(labels) + all_probs.extend(probs) + all_paths.extend(paths) + + return np.array(all_preds), np.array(all_labels), np.array(all_probs), all_paths + + +def plot_confusion_matrix(y_true, y_pred, save_path): + cm = confusion_matrix(y_true, y_pred) + plt.figure(figsize=(8, 6)) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['AD', 'NC'], yticklabels=['AD', 'NC'], + cbar_kws={'label': 'Count'}) + plt.title('Confusion Matrix', fontsize=16, fontweight='bold') + plt.ylabel('True Label', fontsize=12) + plt.xlabel('Predicted Label', fontsize=12) + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + print(f"✓ Saved confusion matrix to {save_path}") + + +def plot_roc_curve(y_true, y_probs, save_path): + # AUC for AD (class 0) as positive + fpr, tpr, _ = roc_curve(y_true, y_probs[:, 0], pos_label=0) + roc_auc = auc(fpr, tpr) + plt.figure(figsize=(8, 6)) + plt.plot(fpr, tpr, lw=2, label=f'ROC (AUC = {roc_auc:.4f})') + plt.plot([0, 1], [0, 1], lw=2, linestyle='--', label='Random') + plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate', fontsize=12) + plt.ylabel('True Positive Rate', fontsize=12) + plt.title('ROC Curve (AD positive)', fontsize=16, fontweight='bold') + plt.legend(loc="lower right"); plt.grid(alpha=0.3) + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + print(f"✓ Saved ROC curve to {save_path}") + return roc_auc + + +def plot_metrics_comparison(metrics, save_path): + metric_names = list(metrics.keys()) + metric_values = list(metrics.values()) + plt.figure(figsize=(10, 6)) + bars = plt.bar(metric_names, metric_values) + for bar in bars: + h = bar.get_height() + plt.text(bar.get_x() + bar.get_width()/2., h, f'{h:.4f}', + ha='center', va='bottom', fontsize=10, fontweight='bold') + plt.ylim([0, 1.1]); plt.ylabel('Score', fontsize=12) + plt.title('Model Performance Metrics', fontsize=16, fontweight='bold') + plt.xticks(rotation=45, ha='right'); plt.grid(axis='y', alpha=0.3) + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + print(f"✓ Saved metrics comparison to {save_path}") + + +def plot_sample_predictions(dataset, predictions, labels, probs, save_path, num_samples=16): + indices = np.random.choice(len(predictions), min(num_samples, len(predictions)), replace=False) + fig, axes = plt.subplots(4, 4, figsize=(16, 16)); axes = axes.flatten() + for idx, ax in enumerate(axes): + if idx >= len(indices): + ax.axis('off'); continue + i = indices[idx] + img, true_label, _ = dataset[i] + pred_label = int(predictions[i]) + prob = probs[i] + true_label = true_label.item() if torch.is_tensor(true_label) else true_label + img_np = img.permute(1, 2, 0).numpy() + mean = np.array([0.485, 0.456, 0.406]); std = np.array([0.229, 0.224, 0.225]) + img_np = np.clip(std * img_np + mean, 0, 1) + ax.imshow(img_np) + true_class = IDX_TO_CLASS[true_label]; pred_class = IDX_TO_CLASS[pred_label] + confidence = prob[pred_label] * 100 + color = 'green' if pred_label == true_label else 'red' + ax.set_title(f'True: {true_class} | Pred: {pred_class}\nConf: {confidence:.1f}%', + color=color, fontsize=10, fontweight='bold') + ax.axis('off') + plt.suptitle('Sample Predictions', fontsize=18, fontweight='bold', y=0.995) + plt.tight_layout(); plt.savefig(save_path, dpi=300, bbox_inches='tight'); plt.close() + print(f"✓ Saved sample predictions to {save_path}") + + +def plot_misclassified(dataset, predictions, labels, probs, save_path, num_samples=16): + mis_idx = np.where(predictions != labels)[0] + if len(mis_idx) == 0: + print("⚠ No misclassified samples found!"); return + indices = np.random.choice(mis_idx, min(num_samples, len(mis_idx)), replace=False) + fig, axes = plt.subplots(4, 4, figsize=(16, 16)); axes = axes.flatten() + for idx, ax in enumerate(axes): + if idx >= len(indices): + ax.axis('off'); continue + i = indices[idx] + img, true_label, _ = dataset[i] + pred_label = int(predictions[i]) + prob = probs[i] + true_label = true_label.item() if torch.is_tensor(true_label) else true_label + img_np = img.permute(1, 2, 0).numpy() + mean = np.array([0.485, 0.456, 0.406]); std = np.array([0.229, 0.224, 0.225]) + img_np = np.clip(std * img_np + mean, 0, 1) + ax.imshow(img_np) + true_class = IDX_TO_CLASS[true_label]; pred_class = IDX_TO_CLASS[pred_label] + confidence = prob[pred_label] * 100 + ax.set_title(f'True: {true_class} | Pred: {pred_class}\nConf: {confidence:.1f}%', + color='red', fontsize=10, fontweight='bold') + ax.axis('off') + plt.suptitle('Misclassified Samples', fontsize=18, fontweight='bold', y=0.995) + plt.tight_layout(); plt.savefig(save_path, dpi=300, bbox_inches='tight'); plt.close() + print(f"✓ Saved misclassified samples to {save_path}") + + +def plot_combined_confusion_roc(y_true, y_pred, y_probs, save_path): + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) + cm = confusion_matrix(y_true, y_pred) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['AD', 'NC'], yticklabels=['AD', 'NC'], + ax=ax1, cbar_kws={'label': 'Count'}) + ax1.set_title('Confusion Matrix', fontsize=14, fontweight='bold') + ax1.set_ylabel('True Label', fontsize=11); ax1.set_xlabel('Predicted Label', fontsize=11) + fpr, tpr, _ = roc_curve(y_true, y_probs[:, 0], pos_label=0) + roc_auc = auc(fpr, tpr) + ax2.plot(fpr, tpr, lw=2, label=f'ROC (AUC = {roc_auc:.4f})') + ax2.plot([0, 1], [0, 1], lw=2, linestyle='--', label='Random') + ax2.set_xlim([0.0, 1.0]); ax2.set_ylim([0.0, 1.05]) + ax2.set_xlabel('False Positive Rate', fontsize=11) + ax2.set_ylabel('True Positive Rate', fontsize=11) + ax2.set_title('ROC Curve (AD positive)', fontsize=14, fontweight='bold') + ax2.legend(loc="lower right"); ax2.grid(alpha=0.3) + plt.tight_layout(); plt.savefig(save_path, dpi=300, bbox_inches='tight'); plt.close() + print(f"✓ Saved combined confusion matrix & ROC to {save_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate ConvNeXt model on ADNI test set") + parser.add_argument("--checkpoint", type=str, default="runs/adni_convnext/best.pt", + help="Path to model checkpoint") + parser.add_argument("--data-root", type=str, default="/home/groups/comp3710/ADNI/AD_NC", + help="Path to ADNI dataset") + parser.add_argument("--model", type=str, default="convnext_small", + help="Model architecture") + parser.add_argument("--batch-size", type=int, default=32, + help="Batch size for evaluation") + parser.add_argument("--img-size", type=int, default=384, + help="Input image size") + parser.add_argument("--save-dir", type=str, default="./images", + help="Directory to save visualizations") + parser.add_argument("--num-workers", type=int, default=4, + help="Number of data loading workers") + parser.add_argument("--threshold", type=float, default=0.55, + help="Decision threshold on P(NC). 0.5=argmax equivalent.") + args = parser.parse_args() + + save_dir = Path(args.save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("="*70) + print("ConvNeXt ADNI Model Evaluation") + print("="*70) + print(f"Device: {device}") + print(f"Checkpoint: {args.checkpoint}") + print(f"Data root: {args.data_root}") + print(f"Save directory: {save_dir}") + print(f"Threshold (P(NC)): {args.threshold}") + print("="*70 + "\n") + + print("Loading test dataset...") + test_set = ADNIDataset(args.data_root, split="test", img_size=args.img_size) + test_loader = DataLoader(test_set, batch_size=args.batch_size, + shuffle=False, num_workers=args.num_workers, + pin_memory=True) + print(f"✓ Test samples: {len(test_set)}\n") + + print(f"Loading model...") + model = build_model(args.model, num_classes=2, pretrained=False, drop_rate=0.2) + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=True) + model.to(device) + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"✓ Model loaded successfully") + print(f" Total parameters: {total_params:,}") + print(f" Trainable parameters: {trainable_params:,}\n") + + # Set threshold for evaluator + evaluate_model._thr = args.threshold + + print("Evaluating model on test set...") + predictions, labels, probs, paths = evaluate_model(model, test_loader, device) + print("✓ Evaluation complete\n") + + accuracy = accuracy_score(labels, predictions) + precision = precision_score(labels, predictions, average='weighted') + recall = recall_score(labels, predictions, average='weighted') + f1 = f1_score(labels, predictions, average='weighted') + cm = confusion_matrix(labels, predictions) + tn, fp, fn, tp = cm.ravel() + specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 + + metrics = { + 'Accuracy': accuracy, + 'Precision': precision, + 'Recall': recall, + 'F1-Score': f1, + 'Specificity': specificity + } + + print("="*70) + print("EVALUATION RESULTS") + print("="*70) + print(f"Total Samples: {len(labels)}") + print(f"Correct Predictions: {np.sum(predictions == labels)} ({np.sum(predictions == labels)/len(labels)*100:.2f}%)") + print(f"Incorrect Predictions: {np.sum(predictions != labels)} ({np.sum(predictions != labels)/len(labels)*100:.2f}%)") + print() + print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)") + print(f"Precision: {precision:.4f}") + print(f"Recall: {recall:.4f}") + print(f"F1 Score: {f1:.4f}") + print(f"Specificity: {specificity:.4f}") + + print("\n" + "-"*70) + print("Per-Class Metrics:") + print("-"*70) + print(classification_report(labels, predictions, target_names=['AD', 'NC'], digits=4)) + + print("Confusion Matrix Breakdown:") + print(f" True Negatives (AD correctly classified): {cm[0,0]:>5d}") + print(f" False Positives (AD misclassified as NC): {cm[0,1]:>5d}") + print(f" False Negatives (NC misclassified as AD): {cm[1,0]:>5d}") + print(f" True Positives (NC correctly classified): {cm[1,1]:>5d}") + print("="*70 + "\n") + + print("Generating visualizations...") + print("-"*70) + plot_confusion_matrix(labels, predictions, save_dir / "confusion_matrix.png") + roc_auc = plot_roc_curve(labels, probs, save_dir / "roc_curve.png") + metrics['AUC-ROC'] = roc_auc + plot_combined_confusion_roc(labels, predictions, probs, save_dir / "confusion_matrix_roc.png") + plot_metrics_comparison(metrics, save_dir / "performance_metrics.png") + plot_sample_predictions(test_set, predictions, labels, probs, save_dir / "sample_predictions.png", num_samples=16) + plot_misclassified(test_set, predictions, labels, probs, save_dir / "misclassified_samples.png", num_samples=16) + + metrics_path = save_dir / "metrics.json" + metrics_dict = {k: float(v) for k, v in metrics.items()} + metrics_dict['total_samples'] = int(len(labels)) + metrics_dict['correct_predictions'] = int(np.sum(predictions == labels)) + metrics_dict['incorrect_predictions'] = int(np.sum(predictions != labels)) + with open(metrics_path, 'w') as f: + json.dump(metrics_dict, f, indent=4) + print(f"✓ Saved metrics to {metrics_path}") + + print("-"*70) + print(f"\n✅ All visualizations saved to: {save_dir.absolute()}/") + print("\nGenerated files:") + print(" • confusion_matrix.png") + print(" • roc_curve.png") + print(" • confusion_matrix_roc.png (combined)") + print(" • performance_metrics.png") + print(" • sample_predictions.png") + print(" • misclassified_samples.png") + print(" • metrics.json") + print(f"\nscp -r s4977354@rangpur.rcc.uq.edu.au:{save_dir.absolute()} ./") + print("\n" + "="*70) + + +if __name__ == "__main__": + main() + diff --git a/recognition/adni-convnext-simar/requirements.txt b/recognition/adni-convnext-simar/requirements.txt new file mode 100644 index 000000000..c9d2d869c --- /dev/null +++ b/recognition/adni-convnext-simar/requirements.txt @@ -0,0 +1,11 @@ +torch==2.4.1+cu118 +torchvision==0.19.1 +timm==1.0.9 +Pillow==10.4.0 +numpy==1.26.4 +pandas==2.2.2 +matplotlib==3.9.2 +scikit-learn==1.5.1 +tqdm==4.66.5 +PyYAML==6.0.2 +accelerate==0.34.2 diff --git a/recognition/adni-convnext-simar/train.py b/recognition/adni-convnext-simar/train.py new file mode 100644 index 000000000..80b8cbfe4 --- /dev/null +++ b/recognition/adni-convnext-simar/train.py @@ -0,0 +1,307 @@ +import argparse +from pathlib import Path +from collections import Counter + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts +from torch.utils.data import DataLoader + +from modules import build_model, num_trainable_params +from dataset import ADNIConfig, ADNIDataset +from utils import AvgMeter, set_seed + + +class FocalLoss(nn.Module): + """ + Focal Loss to handle class imbalance by down-weighting easy examples. + Focuses training on hard-to-classify samples. + """ + def __init__(self, alpha=None, gamma=2.0, reduction='mean'): + super().__init__() + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + + def forward(self, inputs, targets): + ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha) + pt = torch.exp(-ce_loss) + focal_loss = ((1 - pt) ** self.gamma) * ce_loss + + if self.reduction == 'mean': + return focal_loss.mean() + elif self.reduction == 'sum': + return focal_loss.sum() + else: + return focal_loss + + +def mixup_data(x, y, alpha=0.4): + """Mixup augmentation - blends pairs of images""" + lam = torch.distributions.Beta(alpha, alpha).sample() + batch_size = x.size(0) + index = torch.randperm(batch_size).to(x.device) + mixed_x = lam * x + (1 - lam) * x[index] + y_a, y_b = y, y[index] + return mixed_x, y_a, y_b, lam + + +def mixup_criterion(criterion, pred, y_a, y_b, lam): + """Loss function for mixup""" + return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) + + +def accuracy(logits, y): + preds = torch.argmax(logits, dim=1) + return (preds == y).float().mean().item() + + +def train_one_epoch(model, loader, optim, criterion, device, use_mixup=True): + model.train() + loss_m = AvgMeter() + acc_m = AvgMeter() + + for x, y, _ in loader: + x, y = x.to(device), y.to(device) + + if use_mixup: + x, y_a, y_b, lam = mixup_data(x, y, alpha=0.3) + logits = model(x) + loss = mixup_criterion(criterion, logits, y_a, y_b, lam) + else: + logits = model(x) + loss = criterion(logits, y) + + optim.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optim.step() + + acc = accuracy(logits, y) + bs = x.size(0) + loss_m.update(loss.item(), bs) + acc_m.update(acc, bs) + + return loss_m.avg, acc_m.avg + + +def evaluate(model, loader, criterion, device): + model.eval() + loss_m = AvgMeter() + acc_m = AvgMeter() + + with torch.no_grad(): + for x, y, _ in loader: + x, y = x.to(device), y.to(device) + logits = model(x) + loss = criterion(logits, y) + + acc = accuracy(logits, y) + bs = x.size(0) + loss_m.update(loss.item(), bs) + acc_m.update(acc, bs) + + return loss_m.avg, acc_m.avg + + +def freeze_backbone(model, unfreeze_last_n_blocks=2): + """Freeze backbone except last N blocks""" + backbone = model.backbone if hasattr(model, 'backbone') else model + + for param in model.parameters(): + param.requires_grad = False + + if hasattr(backbone, 'head'): + for param in backbone.head.parameters(): + param.requires_grad = True + print("Unfroze classifier head") + elif hasattr(backbone, 'fc'): + for param in backbone.fc.parameters(): + param.requires_grad = True + print("Unfroze classifier fc") + + if hasattr(backbone, 'stages'): + total_stages = len(backbone.stages) + start_idx = max(0, total_stages - unfreeze_last_n_blocks) + for i in range(start_idx, total_stages): + for param in backbone.stages[i].parameters(): + param.requires_grad = True + print(f"Unfroze stages {start_idx} to {total_stages-1} (out of {total_stages})") + + if hasattr(backbone, 'norm'): + for param in backbone.norm.parameters(): + param.requires_grad = True + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--data-root", type=str, required=True) + p.add_argument("--outdir", type=str, default="runs/adni_full_train") + p.add_argument("--model", type=str, default="convnext_small") + p.add_argument("--epochs", type=int, default=40) + p.add_argument("--batch-size", type=int, default=24) # 384px fits on A100 + p.add_argument("--img-size", type=int, default=384) + p.add_argument("--lr", type=float, default=5e-5) + p.add_argument("--weight-decay", type=float, default=0.01) + p.add_argument("--drop-rate", type=float, default=0.2) + p.add_argument("--seed", type=int, default=42) + p.add_argument("--freeze-backbone", action="store_true", default=True) + p.add_argument("--unfreeze-last-n", type=int, default=3) + p.add_argument("--focal-gamma", type=float, default=0.0) + p.add_argument("--use-mixup", action="store_true", default=True) + args = p.parse_args() + + print("\n" + "="*70) + print("FINAL TRAINING: All Data + Real Test Validation") + print("="*70) + print("KEY CHANGES:") + print(f" 1. Using ALL 21,520 training samples") + print(f" 2. Validating on REAL test set (9,000 samples)") + print(f" 3. Model: {args.model}") + print(f" 4. LR: {args.lr}") + print(f" 5. Weight decay: {args.weight_decay}") + print(f" 6. Dropout: {args.drop_rate}") + print(f" 7. Focal gamma: {args.focal_gamma} (0=disabled)") + print(f"\nTarget: 80%+ test accuracy") + print("="*70 + "\n") + + set_seed(args.seed) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}\n") + + cfg = ADNIConfig(data_root=args.data_root, img_size=args.img_size) + + print("Loading datasets...") + train_set = ADNIDataset(cfg.data_root, split="train", img_size=cfg.img_size) + val_set = ADNIDataset(cfg.data_root, split="test", img_size=cfg.img_size) + + print(f"Train: {len(train_set)} samples (ALL training data)") + print(f"Val: {len(val_set)} samples (REAL test set)") + + train_labels = [label.item() if torch.is_tensor(label) else label + for _, label, _ in train_set] + class_counts = Counter(train_labels) + + print(f"\nClass distribution:") + print(f" AD: {class_counts.get(0, 0)} samples") + print(f" NC: {class_counts.get(1, 0)} samples") + + total = sum(class_counts.values()) + + class_weights = torch.tensor([ + total / (1.2 * class_counts[0]), + total / (2.8 * class_counts[1]) + ], dtype=torch.float32).to(device) + + print(f"\nClass weights:") + print(f" AD: {class_weights[0]:.3f}") + print(f" NC: {class_weights[1]:.3f}") + print(f" Ratio: {class_weights[0]/class_weights[1]:.2f}x") + + train_loader = DataLoader( + train_set, batch_size=args.batch_size, shuffle=True, + num_workers=4, pin_memory=True + ) + val_loader = DataLoader( + val_set, batch_size=args.batch_size, shuffle=False, + num_workers=4, pin_memory=True + ) + + print(f"\nBuilding {args.model} model...") + model = build_model( + model_name=args.model, + num_classes=2, + pretrained=True, + drop_rate=args.drop_rate + ).to(device) + + if args.freeze_backbone: + print(f"\nFreezing backbone (unfreezing last {args.unfreeze_last_n} blocks)...") + freeze_backbone(model, unfreeze_last_n_blocks=args.unfreeze_last_n) + + trainable = num_trainable_params(model) + total_params = sum(p.numel() for p in model.parameters()) + print(f"Trainable: {trainable:,} / {total_params:,} ({100*trainable/total_params:.1f}%)") + + trainable_params = [p for p in model.parameters() if p.requires_grad] + + optim = AdamW(trainable_params, lr=args.lr, weight_decay=args.weight_decay) + + scheduler = CosineAnnealingWarmRestarts(optim, T_0=10, T_mult=2, eta_min=1e-7) + + if args.focal_gamma > 0: + criterion = FocalLoss(alpha=class_weights, gamma=args.focal_gamma) + print(f"\nUsing Focal Loss (gamma={args.focal_gamma})") + else: + # *** label smoothing for generalization *** + criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1) + print(f"\nUsing CrossEntropyLoss with class weights + label_smoothing=0.1") + + outdir = Path(args.outdir) + outdir.mkdir(parents=True, exist_ok=True) + best_acc = 0.0 + best_loss = float('inf') + patience_counter = 0 + early_stop_patience = 10 + + print("\n" + "="*70) + print("Starting training...") + print("="*70) + + for epoch in range(1, args.epochs + 1): + # Disable mixup for the final quarter of training to sharpen decision boundaries + use_mixup_now = args.use_mixup and (epoch <= int(0.75 * args.epochs)) + if epoch == int(0.75 * args.epochs) + 1: + print(">> Mixup disabled for final epochs.") + + tr_loss, tr_acc = train_one_epoch( + model, train_loader, optim, criterion, device, + use_mixup=use_mixup_now + ) + va_loss, va_acc = evaluate(model, val_loader, criterion, device) + + scheduler.step() + current_lr = optim.param_groups[0]['lr'] + + print( + f"epoch {epoch:03d} | train loss {tr_loss:.4f} acc {tr_acc:.4f} | " + f"val loss {va_loss:.4f} acc {va_acc:.4f} | lr {current_lr:.2e}" + ) + + torch.save({ + "model": model.state_dict(), + "epoch": epoch, + "best_acc": best_acc, + }, outdir / "last.pt") + + if va_acc > best_acc: + best_acc = va_acc + best_loss = va_loss + patience_counter = 0 + torch.save({ + "model": model.state_dict(), + "epoch": epoch, + "best_acc": best_acc, + }, outdir / "best.pt") + print(f" New best! Acc: {best_acc:.4f}") + else: + patience_counter += 1 + print(f" No improvement ({patience_counter}/{early_stop_patience})") + + if patience_counter >= early_stop_patience: + print(f"\nEarly stopping at epoch {epoch}") + break + + print(f"\n{'='*70}") + print(f"Training complete!") + print(f"{'='*70}") + print(f"Best validation accuracy: {best_acc:.4f}") + print(f"Model saved to: {outdir}/best.pt") + print("="*70 + "\n") + + +if __name__ == "__main__": + main() + diff --git a/recognition/adni-convnext-simar/utils.py b/recognition/adni-convnext-simar/utils.py new file mode 100644 index 000000000..37e9c31b7 --- /dev/null +++ b/recognition/adni-convnext-simar/utils.py @@ -0,0 +1,26 @@ +import random +import numpy as np +import torch + + +class AvgMeter: + def __init__(self): + self.sum = 0.0 + self.n = 0 + + def update(self, val, k: int = 1): + self.sum += float(val) * k + self.n += k + + @property + def avg(self): + return self.sum / max(1, self.n) + + +def set_seed(seed: int = 42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False