diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..7373f0930 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +data/ +checkpoints/ +.venv/ +__pycache__/ +runs/ \ No newline at end of file diff --git a/recognition/README.md b/recognition/README.md deleted file mode 100644 index 32c99e899..000000000 --- a/recognition/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Recognition Tasks -Various recognition tasks solved in deep learning frameworks. - -Tasks may include: -* Image Segmentation -* Object detection -* Graph node classification -* Image super resolution -* Disease classification -* Generative modelling with StyleGAN and Stable Diffusion \ No newline at end of file diff --git a/recognition/adni_convnext_47068591/README.md b/recognition/adni_convnext_47068591/README.md new file mode 100644 index 000000000..139f22a97 --- /dev/null +++ b/recognition/adni_convnext_47068591/README.md @@ -0,0 +1,135 @@ +# Alzheimer’s Disease Classification using ConvNeXt on ADNI Dataset + +*Author: Man Hin Lai (s47068591)* +*Course: COMP3710 Pattern Analysis* + +## 1. Project Introduction + +This project implements a convolutional neural network to classify **Alzheimer’s Disease (AD)** vs **Normal Control (NC)** from grayscale MRI slices . The task is framed as **binary image classification** using a custom, 1-channel variant of **ConvNeXt-Tiny** with training on 2D JPEG slices and subject-level aggregation at evaluation. The pipeline includes robust data augmentation, mixed-precision training, MixUp, weight-decoupled AdamW, warmup+cosine learning-rate scheduling, and early stopping on subject-level accuracy. Inference scripts produce both **slice-level** and **subject-level** metrics and figures suitable for report inclusion. + +## 2. Overview + +**Data → Model → Metrics** + +1. **Dataset (`dataset.py`)** + * Loads grayscale JPEGs from `ADNI/AD_NC/{train,test}/{AD,NC}/`. + * Optional cap on slices per subject (`LIMIT_SLICES_PER_SUBJECT`). + * **Train transforms** : resized crop, flip, affine, perspective, color jitter, Gaussian blur, RandomErasing, normalize. + * **Eval transforms** : resize + normalize only. +2. **Model (`modules.py`)** + * **ConvNeXtTiny1C** : ConvNeXt-like blocks (DWConv-7×7 → LN (channels-last) → MLP(×4) → GELU → MLP → LayerScale → DropPath + residual). + * Stages with depths `[3,3,9,3]` and dims `[96,192,384,768]`. + * Head: GAP → LayerNorm → Dropout → Linear(1). + * Trained with **BCE-with-logits** for binary classification. +3. **Training (`train.py`)** + * **MixUp** in-batch augmentation. + * **AdamW** (decoupled weight decay), **warmup + cosine** LR schedule. + * **AMP** (mixed precision) on CUDA. + * **Early stopping** on *subject-level* accuracy. + * Saves best checkpoint to `runs/best_model.pt` and training curves (`loss_curve.png`, `acc_curve.png`, `subject_acc_curve.png`). +4. **Inference / Report Artifacts (`predict.py`)** + * Loads the test split and the trained checkpoint. + * Computes **slice-level** and **subject-level** metrics (Acc, AUC, Sensitivity, Specificity). + * Saves: confusion matrix, ROC curve, subject example grid, and a metrics JSON. + +--- + +## 3. Pre-Processing & Splits + +**Pre-processing:** + +* All images are **grayscale** and resized to `IMAGE_SIZE=224`. +* **Normalization** : mean=0.5, std=0.5 (single channel). +* **Train-time augmentations** (dataset-level): random resized crop, horizontal flip, small affine transforms, perspective jitter, brightness/contrast jitter, Gaussian blur, and RandomErasing. These are standard modern augmentations for CNNs to improve generalization on limited medical imaging datasets. +* **Train-time MixUp** (model-level): linear combination of two samples and labels within a batch (α=0.2), which smooths decision boundaries and reduces overfitting. + +**Split justification:** + +* The **provided folder structure** separates `train/` and `test/`. We strictly train on `train/` and treat `test/` as **held-out evaluation** (used for validation during development and for final reporting). +* Subject leakage is mitigated by dataset naming (subject ID parsed from filename prefix) and optional `LIMIT_SLICES_PER_SUBJECT` to avoid overpowering the batch with many slices from the same subject. At evaluation, we compute metrics at **slice-level** and **subject-level** (by averaging slice probabilities per subject), with the **subject-level** metric used for model selection. + +--- + + + +## 4. Results + +Training was performed on a local GPU with the default config. The best checkpoint (by subject-level accuracy) achieved approximately: + +* **Slice-level accuracy** : **0.741** +* **Subject-level accuracy (best)** : **0.802** + +Slice level accuracy, loss, and subject level accuracy curves are saved in runs/ after running train.py + +**Examples of these curves:** + +Subject level accuracy: + +![1762166201886](https://file+.vscode-resource.vscode-cdn.net/c%3A/Users/harri/UQ/COMP3710/COMP3710_A3/PatternAnalysis-2025-47068591/recognition/adni_convnext_47068591/image/README/1762166201886.png) + +Slice level accuracy: +![1762166347838](image/README/1762166347838.png) + +--- + +## 5. Visualizations + +Predictions on the test split of the data, using the trained model are visualized as: + +* **Confusion Matrix (Subject-Level)** +* **ROC Curve (Slice-Level)** +* **Subject Examples Grid** (representative slice per subject with subject-level p(AD)) + +These results will be saved in runs/ after running predict.py + +**Examples of these visulizations:** + +Subject level confusion matrix: +![1762166385881](image/README/1762166385881.png) + +Slice level ROC curve: +![1762166411274](image/README/1762166411274.png) + +Examples of: Input | True label | Model prediction(with percentage) +![1762166491952](image/README/1762166491952.png) + + + +--- + +## 6. Dependencies and Environment + +| Package | Version | Purpose | +| ------------ | -------------- | ----------------------- | +| Python | 3.11.9 | Environment | +| PyTorch | 2.5.1 + cu124 | Deep learning (GPU) | +| Torchvision | 0.20.1 + cu124 | Model zoo / transforms | +| Scikit-learn | 1.4+ | Metrics / preprocessing | +| Matplotlib | 3.8+ | Plotting | +| Pillow | 10.3+ | Image utilities | +| numpy | 1.26+ | Maths calculations | + +These Dependencies are listed in requirements.txt + + +--- + +## 7. Reproducibility + +All results were obtained on Windows 10 + RTX 4070 Ti (CUDA 12.4 build). +To replicate: + +```bash +python -m venv .venv && .venv\Scripts\activate +pip install -r requirements.txt +python recognition\adni_convnext_47068591\train.py --root \home\groups\comp3710\ADNI\AD_NC +python recognition\adni_convnext_47068591\predict.py --root \home\groups\comp3710\ADNI\AD_NC ----ckpt runs/best_model.pt +``` + +These following configs can also be changed in train.py: + +* `ROOT`, `EPOCHS`, `BATCH`, `LR`, `WEIGHT_DECAY` +* `IMAGE_SIZE`, `LIMIT_SLICES_PER_SUBJECT` +* `DROP_PATH_RATE`, `HEAD_DROP`, `WARMUP_EPOCHS` +* `CLIP_NORM`, `MIXUP_ALPHA` +* `EARLY_STOP_PATIENCE` diff --git a/recognition/adni_convnext_47068591/dataset.py b/recognition/adni_convnext_47068591/dataset.py new file mode 100644 index 000000000..f45c30e40 --- /dev/null +++ b/recognition/adni_convnext_47068591/dataset.py @@ -0,0 +1,128 @@ +""" +dataset.py +----------- +Loads grayscale JPEG slices for AD vs NC classification (ADNI dataset). + +""" + +import os +import glob +from typing import List, Tuple, Optional +from collections import defaultdict + +import torch +from torch.utils.data import Dataset +from PIL import Image +import torchvision.transforms as T +from torchvision.transforms import InterpolationMode as IM + + +def _parse_subject_id(filename: str) -> str: + """ + Extract subject ID from filenames like '123456_78.jpeg' → '123456'. + """ + base = os.path.basename(filename) + stem, _ = os.path.splitext(base) + return stem.split("_")[0] + + +class ADNIJPEGSlicesDataset(Dataset): + """ + Dataset for grayscale JPEG MRI slices (AD vs NC). + + Returns: + img: FloatTensor [1, H, W] (grayscale) + label: LongTensor 0 (NC) or 1 (AD) + subject_id: str + """ + + def __init__( + self, + root: str, # e.g. "path/to/ADNI/AD_NC" + split: str, # "train" or "test" + image_size: int = 224, + augment: bool = True, + limit_slices_per_subject: Optional[int] = None, + ): + super().__init__() + assert split in ("train", "test"), "split must be 'train' or 'test'" + self.root = root + self.split = split + self.image_size = image_size + self.limit_slices_per_subject = limit_slices_per_subject + + # Map class names to labels + self.class_to_label = {"AD": 1, "NC": 0} + + split_dir = os.path.join(root, split) + self.samples: List[Tuple[str, int, str]] = [] # (path, label, subject_id) + + # Collect all JPEGs + for cls in ("AD", "NC"): + cls_dir = os.path.join(split_dir, cls) + paths = sorted(glob.glob(os.path.join(cls_dir, "*.jpeg"))) + \ + sorted(glob.glob(os.path.join(cls_dir, "*.jpg"))) + + label = self.class_to_label[cls] + by_subject = defaultdict(list) + for p in paths: + sid = _parse_subject_id(p) + by_subject[sid].append(p) + + for sid, plist in by_subject.items(): + if self.limit_slices_per_subject and len(plist) > self.limit_slices_per_subject: + plist = plist[: self.limit_slices_per_subject] + for p in plist: + self.samples.append((p, label, sid)) + + # data augmentations / transforms + if split == "train" and augment: + self.tf = T.Compose([ + # --- geometric (PIL space) --- + T.RandomResizedCrop( + image_size, + scale=(0.80, 1.00), + ratio=(0.90, 1.10), + interpolation=IM.BICUBIC + ), + T.RandomHorizontalFlip(p=0.5), + T.RandomApply([ + T.RandomAffine( + degrees=8, + translate=(0.05, 0.05), + scale=(0.95, 1.05), + shear=(-5, 5), + interpolation=IM.BILINEAR + ) + ], p=0.7), + T.RandomPerspective(distortion_scale=0.20, p=0.3), + T.ColorJitter(brightness=0.18, contrast=0.18), + + # --- tensor space --- + T.ToTensor(), # → [1, H, W] + T.Normalize(mean=[0.5], std=[0.5]), + T.RandomApply([ + T.GaussianBlur(kernel_size=3, sigma=(0.1, 1.2)) + ], p=0.3), + T.RandomErasing( + p=0.25, + scale=(0.01, 0.05), + ratio=(0.4, 2.5), + value='random' + ), + ]) + else: + self.tf = T.Compose([ + T.Resize((image_size, image_size), interpolation=IM.BICUBIC), + T.ToTensor(), + T.Normalize(mean=[0.5], std=[0.5]), + ]) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + path, label, sid = self.samples[idx] + img = Image.open(path).convert("L") # grayscale + img = self.tf(img) # tensor [1,H,W] + return img, torch.tensor(label, dtype=torch.long), sid diff --git a/recognition/adni_convnext_47068591/image/README/1762166201886.png b/recognition/adni_convnext_47068591/image/README/1762166201886.png new file mode 100644 index 000000000..1b02ca469 Binary files /dev/null and b/recognition/adni_convnext_47068591/image/README/1762166201886.png differ diff --git a/recognition/adni_convnext_47068591/image/README/1762166347838.png b/recognition/adni_convnext_47068591/image/README/1762166347838.png new file mode 100644 index 000000000..b4bb73e53 Binary files /dev/null and b/recognition/adni_convnext_47068591/image/README/1762166347838.png differ diff --git a/recognition/adni_convnext_47068591/image/README/1762166385881.png b/recognition/adni_convnext_47068591/image/README/1762166385881.png new file mode 100644 index 000000000..bcb90e3b6 Binary files /dev/null and b/recognition/adni_convnext_47068591/image/README/1762166385881.png differ diff --git a/recognition/adni_convnext_47068591/image/README/1762166411274.png b/recognition/adni_convnext_47068591/image/README/1762166411274.png new file mode 100644 index 000000000..750c187af Binary files /dev/null and b/recognition/adni_convnext_47068591/image/README/1762166411274.png differ diff --git a/recognition/adni_convnext_47068591/image/README/1762166491952.png b/recognition/adni_convnext_47068591/image/README/1762166491952.png new file mode 100644 index 000000000..120721dcf Binary files /dev/null and b/recognition/adni_convnext_47068591/image/README/1762166491952.png differ diff --git a/recognition/adni_convnext_47068591/modules.py b/recognition/adni_convnext_47068591/modules.py new file mode 100644 index 000000000..2c991fb13 --- /dev/null +++ b/recognition/adni_convnext_47068591/modules.py @@ -0,0 +1,237 @@ +""" +modules.py +----------- +Self-built ConvNeXt-like (Tiny) for 1-channel input and binary output. + +Key pieces: +- ConvNeXtBlock: DW-7x7 -> LN (channels-last) -> 1x1 MLP (expand 4x) -> GELU -> 1x1 (project) -> layer scale -> DropPath -> residual +- Downsample: LN (channels-last) + Conv2d (stride=2) +- ConvNeXtTiny1C: stem (4x4/4) + stages with depths [3,3,9,3], dims [96,192,384,768] +- Head: GAP -> LN -> Dropout -> Linear(1) => single logit + +Use with BCE-with-logits. +""" + +from typing import Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# ----------------------------- Utilities ------------------------------------- # + +class DropPath(nn.Module): + """Stochastic depth (per-sample) — as in ConvNeXt/DeiT. Set drop_prob=0 to disable.""" + def __init__(self, drop_prob: float = 0.0): + super().__init__() + self.drop_prob = float(drop_prob) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1.0 - self.drop_prob + # Work with (N, ...) shaped input. Broadcast along all but batch. + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + mask = torch.empty(shape, dtype=x.dtype, device=x.device).bernoulli_(keep_prob) + return x / keep_prob * mask + + +class LayerNorm2d(nn.Module): + """ + Convenience wrapper: apply LayerNorm expecting channels-last. + We permute (N,C,H,W) -> (N,H,W,C), LN over C, then back. + """ + def __init__(self, num_channels: int, eps: float = 1e-6): + super().__init__() + self.ln = nn.LayerNorm(num_channels, eps=eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 3, 1) # NCHW -> NHWC + x = self.ln(x) + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + return x + + +# --------------------------- ConvNeXt Building Blocks ------------------------- # + +class ConvNeXtBlock(nn.Module): + """ + One ConvNeXt block: + DWConv7x7 -> LN (channels-last) -> Linear(4x) -> GELU -> Linear(1x) -> LayerScale(gamma) -> DropPath -> Residual + """ + def __init__( + self, + dim: int, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-6 + ): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise 7x7 + self.norm = nn.LayerNorm(dim, eps=1e-6) # LN expects channels-last + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise (expand) + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) # pointwise (project) + self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim)) if layer_scale_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + + x = self.dwconv(x) # (N,C,H,W) + x = x.permute(0, 2, 3, 1) # -> (N,H,W,C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # -> (N,C,H,W) + + x = shortcut + self.drop_path(x) + return x + + +class Downsample(nn.Module): + """ + ConvNeXt downsample layer: + LN (channels-last) -> Conv2d stride=2 + """ + def __init__(self, in_ch: int, out_ch: int): + super().__init__() + self.norm = LayerNorm2d(in_ch, eps=1e-6) + self.reduction = nn.Conv2d(in_ch, out_ch, kernel_size=2, stride=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.reduction(x) + return x + + +# ------------------------------ Full Model ----------------------------------- # + +class ConvNeXtTiny1C(nn.Module): + """ + ConvNeXt-Tiny-like network for 1-channel input. + - Stem: Conv(1->96, k=4,s=4) + LN + - Stages: depths [3,3,9,3], dims [96,192,384,768] + - Head: GAP -> LN -> Dropout -> Linear(1) (single logit for BCE-with-logits) + """ + def __init__( + self, + in_ch: int = 1, + num_classes: int = 1, # 1 => single logit for binary + depths: Tuple[int, int, int, int] = (3, 3, 9, 3), + dims: Tuple[int, int, int, int] = (96, 192, 384, 768), + drop_path_rate: float = 0.0, + head_drop: float = 0.1, + layer_scale_init_value: float = 1e-6 + ): + super().__init__() + + # Stem: patch embedding (4x4, stride 4) then LN + self.stem_conv = nn.Conv2d(in_ch, dims[0], kernel_size=4, stride=4) + self.stem_ln = LayerNorm2d(dims[0], eps=1e-6) + + # stochastic depth decay rule across all blocks + total_blocks = sum(depths) + dp_rates = torch.linspace(0, drop_path_rate, steps=total_blocks).tolist() + dp_iter = 0 + + # Stage 0 (no downsample before first stage) + stage0 = [] + for _ in range(depths[0]): + stage0.append(ConvNeXtBlock(dims[0], drop_path=dp_rates[dp_iter], layer_scale_init_value=layer_scale_init_value)) + dp_iter += 1 + self.stage0 = nn.Sequential(*stage0) + + # Stage 1 + self.down1 = Downsample(dims[0], dims[1]) + stage1 = [] + for _ in range(depths[1]): + stage1.append(ConvNeXtBlock(dims[1], drop_path=dp_rates[dp_iter], layer_scale_init_value=layer_scale_init_value)) + dp_iter += 1 + self.stage1 = nn.Sequential(*stage1) + + # Stage 2 + self.down2 = Downsample(dims[1], dims[2]) + stage2 = [] + for _ in range(depths[2]): + stage2.append(ConvNeXtBlock(dims[2], drop_path=dp_rates[dp_iter], layer_scale_init_value=layer_scale_init_value)) + dp_iter += 1 + self.stage2 = nn.Sequential(*stage2) + + # Stage 3 + self.down3 = Downsample(dims[2], dims[3]) + stage3 = [] + for _ in range(depths[3]): + stage3.append(ConvNeXtBlock(dims[3], drop_path=dp_rates[dp_iter], layer_scale_init_value=layer_scale_init_value)) + dp_iter += 1 + self.stage3 = nn.Sequential(*stage3) + + # Head + self.head_ln = nn.LayerNorm(dims[3], eps=1e-6) # applied after GAP, so vector LN + self.dropout = nn.Dropout(head_drop) + self.fc = nn.Linear(dims[3], num_classes) + + # Weight init (Kaiming for convs, xavier for linears) + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.Linear,)): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Stem + x = self.stem_conv(x) # (N, C0, H/4, W/4) + x = self.stem_ln(x) # channels-last LN wrapper + + # Stages + x = self.stage0(x) + x = self.down1(x); x = self.stage1(x) + x = self.down2(x); x = self.stage2(x) + x = self.down3(x); x = self.stage3(x) + + # Global average pooling + x = F.adaptive_avg_pool2d(x, output_size=1).flatten(1) # (N, C3) + x = self.head_ln(x) # vector LN (no permute needed) + x = self.dropout(x) + x = self.fc(x) # (N, 1) for binary + return x.squeeze(1) # (N,) single logit + + +# ---------------------------- Loss & Metrics --------------------------------- # + +def bce_with_logits_loss( + logits: torch.Tensor, + targets: torch.Tensor, + pos_weight: Optional[float] = None +) -> torch.Tensor: + """ + Stable BCE-with-logits for binary classification (targets in {0,1}). + Set pos_weight>1.0 if AD is the minority class. + """ + targets = targets.float() + if pos_weight is not None: + pw = torch.tensor([pos_weight], device=logits.device, dtype=logits.dtype) + return F.binary_cross_entropy_with_logits(logits, targets, pos_weight=pw) + return F.binary_cross_entropy_with_logits(logits, targets) + + +@torch.no_grad() +def binary_metrics(logits: torch.Tensor, targets: torch.Tensor) -> Tuple[float, float]: + """ + Returns: + acc: mean accuracy + mpt: mean probability for the positive class (on positive samples), NaN if none. + """ + probs = torch.sigmoid(logits) + preds = (probs >= 0.5).long() + acc = (preds == targets).float().mean().item() + mpt = probs[targets == 1].mean().item() if (targets == 1).any() else float('nan') + return acc, mpt diff --git a/recognition/adni_convnext_47068591/predict.py b/recognition/adni_convnext_47068591/predict.py new file mode 100644 index 000000000..10580bb14 --- /dev/null +++ b/recognition/adni_convnext_47068591/predict.py @@ -0,0 +1,364 @@ +""" +predict.py — demo/report artifacts from the saved checkpoint + +What this script does (no training here): + 1) Loads the TEST split via ADNIJPEGSlicesDataset. + 2) Builds the ConvNeXtTiny1C model and loads weights from a checkpoint. + 3) Computes slice-level and subject-level metrics (acc, AUC, sensitivity, specificity). + 4) Saves a confusion matrix image, ROC curve image, an example grid image, and a JSON with metrics. + +Usage: + python predict.py --root --ckpt runs/best_model.pt + example:python recognition/adni_convnext_47068591/predict.py --root /home/groups/comp3710/ADNI/AD_NC --ckpt runs/best_model.pt + +Outputs (in --out, default "runs/"): + demo_confusion.png, demo_roc.png, demo_examples.png, demo_metrics.json +""" + +import argparse +from pathlib import Path +from collections import defaultdict + +import numpy as np +import torch +from torch.utils.data import DataLoader +import matplotlib.pyplot as plt +from sklearn.metrics import roc_auc_score, confusion_matrix, roc_curve + +from dataset import ADNIJPEGSlicesDataset +from modules import ConvNeXtTiny1C + + +def compute_slice_metrics(model, loader, device): + """ + Evaluate model performance at the SLICE level. + + Returns a dict with: + - acc: accuracy over slices + - auc: ROC AUC over slices (probabilities vs labels) + - sens/spec: sensitivity/specificity from a 0.5 threshold + - y_true, y_prob, y_pred: arrays for downstream plots (ROC/confusion) + """ + model.eval() + y_true, y_prob = [], [] + + # Disable grad; iterate over all test batches + with torch.inference_mode(): + for imgs, labels, _ in loader: + imgs = imgs.to(device) + logits = model(imgs) # raw logits (shape [N]) + probs = torch.sigmoid(logits).cpu().numpy() # convert to p(AD) + y_prob.append(probs) + y_true.append(labels.numpy()) + + # Concatenate across all batches + y_prob = np.concatenate(y_prob).astype(np.float64) + y_true = np.concatenate(y_true).astype(np.int64) + + # Threshold at 0.5 for hard predictions + y_pred = (y_prob >= 0.5).astype(np.int64) + + # Slice accuracy + acc = (y_pred == y_true).mean().item() + + # ROC AUC can fail if there is only one class present → guard with try/except + try: + auc = roc_auc_score(y_true, y_prob) + except ValueError: + auc = float('nan') + + # Confusion matrix and derived metrics at the 0.5 threshold + tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel() + sens = tp / max(tp + fn, 1) # recall for AD + spec = tn / max(tn + fp, 1) # recall for NC + + return dict(acc=acc, auc=auc, sens=sens, spec=spec, + y_true=y_true, y_prob=y_prob, y_pred=y_pred) + + +def compute_subject_metrics(model, loader, device): + """ + Evaluate model performance at the SUBJECT level. + + Approach: + - Accumulate slice probabilities per subject_id in a bucket. + - Average probabilities per subject (mean p(AD) across its slices). + - Threshold averaged probability at 0.5 to get a subject prediction. + """ + model.eval() + bucket = defaultdict(list) # sid -> list of (p, label) + + with torch.inference_mode(): + for imgs, labels, sids in loader: + imgs = imgs.to(device) + logits = model(imgs) + probs = torch.sigmoid(logits).cpu().numpy() + # Group each slice's prob with its subject id + for p, lab, sid in zip(probs, labels.numpy(), sids): + bucket[sid].append((float(p), int(lab))) + + # Aggregate per subject (mean probability); labels are consistent for a subject + y_true, y_prob = [], [] + for sid, items in bucket.items(): + mean_prob = sum(p for p, _ in items) / len(items) + y_prob.append(mean_prob) + y_true.append(items[0][1]) + + # Convert to arrays for metrics + y_true = np.array(y_true, dtype=np.int64) + y_prob = np.array(y_prob, dtype=np.float64) + + # Subject-level hard predictions at 0.5 + y_pred = (y_prob >= 0.5).astype(np.int64) + + # Subject accuracy + acc = (y_pred == y_true).mean().item() + + # Subject ROC AUC (may be NaN if only one class present) + try: + auc = roc_auc_score(y_true, y_prob) + except ValueError: + auc = float('nan') + + # Subject confusion matrix-derived metrics (using the 0.5 threshold) + tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel() + sens = tp / max(tp + fn, 1) + spec = tn / max(tn + fp, 1) + + return dict(acc=acc, auc=auc, sens=sens, spec=spec, + y_true=y_true, y_prob=y_prob, y_pred=y_pred) + + +def plot_subject_confusion(cm, outpath): + """ + Plot a labeled 2x2 confusion matrix for SUBJECT-LEVEL classification (NC vs AD). + + Args: + cm (np.ndarray): 2x2 array in the form [[TN, FP], [FN, TP]] + outpath (Path or str): path to save the figure (PNG) + + Conventions: + - Rows are TRUE subject diagnoses (NC row, AD row) + - Columns are PREDICTED subject diagnoses (NC col, AD col) + """ + tn, fp, fn, tp = cm.ravel() + matrix = np.array([[tn, fp], [fn, tp]]) + + fig, ax = plt.subplots(figsize=(4, 4)) + im = ax.imshow(matrix, cmap="Blues") + + # Tick labels + ax.set_xticks([0, 1]) + ax.set_yticks([0, 1]) + ax.set_xticklabels(['Predicted NC', 'Predicted AD']) + ax.set_yticklabels(['True NC', 'True AD']) + + # Axis titles (explicitly labeled as subject-level) + ax.set_xlabel("Predicted Diagnosis (Subject-Level)", fontsize=11) + ax.set_ylabel("True Diagnosis (Subject-Level)", fontsize=11) + + # Value annotations + for (i, j), v in np.ndenumerate(matrix): + ax.text(j, i, f"{v}", ha="center", va="center", fontsize=13, color="black") + + # Add gridlines between cells + ax.set_xticks(np.arange(-0.5, 2, 1), minor=True) + ax.set_yticks(np.arange(-0.5, 2, 1), minor=True) + ax.grid(which="minor", color="gray", linestyle="-", linewidth=0.5) + ax.tick_params(which="minor", bottom=False, left=False) + + # Updated title + ax.set_title("Confusion Matrix (Subject-Level)", fontsize=14, pad=10) + + plt.tight_layout() + plt.savefig(outpath, dpi=150) + plt.close(fig) + + + +def plot_roc(y_true, y_prob, outpath, title="ROC"): + """ + Plot a standard ROC curve with diagonal (chance) line and AUC in legend. + + Args: + y_true (array-like): Ground-truth labels (0/1) + y_prob (array-like): Predicted probabilities for the positive class (AD) + outpath (Path/str): Where to save the PNG + title (str): Figure title + """ + # ROC points: varying threshold from 1 → 0 + fpr, tpr, _ = roc_curve(y_true, y_prob) + + # AUC can be undefined when only one class is present + auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else float('nan') + + plt.figure() + plt.plot(fpr, tpr, label=f"AUC = {auc:.3f}") + plt.plot([0, 1], [0, 1], linestyle="--") # chance line + plt.xlabel("FPR") + plt.ylabel("TPR") + plt.legend() + plt.title(title) + plt.tight_layout() + plt.savefig(outpath, dpi=150) + plt.close() + + +def make_examples_grid(ds, model, device, outpath, n_per_class=6): + """ + Create a grid of TEST subjects with their subject-level predicted probability p(AD). + + Each subject's probability is computed as the mean of all slice probabilities. + The grid shows one representative slice (middle slice) for each subject, with a + banner that explicitly labels the following: + - Subject ID + - True Diagnosis (NC/AD) + - Model's predicted probability for AD (subject-level mean) + + Args: + ds: ADNIJPEGSlicesDataset (split="test"). + model: Trained model (in eval mode). + device: "cuda" or "cpu". + outpath: Path to save the resulting image grid. + n_per_class: Max subjects per class (NC/AD) to show. + """ + import numpy as np + from collections import defaultdict + from PIL import Image, ImageOps, ImageDraw, ImageFont + + model.eval() + + subj_to_indices = defaultdict(list) + subj_to_label = {} + for idx, (_path, y, sid) in enumerate(ds.samples): + subj_to_indices[sid].append(idx) + subj_to_label.setdefault(sid, y) + + subj_probs = {} + with torch.inference_mode(): + for sid, idx_list in subj_to_indices.items(): + slice_probs = [] + for idx in idx_list: + path, _y, _sid = ds.samples[idx] + pil = Image.open(path).convert("L") + x = ds.tf(pil).unsqueeze(0).to(device) + p = torch.sigmoid(model(x)).item() + slice_probs.append(p) + subj_probs[sid] = float(np.mean(slice_probs)) if slice_probs else float("nan") + + ad_sids = [sid for sid, y in subj_to_label.items() if y == 1][:n_per_class] + nc_sids = [sid for sid, y in subj_to_label.items() if y == 0][:n_per_class] + chosen_sids = ad_sids + nc_sids + if len(chosen_sids) == 0: + return + + tiles = [] + with torch.inference_mode(): + for sid in chosen_sids: + idx_list = subj_to_indices[sid] + label = subj_to_label[sid] + p_subj = subj_probs.get(sid, float("nan")) + + idx_list_sorted = sorted(idx_list) + rep_idx = idx_list_sorted[len(idx_list_sorted) // 2] + path, _y, _sid = ds.samples[rep_idx] + + pil = Image.open(path).convert("L") + disp = ImageOps.equalize(pil.resize((224, 224))) + draw = ImageDraw.Draw(disp) + + # Use two-line banner for clarity and space efficiency + text_line1 = f"Subject ID: {sid}" + text_line2 = f"True: {'AD' if label==1 else 'NC'} | Pred p(AD): {p_subj:.2f}" + + # Draw a taller banner rectangle + banner_height = 40 + draw.rectangle([0, 0, 223, banner_height], fill=0) + + # Use a slightly smaller font if available + try: + font = ImageFont.truetype("arial.ttf", 12) + except: + font = None # fallback to default + + # Write each line separately + draw.text((4, 4), text_line1, fill=255, font=font) + draw.text((4, 20), text_line2, fill=255, font=font) + + tiles.append(disp.convert("RGB")) + + cols = 6 + rows = int(np.ceil(len(tiles) / cols)) + w, h = tiles[0].size + grid = Image.new("RGB", (cols * w, rows * h), color=(255, 255, 255)) + for idx, im in enumerate(tiles): + r, c = divmod(idx, cols) + grid.paste(im, (c * w, r * h)) + grid.save(outpath) + + + + +def main(): + """ + Entrypoint: + - Parse CLI args + - Construct test dataset/loader + - Load checkpointed model + - Compute and save metrics/plots/examples + """ + ap = argparse.ArgumentParser() + ap.add_argument("--root", type=str, required=True) # ADNI/AD_NC root + ap.add_argument("--ckpt", type=str, default="runs/best_model.pt") # checkpoint path + ap.add_argument("--batch", type=int, default=32) # test-time batch size + ap.add_argument("--image_size", type=int, default=224) # must match training/eval tf + ap.add_argument("--workers", type=int, default=4) # dataloader workers + ap.add_argument("--out", type=str, default="runs") # output directory + args = ap.parse_args() + + # Device + output directory + device = "cuda" if torch.cuda.is_available() else "cpu" + outdir = Path(args.out); outdir.mkdir(parents=True, exist_ok=True) + + # Dataset/loader (TEST split only — no training here) + test_ds = ADNIJPEGSlicesDataset(root=args.root, split="test", + image_size=args.image_size, augment=False) + test_loader = DataLoader(test_ds, batch_size=args.batch, shuffle=False, + num_workers=args.workers, pin_memory=True) + + # Model (same architecture as training) + checkpoint load + model = ConvNeXtTiny1C(in_ch=1, num_classes=1, drop_path_rate=0.15, head_drop=0.25).to(device) + state = torch.load(args.ckpt, map_location=device) + model.load_state_dict(state) + model.eval() + + # Compute metrics (slice- and subject-level) + slice_res = compute_slice_metrics(model, test_loader, device) + subj_res = compute_subject_metrics(model, test_loader, device) + + # Save a compact JSON with the numeric metrics (arrays omitted) + metrics = { + "slice": {k: float(v) if not isinstance(v, (list, np.ndarray)) else None + for k, v in slice_res.items() if k not in ("y_true", "y_prob", "y_pred")}, + "subject": {k: float(v) if not isinstance(v, (list, np.ndarray)) else None + for k, v in subj_res.items() if k not in ("y_true", "y_prob", "y_pred")} + } + with open(outdir / "demo_metrics.json", "w") as f: + import json + json.dump(metrics, f, indent=2) + print("Demo metrics:", metrics) + + # Plots (slice-level confusion + ROC) + cm = confusion_matrix(slice_res["y_true"], slice_res["y_pred"], labels=[0, 1]) + plot_subject_confusion(cm, outdir / "demo_confusion.png") + plot_roc(slice_res["y_true"], slice_res["y_prob"], outdir / "demo_roc.png", title="ROC (slice-level)") + + # Qualitative example grid + make_examples_grid(test_ds, model, device, outdir / "demo_examples.png", n_per_class=6) + + print("Demo complete.") + print(f"Artifacts saved to: {outdir.resolve()}") + + +if __name__ == "__main__": + main() diff --git a/recognition/adni_convnext_47068591/train.py b/recognition/adni_convnext_47068591/train.py new file mode 100644 index 000000000..606d6357d --- /dev/null +++ b/recognition/adni_convnext_47068591/train.py @@ -0,0 +1,375 @@ +""" +train.py +--------- +Main training script for Alzheimer MRI slice classification (AD vs NC) +using a ConvNeXt-Tiny backbone adapted for single-channel (grayscale) input. + +Key Features: +- Loads grayscale JPEG slices via ADNIJPEGSlicesDataset. +- Implements a ConvNeXt-Tiny-like CNN (ConvNeXtTiny1C) for binary classification. +- Uses MixUp augmentation, gradient clipping, and AMP (mixed precision). +- Optimized with AdamW, warmup + cosine learning rate scheduling. +- Tracks both slice-level and subject-level accuracy during training. +- Applies early stopping based on subject-level accuracy to prevent overfitting. +- Saves training curves (loss, accuracy) and the best-performing checkpoint. + +Usage: + python train.py + (optional) Override root with: --root + +Outputs: + runs/ + ├── best_model.pt + ├── loss_curve.png + ├── acc_curve.png + ├── subject_acc_curve.png + └── history.json +""" +import json +import time +from pathlib import Path +from collections import defaultdict +from contextlib import nullcontext + +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.optim import AdamW +from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR +import matplotlib.pyplot as plt +import numpy as np + +from dataset import ADNIJPEGSlicesDataset +from modules import ConvNeXtTiny1C, bce_with_logits_loss, binary_metrics + +# ====================== USER CONFIG ============================= # +CONFIG = dict( + ROOT=r"C:\Users\harri\UQ\COMP3710\COMP3710_A3\PatternAnalysis-2025-47068591\data\ADNI\AD_NC", + EPOCHS=60, + BATCH=16, + LR=2e-4, + WEIGHT_DECAY=2e-3, + OUT="runs", + WORKERS=4, + IMAGE_SIZE=224, + LIMIT_SLICES_PER_SUBJECT=12, + SUBJECT_EVAL=True, + SEED=42, + + # Regularization + DROP_PATH_RATE=0.15, + HEAD_DROP=0.25, + WARMUP_EPOCHS=5, + ETA_MIN=1e-5, + + # Stability & Generalization + CLIP_NORM=1.0, # gradient clipping + MIXUP_ALPHA=0.2, # MixUp + + # Early stopping + EARLY_STOP_PATIENCE=8, +) +# ============================================================================ # + + +def set_seed(seed: int = 42): + """ + Fix random seeds for Python, NumPy, and PyTorch (CPU/CUDA). + """ + import random, numpy as np + random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def save_checkpoint(model: nn.Module, path: Path): + """ + Save only the model state_dict at 'path'. Creates parent dirs if needed. + """ + path.parent.mkdir(parents=True, exist_ok=True) + torch.save(model.state_dict(), path) + print(f"Saved checkpoint: {path}") + + +def plot_history(history: dict, outdir: Path): + """ + Write simple PNG plots for train/val loss and accuracy. + Also plots subject-level accuracy if present. + """ + outdir.mkdir(parents=True, exist_ok=True) + + plt.figure() + plt.plot(history["train_loss"], label="Train Loss") + plt.plot(history["val_loss"], label="Val Loss") + plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend() + plt.tight_layout(); plt.savefig(outdir / "loss_curve.png", dpi=150); plt.close() + + plt.figure() + plt.plot(history["train_acc"], label="Train Acc") + plt.plot(history["val_acc"], label="Val Acc") + plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.legend() + plt.tight_layout(); plt.savefig(outdir / "acc_curve.png", dpi=150); plt.close() + + if "val_subj_acc" in history and len(history["val_subj_acc"]) > 0: + plt.figure() + plt.plot(history["val_subj_acc"], label="Val Subject Acc") + plt.xlabel("Epoch"); plt.ylabel("Subject Acc"); plt.legend() + plt.tight_layout(); plt.savefig(outdir / "subject_acc_curve.png", dpi=150); plt.close() + + +# ---------- MixUp ---------- +def do_mixup(x, y, alpha=0.2): + """ + Standard MixUp: convex-combine inputs and labels within the batch. + Returns mixed images, mixed (soft) labels, and lambda. + """ + if alpha <= 0: + return x, y, 1.0 + lam = np.random.beta(alpha, alpha) + bs = x.size(0) + idx = torch.randperm(bs, device=x.device) + x_mix = lam * x + (1 - lam) * x[idx] + y = y.float() + y_mix = lam * y + (1 - lam) * y[idx] + return x_mix, y_mix, lam + + +# ---------- Training ---------- +def train_one_epoch(model, loader, optimizer, device, scaler=None, mixup_alpha=0.2, clip_norm=1.0): + """ + One full pass over the training set. + - Uses AMP if CUDA is available and scaler is provided. + - Applies MixUp to both inputs and labels. + - Clips gradients for stability. + - Tracks average loss and accuracy for reporting. + """ + model.train() + running_loss, running_acc, n_samples = 0.0, 0.0, 0 + autocast_ctx = torch.amp.autocast('cuda') if device == "cuda" else nullcontext() + + for imgs, labels, _sids in loader: + imgs, labels = imgs.to(device), labels.to(device) + optimizer.zero_grad(set_to_none=True) + + # MixUp augmentation (soft labels) + imgs, y_soft, _ = do_mixup(imgs, labels, alpha=mixup_alpha) + + if scaler is not None: + with autocast_ctx: + logits = model(imgs) + loss = F.binary_cross_entropy_with_logits(logits.view(-1), y_soft.view(-1)) + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + if clip_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm) + scaler.step(optimizer) + scaler.update() + else: + with autocast_ctx: + logits = model(imgs) + loss = F.binary_cross_entropy_with_logits(logits.view(-1), y_soft.view(-1)) + loss.backward() + if clip_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm) + optimizer.step() + + # Report hard accuracy against original hard labels (not soft) + acc, _ = binary_metrics(logits.detach(), labels) + bs = imgs.size(0) + running_loss += loss.item() * bs + running_acc += acc * bs + n_samples += bs + + return running_loss / n_samples, running_acc / n_samples + + +# ---------- Evaluation ---------- +@torch.inference_mode() +def evaluate_slice_level(model, loader, device): + """ + Slice-level evaluation on the provided loader. + Returns average loss and accuracy across slices. + """ + model.eval() + total_loss, total_acc, n_samples = 0.0, 0.0, 0 + autocast_ctx = torch.amp.autocast('cuda') if device == "cuda" else nullcontext() + + for imgs, labels, _sids in loader: + imgs, labels = imgs.to(device), labels.to(device) + with autocast_ctx: + logits = model(imgs) + loss = bce_with_logits_loss(logits, labels) + acc, _ = binary_metrics(logits, labels) + + bs = imgs.size(0) + total_loss += loss.item() * bs + total_acc += acc * bs + n_samples += bs + + return total_loss / n_samples, total_acc / n_samples + + +@torch.inference_mode() +def evaluate_subject_level(model, loader, device): + """ + Subject-level evaluation: + - Aggregate all slice logits per subject (mean logit). + - Return subject-level accuracy. + """ + model.eval() + bucket = defaultdict(list) + autocast_ctx = torch.amp.autocast('cuda') if device == "cuda" else nullcontext() + + for imgs, labels, sids in loader: + imgs, labels = imgs.to(device), labels.to(device) + with autocast_ctx: + logits = model(imgs) + + for lg, lab, sid in zip(logits.detach().cpu(), labels.cpu(), sids): + bucket[sid].append((float(lg), int(lab))) + + correct, total = 0, 0 + for sid, entries in bucket.items(): + mean_logit = sum(lg for lg, _ in entries) / len(entries) + pred = 1 if mean_logit >= 0.0 else 0 + true = entries[0][1] + correct += int(pred == true) + total += 1 + + return correct / max(total, 1) + + +# ---------- Main ---------- +def main(): + """ + Orchestrates: + - seeding, device setup + - dataset/dataloader creation (train uses 'train/', val uses 'test/') + - model/optimizer/scheduler/scaler setup + - training loop with subject-level early stopping + - history plots + JSON dump + """ + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--root", type=str, default=CONFIG["ROOT"]) + args, _ = parser.parse_known_args() + + set_seed(CONFIG["SEED"]) + device = "cuda" if torch.cuda.is_available() else "cpu" + + outdir = Path(CONFIG["OUT"]); outdir.mkdir(parents=True, exist_ok=True) + print(f"Device: {device}") + print(f"Root: {args.root}") + + # --- Datasets & loaders --- + # Train: strong augmentation; + # Val(Test): deterministic resize/normalize. + train_ds = ADNIJPEGSlicesDataset( + root=args.root, split="train", + image_size=CONFIG["IMAGE_SIZE"], augment=True, + limit_slices_per_subject=CONFIG["LIMIT_SLICES_PER_SUBJECT"] + ) + val_ds = ADNIJPEGSlicesDataset( + root=args.root, split="test", + image_size=CONFIG["IMAGE_SIZE"], augment=False + ) + + train_loader = DataLoader(train_ds, batch_size=CONFIG["BATCH"], shuffle=True, + num_workers=CONFIG["WORKERS"], pin_memory=True) + val_loader = DataLoader(val_ds, batch_size=CONFIG["BATCH"], shuffle=False, + num_workers=CONFIG["WORKERS"], pin_memory=True) + + # --- Model --- + model = ConvNeXtTiny1C( + in_ch=1, num_classes=1, + drop_path_rate=CONFIG["DROP_PATH_RATE"], + head_drop=CONFIG["HEAD_DROP"] + ).to(device) + + # --- Optimizer (AdamW with decoupled weight decay) --- + decay, no_decay = [], [] + for n, p in model.named_parameters(): + if not p.requires_grad: continue + # Norms & biases go to no_decay + if p.ndim == 1 or n.endswith(".bias") or ("norm" in n.lower()): + no_decay.append(p) + else: + decay.append(p) + optimizer = AdamW( + [{"params": decay, "weight_decay": CONFIG["WEIGHT_DECAY"]}, + {"params": no_decay, "weight_decay": 0.0}], + lr=CONFIG["LR"], betas=(0.9, 0.999) + ) + + # --- Scheduler: warmup (Linear) + cosine anneal --- + warmup_epochs = CONFIG["WARMUP_EPOCHS"] + main_epochs = CONFIG["EPOCHS"] - warmup_epochs + warmup = LinearLR(optimizer, start_factor=1e-3, end_factor=1.0, total_iters=warmup_epochs) + cosine = CosineAnnealingLR(optimizer, T_max=main_epochs, eta_min=CONFIG["ETA_MIN"]) + scheduler = SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs]) + + # --- AMP scaler (CUDA) --- + scaler = torch.amp.GradScaler('cuda') if device == "cuda" else None + + # --- Training bookkeeping --- + history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": [], "val_subj_acc": []} + best_metric, best_path = -1.0, outdir / "best_model.pt" + + # --- Early stopping state --- + patience = int(CONFIG.get("EARLY_STOP_PATIENCE", 10)) + no_improve = 0 + + print("\n=== Training ===") + for epoch in range(1, CONFIG["EPOCHS"] + 1): + t0 = time.time() + + # ---- Train ---- + tr_loss, tr_acc = train_one_epoch( + model, train_loader, optimizer, device, + scaler=scaler, mixup_alpha=CONFIG["MIXUP_ALPHA"], + clip_norm=CONFIG["CLIP_NORM"] + ) + + # ---- Validate (slice + subject-level) ---- + val_loss, val_acc = evaluate_slice_level(model, val_loader, device) + subj_metric = evaluate_subject_level(model, val_loader, device) if CONFIG["SUBJECT_EVAL"] else val_acc + + # ---- Log / step LR ---- + history["train_loss"].append(tr_loss); history["val_loss"].append(val_loss) + history["train_acc"].append(tr_acc); history["val_acc"].append(val_acc) + history["val_subj_acc"].append(subj_metric) + scheduler.step() + + # ---- Progress line ---- + line = ( + f"Epoch {epoch:03d}/{CONFIG['EPOCHS']} | " + f"Train {tr_loss:.4f}/{tr_acc:.3f} | " + f"Val {val_loss:.4f}/{val_acc:.3f} | " + f"Subj {subj_metric:.3f} | " + f"LR={optimizer.param_groups[0]['lr']:.6g} | {time.time()-t0:.1f}s" + ) + print(line) + + # ---- Save best & Early stop ---- + if subj_metric > best_metric: + best_metric = subj_metric + save_checkpoint(model, best_path) + no_improve = 0 # reset patience on improvement + else: + no_improve += 1 + if no_improve >= patience: + print(f"Early stopping at epoch {epoch} (no improvement for {patience} epochs).") + break + + # --- Plots & history dump --- + plot_history(history, outdir) + with open(outdir / "history.json", "w") as f: + json.dump(history, f, indent=2) + + print(f"\nDone. Best subject acc: {best_metric:.3f}") + print(f"Best checkpoint: {best_path}") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..b0f33d74a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +--extra-index-url https://download.pytorch.org/whl/cu124 +torch==2.5.1+cu124 +torchvision==0.20.1+cu124 + +numpy>=1.26 +matplotlib>=3.8 +pillow>=10.3 +scikit-learn>=1.4 \ No newline at end of file