diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index a8899eb..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml new file mode 100644 index 0000000..1693d02 --- /dev/null +++ b/.github/workflows/build-and-test.yml @@ -0,0 +1,45 @@ +# +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# + +name: Build and Test + +on: + push: + branches: + - main + pull_request: + workflow_dispatch: + workflow_call: + +jobs: + pylint: + name: PyLint + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.12"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + - name: Install Infrastructure + run: | + pip install -r requirements.txt + pip install pylint + - name: Analysing the code with pylint + run: | + pylint $(git ls-files '*.py') + black_lint: + name: Black Code Formatter Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + - name: Install Black + run: pip install black[jupyter] + - name: Check code formatting with Black + run: black . --exclude '\.ipynb$' diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml new file mode 100644 index 0000000..c74fa09 --- /dev/null +++ b/.github/workflows/pull_request.yml @@ -0,0 +1,37 @@ +# +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# + +name: Pull Request + +on: + pull_request: + workflow_dispatch: + +jobs: + reuse_action: + name: REUSE Compliance Check + uses: DaneshjouLab/.github/.github/workflows/reuse.yml@main + markdown_link_check: + name: Markdown Link Check + uses: DaneshjouLab/.github/.github/workflows/markdown-link-check.yml@main + yamllint: + name: YAML Lint Check + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + + - name: Install yamllint + run: pip install yamllint + + - name: Run yamllint with custom config + run: yamllint -c .yamllint .github/workflows/*.yml diff --git a/.gitignore b/.gitignore index 9c9974f..070e362 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,14 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # Byte-compiled / optimized / DLL files __pycache__/ *.py[codz] *$py.class +.python-version # C extensions *.so @@ -206,5 +213,4 @@ marimo/_static/ marimo/_lsp/ __marimo__/ - -.DS_Store \ No newline at end of file +**/.DS_Store \ No newline at end of file diff --git a/.reuse/dep5.txt b/.reuse/dep5.txt new file mode 100644 index 0000000..6f0048b --- /dev/null +++ b/.reuse/dep5.txt @@ -0,0 +1,11 @@ +Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ + +Files: media/*.png +Copyright: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) +License: MIT +Comment: All files are part of the Daneshjou Lab projects. + +Files: results/*.json +Copyright: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) +License: MIT +Comment: All files are part of the Daneshjou Lab projects. \ No newline at end of file diff --git a/.yamllint b/.yamllint new file mode 100644 index 0000000..023e567 --- /dev/null +++ b/.yamllint @@ -0,0 +1,13 @@ +--- +extends: default + +rules: + truthy: + level: warning + allowed-values: ["false", "true", "on", "off"] + document-start: + level: warning + present: false + line-length: + max: 180 + level: warning \ No newline at end of file diff --git a/.yamllint.license b/.yamllint.license new file mode 100644 index 0000000..3184dcc --- /dev/null +++ b/.yamllint.license @@ -0,0 +1,5 @@ +# This source file is part of the ARPA-H CARE LLM project +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT \ No newline at end of file diff --git a/LICENSES/MIT.txt b/LICENSES/MIT.txt new file mode 100644 index 0000000..d817195 --- /dev/null +++ b/LICENSES/MIT.txt @@ -0,0 +1,18 @@ +MIT License + +Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +associated documentation files (the "Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the +following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial +portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT +LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO +EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md index a9969ef..b7278ae 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,14 @@ + + # Finetuning Pretrained Models for Compressed Dermatology Image Analysis This project explores how compressed and degraded dermatology images (from the ISIC 2019 dataset) affect classification performance using pretrained vision models. It compares fine-tuning vs. linear probing across multiple JPEG quality levels. -![System architecture diagram](<./CS231N Poster.png>) ## Project Goals @@ -30,6 +36,14 @@ This project explores how compressed and degraded dermatology images (from the I reduced-perception/ ├── configs/ │ └── example_config.yaml # Configs for job submissions +compressed-perception/ +├── README.md # Project overview & documentation +├── LICENSES/ # Directory containing license files (REUSE compliance) +│ └── MIT.txt # MIT license text +│ +├── pyproject.toml # Python packaging config +├── setup.py # Installation script for the package +├── setup.cfg # Configuration for setup tools │ ├── scripts/ # Lightweight utility or shell scripts │ ├── download_unpack_isic2019.sh # Downloads and unpacks ISIC data @@ -74,6 +88,101 @@ reduced-perception/ 3. View results We use weights and biases for logging, so output plots can be seen there +├── requirements.txt # Dependencies file +├── requirements.txt.license # Dependencies file license +├── .yamllint # YAML linter configuration +├── .yamllint.license # YAML linter configuration license +│ +├── .github/ # GitHub specific files +│ └── workflows/ # CI/CD workflow definitions +│ ├── build-and-test.yml +│ └── pull_request.yml +│ +├── .reuse/ # REUSE compliance configuration +│ └── dep5 # Copyright and license information +│ +├── docs/ # Documentation +│ └── pipeline.md # Pipeline documentation +│ +├── scripts/ # Standalone scripts +│ ├── ... +│ └── visualize_isic_results.py # Visualize metrics for model comparison (TODO) +│ +├── configs/ # Configuration files (TODO) +│ ├── datasets/ # Dataset configs +│ │ └── isic2019.yaml # ISIC 2019 dataset config +│ ├── models/ # Model configs +│ │ ├── vit.yaml # ViT model config +│ │ ├── dinov2.yaml # DINOv2 model config +│ │ └── simclr.yaml # SimCLR model config +│ ├── experiments/ # Experiment configs +│ │ ├── baseline.yaml # Baseline experiment +│ │ └── lr_sweep.yaml # Learning rate sweep experiment +│ └── example_config.yaml # Example configuration file +│ +│ +├── tests/ # Test suite (TODO) +│ ├── unit/ # Unit tests +│ │ └── test_transforms.py +│ ├── integration/ # Integration tests +│ │ └── test_pipeline.py +│ └── conftest.py # Test fixtures and configuration +│ +├── src/ +│ └── compressed_perception/ # Main package +│ ├── __init__.py # Package initialization +│ │ +│ ├── models/ # Model implementations +│ │ ├── __init__.py +│ │ ├── architectures/ # Model architecture definitions +│ │ │ ├── __init__.py +│ │ │ ├── vit.py # Vision Transformer adaptations +│ │ │ └── simclr.py # SimCLR adaptations +│ │ │ +│ │ ├── evaluation/ # Model evaluation code +│ │ │ ├── __init__.py +│ │ │ └── metrics.py # Evaluation metrics +│ │ │ +│ │ ├── comparison/ # Model comparison utilities +│ │ │ ├── __init__.py +│ │ │ ├── compare_baseline.py # Baseline comparison +│ │ │ └── compare_lr_sweep.py # Learning rate sweeping +│ │ │ +│ │ ├── training/ # Training infrastructure +│ │ │ ├── __init__.py +│ │ │ ├── trainers.py # Training loops +│ │ │ └── callbacks.py # Training callbacks +│ │ │ +│ │ └── utils/ # Model utilities +│ │ ├── __init__.py +│ │ ├── constants.py # Model constants +│ │ └── helpers.py # Helper functions +│ │ +│ ├── modules/ # Reusable modules +│ │ ├── __init__.py +│ │ ├── transforms/ # Image transformations +│ │ │ ├── __init__.py +│ │ │ ├── degradation.py # Image degradation transforms +│ │ │ └── augmentation.py # Data augmentation transforms +│ │ │ +│ │ └── data_preparation/ # Data preparation utilities +│ │ ├── __init__.py +│ │ └── preparation.py # Dataset preparation +│ │ +│ +│ +├── results/ +│ +├── jobs/ # Cluster job submission files +│ ├── job_template.slurm # SLURM job template +│ ├── run.sh # General run script +│ ├── rurun_compare_baseline.sh # Learning rate experiment script +│ ├── run_compare_lr_sweep.sh # Model comparison script +│ └── configs/ # Job configurations +│ +└── media/ # Media files for documentation + └── CS231N Poster.png # Project poster +``` ## 📦 Dataset diff --git a/docs/pipeline.md b/docs/pipeline.md new file mode 100644 index 0000000..da4efbf --- /dev/null +++ b/docs/pipeline.md @@ -0,0 +1,72 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +# Model Comparison Pipeline + +## Overview + +This script (`model_comparison_models.py`) provides a baseline for comparing different image classification models at various image compression levels, including the original images. It supports fine-tuning, linear probing, and optional image degradation transforms. + +--- + +## Features + +- **Model Support:** Vision Transformer (ViT), DINOv2, SimCLR (self-supervised backbone) +- **Data Augmentation:** Optional JPEG compression, Gaussian blur, and color quantization +- **Dataset Balancing:** Ensures equal samples per class for fair comparison +- **Training & Evaluation:** Uses Hugging Face Trainer for streamlined workflows +- **Experiment Tracking:** Integrated with Weights & Biases (`wandb`) +- **GPU Monitoring:** Optional support via `pynvml` + +--- + +## Workflow + +1. **Environment Setup** + - Loads required libraries and sets up cache directories. + - Checks for GPU availability. + +2. **Dataset Loading & Balancing** + - Loads ISIC_2019_224 dataset. + - Balances the dataset across filtered classes. + +3. **Model Initialization** + - Initializes model and preprocessor based on configuration. + +4. **Preprocessing & Augmentation** + - Applies resizing, normalization, and optional degradation transforms. + +5. **Training & Evaluation** + - Splits data into training and validation sets. + - Trains and evaluates each model, logging results to `wandb`. + +--- + +## Usage + +```bash +python [model_comparison_models.py](http://_vscodecontentref_/2) --resolution 224 --batch_size 256 --num_train_images 25000 --num_epochs 10 --eval_steps 10 +``` + +## Configuration +- Models: Edit the models list in the script to add or modify model configurations. +- Transforms: Toggle apply_transforms in prepare_datasets() to enable/disable augmentations. +- Hyperparameters: Adjust arguments in the main() function for batch size, epochs, etc. + +## Output +- Training and evaluation metrics are printed to the console and logged to Weights & Biases. +- Results can be used for further analysis or ablation studies. + +## Extending +- Add new models by updating the models list. +- Implement new transforms in utils/transforms.py. +- Add new datasets by modifying the dataset loading logic. + +## References +- Hugging Face Transformers +- PyTorch +- Weights & Biases +- ISIC 2019 Dataset \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0acb2d9..c224c16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ +accelerate annotated-types anyio asttokens -accelerate beautifulsoup4 biopython certifi @@ -10,7 +10,7 @@ colorama comm contourpy cycler -datasets +datasets>=2.18.0 debugpy decorator distro @@ -21,6 +21,7 @@ fonttools h11 httpcore httpx +hydra-core==1.3.2 idna importlib_metadata ipykernel @@ -35,10 +36,11 @@ loguru matplotlib matplotlib-inline nest_asyncio -numpy +numpy>=1.26.0 +omegaconf==2.3.0 openai packaging -pandas +pandas>=2.2.0 parso pickleshare pillow @@ -46,34 +48,33 @@ platformdirs prompt_toolkit psutil pure_eval -pydantic -pydantic_core +pydantic>=2.6.0 +pydantic_core>=2.16.0 pynvml Pygments pyparsing python-dateutil python-dotenv requests -scikit-learn -scipy +scikit-learn>=1.3.0 +scipy>=1.11.0 seaborn setuptools six sniffio soupsieve stack-data -timm thop threadpoolctl -tornado -torch -torchvision -transformers -tqdm +timm +torch>=2.0.0 +torchvision>=0.15.0 +tqdm>=4.65.0 traitlets -typing-inspection -typing_extensions +transformers>=4.30.0 triton +typing-inspect +typing_extensions tzdata urllib3 wandb diff --git a/requirements.txt.licence b/requirements.txt.licence deleted file mode 100644 index 176ed16..0000000 --- a/requirements.txt.licence +++ /dev/null @@ -1 +0,0 @@ -# MIT \ No newline at end of file diff --git a/requirements.txt.license b/requirements.txt.license new file mode 100644 index 0000000..3cc951b --- /dev/null +++ b/requirements.txt.license @@ -0,0 +1,5 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT \ No newline at end of file diff --git a/scripts/test_transforms.py b/scripts/test_transforms.py deleted file mode 100644 index 933b2a3..0000000 --- a/scripts/test_transforms.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Minimal check: dataset loads and a degradation transform is applied to 5 images. -- Uses your BaseDataModule and get_degradation_transforms() -- No training, no W&B. -- Saves clean and degraded PNGs. -""" - -from types import SimpleNamespace -from pathlib import Path -import argparse - -from torchvision.utils import save_image -import torchvision.transforms.functional as TF -from PIL import Image - -from src.data.datamodule import BaseDataModule -from src.transformation.transforms import get_degradation_transforms -from src.utils import setup_environment - - -def main(): - ap = argparse.ArgumentParser(description="Quick dataset+transform smoke test (5 images).") - ap.add_argument("--dataset", type=str, required=True, choices=["isic", "tcga", "merlin"]) - ap.add_argument("--data_dir", type=str, required=True) - ap.add_argument("--out_dir", type=str, default="outputs/test_transforms") - ap.add_argument("--resolution", type=int, default=224) - ap.add_argument("--num_workers", type=int, default=2) - args = ap.parse_args() - - setup_environment() - - out_root = Path(args.out_dir) - out_root.mkdir(parents=True, exist_ok=True) - - # ---- minimal cfg your DataModule/get_dataset can read ---- - cfg = SimpleNamespace( - resolution=args.resolution, - batch_size=5, # just enough to grab 5 samples - ) - - # ---- build the DataModule exactly as your code expects ---- - dm = BaseDataModule( - cfg=cfg, - dataset_name=args.dataset, - data_dir=args.data_dir, - num_workers=args.num_workers, - batch_size=cfg.batch_size, - pin_memory=True, - drop_last=False, - ) - dm.setup(stage="fit") - - # ---- pull one small batch (5 images) from val ---- - val_loader = dm.val_dataloader() - batch = next(iter(val_loader)) - - if isinstance(batch, dict): - x = batch.get("pixel_values") or batch.get("images") or batch.get("x") - y = batch.get("labels") or batch.get("y") - else: - x, y = batch - - # keep exactly 5 - x = x[:5] - - for i in range(x.size(0)): - save_image(x[i], out_root / f"clean_{i}.png") - - degradations = get_degradation_transforms() - if not degradations: - print("No degradations returned by get_degradation_transforms(); saved only clean images.") - return - - degr = degradations[0] - tag = degr.__class__.__name__.lower() - print(f"Applying degradation: {tag}") - - for i in range(x.size(0)): - pil_img = TF.to_pil_image(x[i].cpu()) - pil_out = degr(pil_img) - t_out = TF.to_tensor(pil_out) - save_image(t_out, out_root / f"{tag}_{i}.png") - - print(f"✅ Done. Wrote images to: {out_root.resolve()}") - - -if __name__ == "__main__": - main() diff --git a/scripts/trial.py b/scripts/trial.py new file mode 100644 index 0000000..f967db0 --- /dev/null +++ b/scripts/trial.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: MIT +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# pylint: disable=all +""" +Trial: load ISIC (HF, 224×224), then save original vs. 112×112 downsampled. +""" + +import sys +from pathlib import Path + +current_dir = Path(__file__).resolve().parent +project_root = current_dir.parent +sys.path.insert(0, str(project_root)) + +from src.data.isic_loader import ISICBaseDataset # HF-backed +from src.transformation.transforms import ResolutionReductionTransform + +def main(): + out_dir = Path("outputs/trial_isic112") + out_dir.mkdir(parents=True, exist_ok=True) + + # IMPORTANT: no transform here, so we get true originals from HF (224×224) + ds = ISICBaseDataset( + repo_id="MKZuziak/ISIC_2019_224", + split="train", + transform=None, + ) + print(f"Loaded {len(ds)} samples") + + # Build the reduction transform to 112×112 + reduce112 = ResolutionReductionTransform( + target_resolution=(112, 112), + restore_original_size=False, + ) + + for i in range(5): + sample = ds[i] + img = sample["image"] + lbl = sample["label"] + + orig_path = out_dir / f"original_{i}.png" + img.save(orig_path) + + reduced = reduce112(img) + red_path = out_dir / f"reduced_{i}.png" + reduced.save(red_path) + + print(f"[{i}] label={lbl} | original={img.size} → reduced={reduced.size} | " + f"saved: {orig_path.name}, {red_path.name}") + + print(f"\n✅ Done. Check: {out_dir.resolve()}") + +if __name__ == "__main__": + main() diff --git a/src/__init__.py b/src/__init__.py index 1a66f1d..e69de29 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,3 +0,0 @@ -""" -Root package for the project. -""" diff --git a/src/cli/train.py b/src/cli/train.py index 5fe1460..1a9e5e5 100644 --- a/src/cli/train.py +++ b/src/cli/train.py @@ -1,4 +1,54 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# SPDX-License-Identifier: MIT + # src/cli/train.py +# -*- coding: utf-8 -*- +# pylint: disable=import-error, broad-exception-caught +""" +CLI entry point for training/evaluating models across paradigms (probe/finetune). +It normalizes dataset config → data config, builds a BaseDataModule using the +dataset factory, and dispatches to the selected training wrapper. + +Usage examples: + python -m src.cli.train train.mode=probe \ + dataset.name=isic2019 dataset.data_dir=/data/ISIC \ + dataset.image_size=224 dataset.batch_size=128 \ + model.type=vit model.model_id=google/vit-base-patch16-224 + + # With controlled degradation: downsample to 112px, then pipeline resizes to 224 + python -m src.cli.train train.mode=probe \ + dataset.name=isic2019 dataset.data_dir=/data/ISIC \ + dataset.image_size=224 dataset.batch_size=128 \ + dataset.degradation.target_resolution=112 \ + model.type=vit model.model_id=google/vit-base-patch16-224 + + # Random search (build-in basic sweep) + python -m src.cli.train -m \ + hydra.sweeper=basic \ + hydra.sweeper.n_trials=20 \ + hydra.sweeper.params='optim.lr=log(1e-5,1e-3); optim.weight_decay=uniform(0,0.1); \ +dataset.batch_size=choice(64,128)' \ + train.mode=probe \ + dataset.name=isic2019 dataset.data_dir=/data/ISIC \ + dataset.image_size=224 \ + model.type=vit model.model_id=google/vit-base-patch16-224 + + OR + + python -m src.cli.train -m \ + train.mode=probe \ + dataset.name=isic2019 dataset.data_dir=/data/ISIC \ + dataset.image_size=224 dataset.batch_size=128 \ + model.type=vit model.model_id=google/vit-base-patch16-224 \ + optim.lr=1e-5,3e-5,1e-4,3e-4 \ + optim.weight_decay=0.0,0.01 + +""" + +from __future__ import annotations + import os import sys import time @@ -6,38 +56,50 @@ import random import logging from pathlib import Path -from typing import Dict, Any +from typing import Dict, Any, Optional, Tuple, Union import torch import numpy as np -import hydra -from omegaconf import DictConfig, OmegaConf +import hydra # pylint: disable=import-error +from omegaconf import DictConfig, OmegaConf # pylint: disable=import-error -# ---- Optional: import wrappers (these implement run(cfg) and return a dict of metrics) -from src.wrappers import probe as probe_wrapper -from src.wrappers import finetune as finetune_wrapper -from src.wrappers import distill as distill_wrapper +# ---- Training wrappers (each provides run(cfg) -> dict of metrics) +from src.wrappers import probe as probe_wrapper # pylint: disable=import-error +from src.wrappers import finetune as finetune_wrapper # pylint: disable=import-error + +# ---- Data pipeline +from src.data.datamodule import BaseDataModule # pylint: disable=import-error +from src.transformations.transforms import ( + ResolutionReductionTransform, +) # pylint: disable=import-error + +# ---- Optional HF preprocessor (only needed when actually running a HF backbone) +try: + from transformers import AutoImageProcessor # noqa: F401 +except ImportError: # pragma: no cover + # If transformers is not installed, we continue without it + AutoImageProcessor = None # type: ignore log = logging.getLogger("train") +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + def _is_rank_zero() -> bool: - # Works for both torchrun (DDP) and single process local_rank = int(os.environ.get("LOCAL_RANK", "0")) return local_rank == 0 def _select_device(cfg: DictConfig) -> torch.device: - # Honor cfg.train.device if present; otherwise auto - if "device" in cfg.train and cfg.train.device: - dev = cfg.train.device - return torch.device(dev) - if torch.cuda.is_available(): - return torch.device("cuda") - return torch.device("cpu") + if "train" in cfg and "device" in cfg.train and cfg.train.device: + return torch.device(cfg.train.device) + return torch.device("cuda" if torch.cuda.is_available() else "cpu") -def _seed_everything(seed: int, deterministic: bool = False): +def _seed_everything(seed: int, deterministic: bool = False) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) @@ -49,22 +111,23 @@ def _seed_everything(seed: int, deterministic: bool = False): torch.backends.cudnn.benchmark = True -def _save_resolved_config(cfg: DictConfig, run_dir: Path): +def _save_resolved_config(cfg: DictConfig, run_dir: Path) -> None: if not _is_rank_zero(): return run_dir.mkdir(parents=True, exist_ok=True) - with open(run_dir / "resolved_config.yaml", "w") as f: + with open(run_dir / "resolved_config.yaml", "w", encoding="utf-8") as f: OmegaConf.save(config=cfg, f=f.name) -def _print_run_header(cfg: DictConfig, run_dir: Path, device: torch.device): +def _print_run_header(cfg: DictConfig, run_dir: Path, device: torch.device) -> None: if not _is_rank_zero(): return + model_name = getattr(cfg.model, "name", getattr(cfg.model, "type", "N/A")) banner = ( f"\n=== TRAIN START ===\n" f"mode : {cfg.train.mode}\n" f"dataset : {cfg.dataset.name}\n" - f"model : {getattr(cfg.model, 'name', 'N/A')}\n" + f"model : {model_name}\n" f"device : {device}\n" f"seed : {cfg.seed}\n" f"run_dir : {str(run_dir)}\n" @@ -74,33 +137,179 @@ def _print_run_header(cfg: DictConfig, run_dir: Path, device: torch.device): def _dispatch_wrapper(cfg: DictConfig) -> Dict[str, Any]: - mode = cfg.train.mode.lower() + mode = str(cfg.train.mode).lower() if mode == "probe": return probe_wrapper.run(cfg) - elif mode == "finetune": + if mode == "finetune": return finetune_wrapper.run(cfg) - elif mode == "distill": - return distill_wrapper.run(cfg) - else: - raise ValueError( - f"Unknown train.mode='{cfg.train.mode}'. " - f"Expected one of: probe | finetune | distill" + # if mode == "distill": + # return distill_wrapper.run(cfg) + raise ValueError( + f"Unknown train.mode='{cfg.train.mode}'. Expected one of: probe | finetune | distill" + ) + + +def _ensure_keys(d: DictConfig, keys_defaults: Dict[str, Any]) -> None: + """Ensure keys exist in DictConfig with defaults (in-place).""" + for k, v in keys_defaults.items(): + if k not in d or d[k] is None: + d[k] = v + + +def _normalize_dataset_into_data(cfg: DictConfig) -> None: + """ + Map cfg.dataset → cfg.data so wrappers and datamodules get a consistent schema. + + Expected cfg.dataset fields (typical): + - name: str (e.g., "isic2019", "chexpert") + - data_dir: str or null (some loaders use HF hub instead) + - image_size: int (model input size, e.g., 224) + - num_classes: int (optional; forwarded to model.config.num_labels) + - batch_size: int + - num_workers: int + - pin_memory: bool (optional) + - degradation: (optional group) + target_resolution: int | [w,h] + restore_original_size: bool + # or + reduction_factor: float in (0,1] + + After this, cfg.data contains: + dataset_name, data_dir, image_size, batch_size, num_workers, pin_memory + """ + if "dataset" not in cfg: + raise ValueError("Config is missing 'dataset' group (cfg.dataset.*).") + + ds = cfg.dataset + if "data" not in cfg or cfg.data is None: + cfg.data = OmegaConf.create() + + cfg.data.dataset_name = str(ds.get("name")) + cfg.data.data_dir = ds.get("data_dir") # can be None for HF datasets + cfg.data.image_size = int(ds.get("image_size", 224)) + cfg.data.batch_size = int( + ds.get("batch_size", getattr(cfg.train, "batch_size", 64)) + ) + cfg.data.num_workers = int(ds.get("num_workers", 4)) + cfg.data.pin_memory = bool(ds.get("pin_memory", True)) + + # Optional: propagate num_classes → model.config.num_labels + num_classes = ds.get("num_classes", None) + if num_classes is not None: + if "model" not in cfg: + cfg.model = OmegaConf.create() + if "config" not in cfg.model or cfg.model.config is None: + cfg.model.config = OmegaConf.create() + cfg.model.config.num_labels = int(num_classes) + + if not cfg.data.dataset_name: + raise ValueError("cfg.dataset.name must be set (e.g., 'isic2019').") + + +def _build_degradation_transform( + degr_cfg: DictConfig, +) -> Optional[ResolutionReductionTransform]: + """ + Build a ResolutionReductionTransform from a degradation config group. + Supports: + - target_resolution: int | [w,h] + - restore_original_size: bool + - reduction_factor: float + """ + if not degr_cfg: + return None + + # normalize target_resolution + target_res: Optional[Union[int, Tuple[int, int]]] = getattr( + degr_cfg, "target_resolution", None + ) + restore: bool = bool(getattr(degr_cfg, "restore_original_size", False)) + reduction_factor = getattr(degr_cfg, "reduction_factor", None) + + if target_res is not None: + if isinstance(target_res, int): + target_res = (int(target_res), int(target_res)) + else: + target_res = tuple(int(x) for x in target_res) + return ResolutionReductionTransform( + target_resolution=target_res, restore_original_size=restore + ) + + if reduction_factor is not None: + return ResolutionReductionTransform( + reduction_factor=float(reduction_factor), restore_original_size=restore ) + return None + + +def _build_datamodule(cfg: DictConfig) -> BaseDataModule: + """ + Construct the BaseDataModule using normalized cfg.data and optional + degradation transforms. Keeps HF preprocessor optional. + """ + # Optional degradation transform + degr_cfg = getattr(cfg.dataset, "degradation", None) + transform = _build_degradation_transform(degr_cfg) if degr_cfg else None + + # Optional HF preprocessor (only if actually running a HF backbone) + preproc = None + model_id = getattr(cfg.model, "model_id", None) + if model_id and AutoImageProcessor is not None: + try: + preproc = AutoImageProcessor.from_pretrained(model_id) + except Exception: + preproc = None # safe no-op path + + image_size = int(getattr(cfg.data, "image_size", 224)) + batch_size = int(getattr(cfg.data, "batch_size", 64)) + num_workers = int(getattr(cfg.data, "num_workers", 4)) + pin_memory = bool(getattr(cfg.data, "pin_memory", True)) + model_type = str(getattr(cfg.model, "type", "vit")) + + dm = BaseDataModule( + cfg=cfg, + dataset_name=str(cfg.data.dataset_name), + data_dir=getattr(cfg.data, "data_dir", None), + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + preprocessor=preproc, # None -> ModelPreprocessor no-op + resolution=image_size, # model input size (e.g., 224) + transform=transform, # may be None + model_type=model_type, + ) + + # Expose to wrappers + cfg.runtime = dict(getattr(cfg, "runtime", {})) + cfg.runtime["datamodule"] = dm + return dm + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + @hydra.main(config_path="../../configs", config_name="defaults", version_base=None) def main(cfg: DictConfig): - # ---- Resolve and freeze config for safety - OmegaConf.set_struct(cfg, False) # allow wrappers to attach runtime fields if needed + """Main training CLI entry point.""" + OmegaConf.set_struct(cfg, False) + + # Normalize dataset selection into cfg.data for wrappers/datamodules + _normalize_dataset_into_data(cfg) - # ---- Set up run directory (Hydra changes CWD to a unique run dir automatically) + # Set up run directory (Hydra sets CWD to the unique run dir) run_dir = Path(os.getcwd()) - # ---- Device & seeding + # Device & seeding device = _select_device(cfg) - _seed_everything(seed=int(cfg.seed), deterministic=getattr(cfg.train, "deterministic", False)) + _seed_everything( + seed=int(cfg.seed), + deterministic=bool(getattr(cfg.train, "deterministic", False)), + ) - # ---- Optional: attach runtime info for wrappers/engines + # Attach runtime info used by wrappers/engines cfg.runtime = { "device": str(device), "start_time": time.strftime("%Y-%m-%d %H:%M:%S"), @@ -109,11 +318,15 @@ def main(cfg: DictConfig): "world_size": int(os.environ.get("WORLD_SIZE", "1")), } - # ---- Save resolved config + # Build & setup datamodule using dataset factory logic + dm = _build_datamodule(cfg) + dm.setup(stage="fit") # prepares train/val (or splits train if no val split) + + # Persist resolved config and print header _save_resolved_config(cfg, run_dir) _print_run_header(cfg, run_dir, device) - # ---- Kick off the selected training paradigm via wrapper + # Kick off the selected training paradigm try: metrics = _dispatch_wrapper(cfg) except KeyboardInterrupt: @@ -121,19 +334,20 @@ def main(cfg: DictConfig): print("\n⚠️ Training interrupted by user.", flush=True) raise except Exception as e: - # Surface a readable error at rank zero; still propagate for proper exit codes if _is_rank_zero(): print(f"\n❌ Training failed: {e}\n", flush=True) raise - # ---- Persist final metrics + # Save final metrics if _is_rank_zero(): metrics = metrics or {} - with open(run_dir / "final_metrics.json", "w") as f: + with open(run_dir / "final_metrics.json", "w", encoding="utf-8") as f: json.dump(metrics, f, indent=2) - print("✅ Train done. Final metrics written to final_metrics.json\n", flush=True) + print( + "✅ Train done. Final metrics written to final_metrics.json\n", flush=True + ) if __name__ == "__main__": - # Support `python -m src.cli.train train=probe dataset=isic2019` + # pylint: disable=no-value-for-parameter sys.exit(main()) diff --git a/src/config.py b/src/config.py index 88abf6f..00c50c6 100644 --- a/src/config.py +++ b/src/config.py @@ -1,6 +1,20 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + """Configuration and constants.""" from dataclasses import dataclass -from typing import List, Dict, Any +from typing import Dict, Any + +# Optional import - used for hardware detection +try: + import torch # pylint: disable=import-error + CUDA_AVAILABLE = torch.cuda.is_available() +except ImportError: + # If torch is not available, assume no CUDA + CUDA_AVAILABLE = False # Model constants HF_MODELS = ["vit", "dinov2"] @@ -17,9 +31,11 @@ "std": [0.229, 0.224, 0.225], } + @dataclass -class TrainingConfig: +class TrainingConfig: # pylint: disable=too-many-instance-attributes """Training configuration.""" + num_train_images: int = 100 proportion_per_transform: float = 0.2 resolution: int = 224 @@ -28,37 +44,31 @@ class TrainingConfig: eval_steps: int = 10 learning_rate: float = 1e-4 weight_decay: float = 0.01 - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for logging.""" return self.__dict__ - + def to_wandb_config(self) -> Dict[str, Any]: """Create wandb configuration.""" - import torch return { **self.to_dict(), - "gpu_available": torch.cuda.is_available(), + "gpu_available": CUDA_AVAILABLE, } + # Model configurations MODEL_REGISTRY = [ { "name": "vit", "model_id": "google/vit-base-patch16-224", "type": "vit", - "config": { - "num_labels": NUM_FILTERED_CLASSES, - "ignore_mismatched_sizes": True - } + "config": {"num_labels": NUM_FILTERED_CLASSES, "ignore_mismatched_sizes": True}, }, { "name": "dinov2", "model_id": "facebook/dinov2-base", "type": "dinov2", - "config": { - "num_labels": NUM_FILTERED_CLASSES, - "ignore_mismatched_sizes": True - } + "config": {"num_labels": NUM_FILTERED_CLASSES, "ignore_mismatched_sizes": True}, }, -] \ No newline at end of file +] diff --git a/src/engines/distill_engine.py b/src/data/__init__.py similarity index 100% rename from src/engines/distill_engine.py rename to src/data/__init__.py diff --git a/src/data/data_utils.py b/src/data/data_utils.py deleted file mode 100644 index 351e018..0000000 --- a/src/data/data_utils.py +++ /dev/null @@ -1,200 +0,0 @@ -"""Dataset implementations and data utilities.""" -import numpy as np -import torch -from PIL import Image -from torch.utils.data import Dataset, ConcatDataset, Subset -from torchvision import transforms -from typing import Optional, List, Dict, Any, Union - -from src.config import HF_MODELS, DEFAULT_IMAGE_SIZE -from src.transformation.transforms import ResolutionReductionTransform - -class ISICDataset(Dataset): - """ISIC dataset with support for multiple model types and transformations.""" - - def __init__( - self, - dataset: Union[Dataset, Subset], - preprocessor: Optional[Any] = None, - resolution: int = DEFAULT_IMAGE_SIZE, - transform: Optional[transforms.Compose] = None, - model_type: str = "vit", - jpeg_quality: Optional[int] = None, - ): - self.dataset = dataset - self.preprocessor = preprocessor - self.resolution = resolution - self.transform = transform - self.model_type = model_type - self.jpeg_quality = jpeg_quality - - # Create base preprocessing - self.base_preprocessor = transforms.Compose([ - transforms.Resize((resolution, resolution), Image.LANCZOS), - transforms.ToTensor(), - ]) - - self.model_preprocessor = None - - def __len__(self) -> int: - return len(self.dataset) - - def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: - # Convert numpy types to Python int - if isinstance(idx, (np.integer, np.int64)): - idx = int(idx) - - # Handle both direct dataset and Subset access - if hasattr(self.dataset, 'dataset'): - # This is a Subset - subset_idx = int(self.dataset.indices[idx]) - item = self.dataset.dataset[subset_idx] - else: - # Direct dataset access - item = self.dataset[idx] - - image = item["image"] - label = item["label"] - - # Resize to target resolution - image = image.resize((self.resolution, self.resolution), Image.LANCZOS) - - # Apply optional transformations - if self.transform: - image = self.transform(image) - - if self.jpeg_quality is not None: - image = JPEGCompressionTransform(self.jpeg_quality)(image) - - # Apply model-specific preprocessing - if self.model_type in HF_MODELS: - # For HuggingFace models - if hasattr(self.preprocessor, 'size'): - self.preprocessor.size = self.resolution - encoding = self.preprocessor(images=image, return_tensors="pt") - pixel_values = encoding["pixel_values"].squeeze(0) - else: - raise ValueError(f"Unsupported model_type: {self.model_type}") - - label = torch.tensor(label, dtype=torch.long) - return {"pixel_values": pixel_values, "labels": label} - -def create_transformed_datasets( - train_dataset: Dataset, - val_dataset: Dataset, - transforms_list: List, - proportion_per_transform: float, - preprocessor: Optional[Any], - resolution: int, - model_type: str -) -> tuple[Dataset, Dataset]: - """ - Create train and validation datasets with transformations. - - Args: - train_dataset: Training dataset - val_dataset: Validation dataset - transforms_list: List of transforms to apply - proportion_per_transform: Proportion of data for each transform - preprocessor: Model preprocessor - resolution: Image resolution - model_type: Type of model - - Returns: - Tuple of (train_dataset, val_dataset) - """ - num_images = len(train_dataset) - images_per_transform = int(num_images * proportion_per_transform) - - # Shuffle indices - indices = np.arange(num_images) - np.random.shuffle(indices) - - transformed_datasets = [] - used_indices = [] - - # Apply each transform to a subset - for i, transform in enumerate(transforms_list): - start_idx = i * images_per_transform - end_idx = start_idx + images_per_transform - subset_indices = indices[start_idx:end_idx] - used_indices.extend(subset_indices) - - subset = Subset(train_dataset, subset_indices) - transform_compose = transforms.Compose([transform]) - - transformed_ds = ISICDataset( - subset, - preprocessor, - resolution, - transform_compose, - model_type - ) - transformed_datasets.append(transformed_ds) - - # Add remaining samples without transformation - remaining_indices = np.setdiff1d(indices, used_indices) - if len(remaining_indices) > 0: - remaining_subset = Subset(train_dataset, remaining_indices) - untransformed_ds = ISICDataset( - remaining_subset, - preprocessor, - resolution, - None, - model_type - ) - transformed_datasets.append(untransformed_ds) - - # Combine all training datasets - train_ds = ConcatDataset(transformed_datasets) - - # Create validation dataset (no transformations) - val_ds = ISICDataset( - val_dataset, - preprocessor, - resolution, - model_type=model_type, - ) - - return train_ds, val_ds - -def balance_dataset(dataset: Dataset, filtered_classes: List[str], num_train_images: int, seed: int = 42): - """ - Balance dataset by sampling equal numbers from each class. - - Args: - dataset: Input dataset - filtered_classes: Classes to keep - num_train_images: Total number of training images - seed: Random seed - - Returns: - Balanced dataset - """ - # Get class counts - class_counts = {label: 0 for label in filtered_classes} - class_indices = {label: [] for label in filtered_classes} - - for i, item in enumerate(dataset): - label_str = str(item["label"]) - if label_str in filtered_classes: - class_counts[label_str] += 1 - class_indices[label_str].append(i) - - print(f"Class counts before balancing: {class_counts}") - - # Calculate samples per class - min_class_size = min(class_counts.values()) - images_per_class = min(num_train_images // len(filtered_classes), min_class_size) - - # Sample from each class - np.random.seed(seed) - balanced_indices = [] - - for label in filtered_classes: - indices = class_indices[label] - sampled = np.random.choice(indices, images_per_class, replace=False) - balanced_indices.extend(sampled) - - np.random.shuffle(balanced_indices) - return dataset.select(balanced_indices) \ No newline at end of file diff --git a/src/data/datamodule.py b/src/data/datamodule.py index b2e9a02..ddf9dcb 100644 --- a/src/data/datamodule.py +++ b/src/data/datamodule.py @@ -1,30 +1,67 @@ -# src/data/datamodule.py -# -*- coding: utf-8 -*- -from torch.utils.data import DataLoader, random_split -from typing import Optional -from .datasets import get_dataset +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT +""" +DataModule for managing dataset loading, splitting, and dataloader creation. +Provides a unified interface for working with different datasets. +""" + +# Standard library imports +from typing import Optional, Any + +# Third-party imports +import torch # pylint: disable=import-error +from torch.utils.data import DataLoader, random_split, Subset # pylint: disable=import-error + +# Local imports +# pylint: disable=import-error +from src.data.dataset_factory import get_dataset class BaseDataModule: """ Dataset-agnostic data module. - Responsibilities: - - Instantiate train/val/test datasets using `get_dataset()`. - - Build and cache dataloaders. - - Apply consistent transforms and batching options across datasets. + - If get_dataset(dataset_name, split=...) exists, we use provided splits. + - Otherwise we create train/val via random_split from a single 'train' split. """ + # pylint: disable=too-many-instance-attributes def __init__( self, - cfg, + cfg: Any, dataset_name: str, data_dir: str, + *, num_workers: int = 8, batch_size: int = 32, pin_memory: bool = True, drop_last: bool = False, - ): + split_seed: int = 42, + preprocessor: Any = None, + resolution: int = 224, + transform: Any = None, + model_type: str = "vit", + ): # pylint: disable=too-many-arguments + """ + Initialize the DataModule with dataset and dataloader parameters. + + Args: + cfg: Configuration object + dataset_name: Name of the dataset to load + data_dir: Directory where dataset is stored + num_workers: Number of worker processes for data loading + batch_size: Number of samples per batch + pin_memory: Whether to pin memory for faster GPU transfer + drop_last: Whether to drop the last incomplete batch + split_seed: Random seed for train/val splitting + preprocessor: Data preprocessing pipeline + resolution: Image resolution to use + transform: Data augmentation transforms + model_type: Type of model (e.g., 'vit', 'cnn') + """ self.cfg = cfg self.dataset_name = dataset_name self.data_dir = data_dir @@ -32,51 +69,103 @@ def __init__( self.batch_size = batch_size self.pin_memory = pin_memory self.drop_last = drop_last + self.split_seed = split_seed + + # Plumb-through for dataset construction + self.preprocessor = preprocessor + self.resolution = resolution + self.transform = transform + self.model_type = model_type self.train_set = None self.val_set = None self.test_set = None # ------------------------------------------------------------------ - def setup(self, stage: Optional[str] = None): + def setup(self, _stage: Optional[str] = None): """ - Called once to initialize datasets. - stage: 'fit' | 'validate' | 'test' | None + Initialize datasets. _stage: 'fit' | 'validate' | 'test' | None + Tries explicit splits first; falls back to random_split from a 'train' split. + + Args: + _stage: Current stage of training pipeline (unused but kept for API compatibility) """ - ds_full = get_dataset(self.dataset_name, self.data_dir, split="train", cfg=self.cfg) - n_total = len(ds_full) - n_val = int(0.1 * n_total) - n_train = n_total - n_val - self.train_set, self.val_set = random_split(ds_full, [n_train, n_val]) + # Try to fetch explicit train/val + ds_train = get_dataset( + self.dataset_name, self.data_dir, split="train", cfg=self.cfg, + preprocessor=self.preprocessor, resolution=self.resolution, + transform=self.transform, model_type=self.model_type, + ) + + try: + ds_val = get_dataset( + self.dataset_name, self.data_dir, split="val", cfg=self.cfg, + preprocessor=self.preprocessor, resolution=self.resolution, + transform=None, model_type=self.model_type, + ) + except Exception as e: # pylint: disable=broad-exception-caught, unused-variable + # Exception is broad to handle any dataset-specific errors + ds_val = None + + try: + ds_test = get_dataset( + self.dataset_name, self.data_dir, split="test", cfg=self.cfg, + preprocessor=self.preprocessor, resolution=self.resolution, + transform=None, model_type=self.model_type, + ) + except Exception as e: # pylint: disable=broad-exception-caught, unused-variable + # Exception is broad to handle any dataset-specific errors + ds_test = None - # test set - self.test_set = get_dataset(self.dataset_name, self.data_dir, split="test", cfg=self.cfg) + if ds_val is None: + # Fallback: split train into train/val + n_total = len(ds_train) + n_val = max(1, int(0.1 * n_total)) + n_train = n_total - n_val + g = torch.Generator().manual_seed(self.split_seed) + self.train_set, self.val_set = random_split(ds_train, [n_train, n_val], generator=g) + else: + self.train_set, self.val_set = ds_train, ds_val + + self.test_set = ds_test # ------------------------------------------------------------------ def train_dataloader(self): + """ + Create and return the training data loader. + + Returns: + DataLoader: Configured training data loader + """ return DataLoader( - self.train_set, - batch_size=self.batch_size, - shuffle=True, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - drop_last=self.drop_last, + self.train_set, batch_size=self.batch_size, shuffle=True, + num_workers=self.num_workers, pin_memory=self.pin_memory, drop_last=self.drop_last, ) def val_dataloader(self): + """ + Create and return the validation data loader. + + Returns: + DataLoader: Configured validation data loader + """ return DataLoader( - self.val_set, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - pin_memory=self.pin_memory, + self.val_set, batch_size=self.batch_size, shuffle=False, + num_workers=self.num_workers, pin_memory=self.pin_memory, ) def test_dataloader(self): + """ + Create and return the test data loader. + If no test set exists, returns an empty loader. + + Returns: + DataLoader: Configured test data loader + """ + if self.test_set is None: + # Provide an empty loader if test isn't defined + return DataLoader(Subset(self.val_set, []), batch_size=self.batch_size) return DataLoader( - self.test_set, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - pin_memory=self.pin_memory, + self.test_set, batch_size=self.batch_size, shuffle=False, + num_workers=self.num_workers, pin_memory=self.pin_memory, ) diff --git a/src/data/dataset_factory.py b/src/data/dataset_factory.py new file mode 100644 index 0000000..5101b21 --- /dev/null +++ b/src/data/dataset_factory.py @@ -0,0 +1,405 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +"""Dataset implementations and data utilities.""" +# Standard library imports +from typing import Optional, List, Dict, Any, Union + +# Third-party imports +import numpy as np # pylint: disable=import-error +import torch # pylint: disable=import-error +from PIL import Image # pylint: disable=import-error +from torch.utils.data import Dataset, ConcatDataset, Subset # pylint: disable=import-error + +# Local imports +# pylint: disable=import-error,relative-beyond-top-level +from src.config import HF_MODELS, DEFAULT_IMAGE_SIZE +from src.data.isic_loader import ISICBaseDataset + +# ============================================================================ +# DATA LOADING +# ============================================================================ + +class DatasetWrapper(Dataset): + """Simple wrapper that handles dataset/subset access patterns.""" + + def __init__(self, dataset: Union[Dataset, Subset]): + self.dataset = dataset + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> Dict[str, Any]: + """Get item from dataset, satisfying Dataset abstract method.""" + return self.get_raw_item(idx) + + def get_raw_item(self, idx: int) -> Dict[str, Any]: + """Get raw item from dataset, handling both Dataset and Subset.""" + # Convert numpy types to Python int + if isinstance(idx, (np.integer, np.int64)): + idx = int(idx) + + # Handle both direct dataset and Subset access + if hasattr(self.dataset, 'dataset'): + # This is a Subset + subset_idx = int(self.dataset.indices[idx]) + return self.dataset.dataset[subset_idx] + # Direct dataset access + return self.dataset[idx] + + +# ============================================================================ +# IMAGE PREPROCESSING +# ============================================================================ + +class ImageProcessor: + """Handles image preprocessing operations.""" + + def __init__(self, resolution: int = DEFAULT_IMAGE_SIZE): + self.resolution = resolution + + def resize_image(self, image: Image.Image) -> Image.Image: + """Resize image to target resolution.""" + return image.resize( + (self.resolution, self.resolution), + Image.Resampling.LANCZOS + ) + + def apply_transforms(self, image: Image.Image, transform: Optional[Any]) -> Image.Image: + """Apply optional transformations to image.""" + if transform: + return transform(image) + return image + + +# ============================================================================ +# MODEL-SPECIFIC PREPROCESSING +# ============================================================================ + +class ModelPreprocessor: + """Handles model-specific preprocessing.""" + # pylint: disable=too-few-public-methods + + def __init__(self, preprocessor: Optional[Any] = None, model_type: str = "vit"): + self.preprocessor = preprocessor + self.model_type = model_type + + if self.preprocessor is None: + # Will no-op and return a tensorized fallback later + return + + if self.model_type not in HF_MODELS: + raise ValueError(f"Unsupported model_type: {self.model_type}") + + def preprocess(self, image: Image.Image, resolution: int) -> torch.Tensor: + """Apply model-specific preprocessing; no-op if preprocessor is None.""" + if self.preprocessor is None: + return self._convert_pil_to_tensor(image) + + # HuggingFace path + if hasattr(self.preprocessor, "size") and self.preprocessor.size != resolution: + try: + self.preprocessor.size = resolution + except (AttributeError, TypeError) as e: # pylint: disable=unused-variable + # Some processors use tuples or different APIs for size + pass + + encoding = self.preprocessor(images=image, return_tensors="pt") + return encoding["pixel_values"].squeeze(0) + + def _convert_pil_to_tensor(self, image: Image.Image) -> torch.Tensor: + """Convert PIL image to tensor with proper format.""" + # Minimal safety: convert PIL -> tensor [C,H,W] in [0,1] + arr = np.asarray(image).astype(np.float32) / 255.0 + if arr.ndim == 2: # grayscale -> [H,W] -> [H,W,1] + arr = arr[:, :, None] + # HW(C) -> CHW + arr = np.transpose(arr, (2, 0, 1)) + return torch.from_numpy(arr) + + +# ============================================================================ +# COMBINED DATASET +# ============================================================================ + +class ISICDataset(Dataset): + """ISIC dataset that combines data loading, image processing, and model preprocessing.""" + + def __init__( + self, + dataset: Union[Dataset, Subset], + *, # Force keyword arguments after first positional arg + preprocessor: Optional[Any] = None, + resolution: int = DEFAULT_IMAGE_SIZE, + transform: Optional[Any] = None, + model_type: str = "vit", + ): # pylint: disable=too-many-arguments + # Data loading + self.data_wrapper = DatasetWrapper(dataset) + + # Image processing + self.image_processor = ImageProcessor(resolution) + + # Model preprocessing + self.model_preprocessor = ModelPreprocessor(preprocessor, model_type) + + # Transformation + self.transform = transform + + def __len__(self) -> int: + return len(self.data_wrapper) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + # 1. Load raw data + item = self.data_wrapper.get_raw_item(idx) + image = item["image"] + label = item["label"] + + # 2. Process image + image = self.image_processor.resize_image(image) + image = self.image_processor.apply_transforms(image, self.transform) + + # 3. Apply model-specific preprocessing + pixel_values = self.model_preprocessor.preprocess( + image, self.image_processor.resolution + ) + + # 4. Prepare output + label = torch.tensor(label, dtype=torch.long) + return {"pixel_values": pixel_values, "labels": label} + +# ============================================================================ +# DATASET CREATION UTILITIES +# ============================================================================ + +def split_dataset_for_transforms( + dataset: Dataset, + transforms_list: List[Any], + proportion_per_transform: float +) -> List[Subset]: + """Split dataset into subsets for applying different transforms.""" + num_images = len(dataset) + images_per_transform = int(num_images * proportion_per_transform) + + # Shuffle indices + indices = np.arange(num_images) + np.random.shuffle(indices) + + subsets = [] + used_indices = [] + + # Create subset for each transform + for i, _ in enumerate(transforms_list): + start_idx = i * images_per_transform + end_idx = start_idx + images_per_transform + subset_indices = indices[start_idx:end_idx] + used_indices.extend(subset_indices) + subsets.append(Subset(dataset, subset_indices)) + + # Add remaining samples + remaining_indices = np.setdiff1d(indices, used_indices) + if len(remaining_indices) > 0: + subsets.append(Subset(dataset, remaining_indices)) + + return subsets + + +def create_transformed_datasets( + train_dataset: Dataset, + val_dataset: Dataset, + transforms_list: List[Any], + proportion_per_transform: float, + *, # Force keyword arguments + preprocessor: Optional[Any] = None, + resolution: int = DEFAULT_IMAGE_SIZE, + model_type: str = "vit" +) -> tuple[Dataset, Dataset]: + """Create train and validation datasets with transformations.""" + # pylint: disable=too-many-arguments,too-many-locals + + # Split training data into subsets + train_subsets = split_dataset_for_transforms( + train_dataset, transforms_list, proportion_per_transform + ) + + # Create datasets with transforms + transformed_datasets = [] + + # Apply each transform to corresponding subset + for _, (subset, transform) in enumerate(zip(train_subsets[:-1], transforms_list)): + transformed_ds = ISICDataset( + subset, + preprocessor=preprocessor, + resolution=resolution, + transform=transform, + model_type=model_type + ) + transformed_datasets.append(transformed_ds) + + # Add untransformed subset (if any remaining) + if len(train_subsets) > len(transforms_list): + untransformed_ds = ISICDataset( + train_subsets[-1], + preprocessor=preprocessor, + resolution=resolution, + transform=None, + model_type=model_type + ) + transformed_datasets.append(untransformed_ds) + + # Combine all training datasets + train_ds = ConcatDataset(transformed_datasets) + + # Create validation dataset (no transformations) + val_ds = ISICDataset( + val_dataset, + preprocessor=preprocessor, + resolution=resolution, + transform=None, + model_type=model_type + ) + + return train_ds, val_ds + +# ============================================================================ +# DATASET BALANCING +# ============================================================================ + +def get_class_distribution(dataset: Dataset, filtered_classes: List[str]) -> Dict[str, List[int]]: + """Get class distribution and indices.""" + class_counts = {class_label: 0 for class_label in filtered_classes} + class_indices = {class_label: [] for class_label in filtered_classes} + + for i, item in enumerate(dataset): + label_str = str(item["label"]) + if label_str in filtered_classes: + class_counts[label_str] += 1 + class_indices[label_str].append(i) + + print(f"Class counts before balancing: {class_counts}") + return class_indices + + +def sample_balanced_indices( + class_indices: Dict[str, List[int]], + num_train_images: int, + seed: int = 42 +) -> List[int]: + """Sample balanced indices from each class.""" + np.random.seed(seed) + + # Calculate samples per class + num_classes = len(class_indices) + min_class_size = min(len(indices) for indices in class_indices.values()) + images_per_class = min(num_train_images // num_classes, min_class_size) + + print(f"Sampling {images_per_class} images per class") + + # Sample from each class + balanced_indices = [] + for _label, indices in class_indices.items(): + sampled = np.random.choice(indices, images_per_class, replace=False) + balanced_indices.extend(sampled) + + np.random.shuffle(balanced_indices) + return balanced_indices + + +def balance_dataset( + dataset: Dataset, + filtered_classes: List[str], + num_train_images: int, + seed: int = 42 +) -> Dataset: + """ + Balance dataset by sampling equal numbers from each class. + - If `dataset` is a PyTorch Dataset: returns a torch.utils.data.Subset + - If `dataset` is an HF Dataset (has `.select`): returns a selected HF Dataset + """ + class_indices = get_class_distribution(dataset, filtered_classes) + balanced_indices = sample_balanced_indices(class_indices, num_train_images, seed) + + # HF Dataset has .select; PyTorch does not + if hasattr(dataset, "select"): + return dataset.select(balanced_indices) + return Subset(dataset, balanced_indices) + +def _load_isic_split(data_dir: str, split: str) -> Dataset: + """ + Replace with your real ISIC split loader that yields dicts: + {"image": PIL.Image, "label": int} + + Expected interface: + class ISICRawSplit(Dataset): + def __init__(self, data_dir: str, split: str): ... + def __getitem__(self, i) -> {"image": PIL.Image, "label": int} + def __len__(self) -> int: ... + + If you already have it elsewhere, just import and return it here. + """ + # pylint: disable=import-outside-toplevel,relative-beyond-top-level,import-error + + try: + from src.data.isic_raw import ISICRawSplit # Use absolute import + except ImportError as e: + raise ImportError( + "ISICRawSplit not found. Create src/data/isic_raw.py with an ISICRawSplit " + "that returns {'image': PIL.Image, 'label': int}." + ) from e + + return ISICRawSplit(data_dir, split) + + +def get_dataset( + dataset_name: str, + data_dir: str, + split: str, + cfg=None, # Kept for API compatibility + *, + preprocessor=None, + resolution: int = DEFAULT_IMAGE_SIZE, + transform=None, + model_type: str = "vit", + mode: str = "model_ready", # "raw" or "model_ready" +) -> Dataset: + """ + Unified dataset factory used by BaseDataModule. + + Args: + dataset_name: e.g., "isic2019" + data_dir: root directory for the dataset + split: "train" | "val" | "test" (if "val" doesn't exist, DataModule can split) + cfg: optional config (unused but kept for API compatibility) + preprocessor: HF image processor or similar (used by model_ready) + resolution: model input resolution (used by model_ready) + transform: optional PIL->PIL degradation transform (applied before preprocess) + model_type: str key you use in HF_MODELS + mode: "raw" (no transforms/preprocessing) or "model_ready" (pipeline) + + Returns: + A torch.utils.data.Dataset + """ + # pylint: disable=too-many-arguments,unused-argument + + name = dataset_name.lower() + + if name in {"isic", "isic2019", "dermatology"}: + base_ds = _load_isic_split(data_dir, split) + + if mode == "raw": + # No resize/degradation or model preprocessing + return ISICBaseDataset(base_ds) + + # model_ready: resize + optional degradation + HF preprocessing + return ISICDataset( + dataset=base_ds, + preprocessor=preprocessor, + resolution=resolution, + transform=transform, + model_type=model_type, + ) + + raise ValueError(f"Unknown dataset_name: {dataset_name}") diff --git a/src/data/datasets.py b/src/data/datasets.py deleted file mode 100644 index 8568949..0000000 --- a/src/data/datasets.py +++ /dev/null @@ -1,14 +0,0 @@ -# src/data/datasets.py -from src.data.isic_dataset import ISICDataset -from src.data.tcga_dataset import TCGADataset -from src.data.merlin_dataset import MerlinDataset - -def get_dataset(dataset_name, data_dir, split, cfg): - if dataset_name.lower() == "isic": - return ISICDataset(data_dir=data_dir, split=split, cfg=cfg) - elif dataset_name.lower() == "tcga": - return TCGADataset(data_dir=data_dir, split=split, cfg=cfg) - elif dataset_name.lower() == "merlin": - return MerlinDataset(data_dir=data_dir, split=split, cfg=cfg) - else: - raise ValueError(f"Unknown dataset: {dataset_name}") diff --git a/src/data/isic_loader.py b/src/data/isic_loader.py new file mode 100644 index 0000000..7d0b7f6 --- /dev/null +++ b/src/data/isic_loader.py @@ -0,0 +1,170 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +"""ISIC dataset loader implementation for dermatology image datasets. + +This module provides a Hugging Face–backed loader for ISIC that returns +dicts shaped like: + {"image": PIL.Image, "label": int} + +Typical usage (HF-only, no DataModule required): + from src.data.isic_loader import ISICHFRawSplit + ds = ISICHFRawSplit(repo_id="MKZuziak/ISIC_2019_224", split="train") + sample = ds[0] + image, label = sample["image"], sample["label"] + +You may pass a PIL->PIL transform (e.g., ResolutionReductionTransform) via `transform=...`. +""" + +from __future__ import annotations + +from typing import Any, Dict, Optional, Sequence +from torch.utils.data import Dataset # pylint: disable=import-error +from PIL import Image + +try: + # Hugging Face datasets is required for this loader + from datasets import load_dataset, Image as HFImageFeature # type: ignore +except Exception as _e: # pragma: no cover + load_dataset = None + HFImageFeature = None + _HF_IMPORT_ERROR = _e + + +def _to_pil(x: Any) -> Image.Image: + """Best-effort conversion to PIL.Image.""" + if isinstance(x, Image.Image): + return x + try: + # HF Image feature usually yields PIL already; if not, try convert + return x.convert("RGB") + except Exception: + import numpy as np + return Image.fromarray(np.asarray(x)).convert("RGB") + + +class ISICHFRawSplit(Dataset): + """ + Hugging Face–backed ISIC split. Returns dicts: + {"image": PIL.Image, "label": int} + + Parameters + ---------- + repo_id : str + HF dataset repo id. Default: "MKZuziak/ISIC_2019_224" + split : str + Split name to load (e.g., "train"). Depends on the repo. + cache_dir : Optional[str] + Local HF cache directory. + image_column : str + Column name for image. Default: "image" + label_column : str + Column name for label. Default: "label" + transform : Optional[callable] + Optional PIL->PIL transform (e.g., ResolutionReductionTransform). + filter_fn : Optional[callable] + Optional row-level filter: receives an item dict and returns True/False. + If provided, the dataset builds an index of rows where filter_fn(item) is True. + keep_indices : Optional[Sequence[int]] + Optional explicit list of indices to keep (applied after filter_fn if both are given). + + Notes + ----- + - Model-specific preprocessing (tensor conversion, normalization) is intentionally not applied here. + That belongs in your model-ready pipeline/DataModule. + """ + + def __init__( + self, + *, + repo_id: str = "MKZuziak/ISIC_2019_224", + split: str = "train", + cache_dir: Optional[str] = None, + image_column: str = "image", + label_column: str = "label", + transform: Optional[Any] = None, + filter_fn: Optional[Any] = None, + keep_indices: Optional[Sequence[int]] = None, + ): + if load_dataset is None: # pragma: no cover + raise ImportError( + "Hugging Face `datasets` is required for ISICHFRawSplit. " + "Install via `pip install datasets`. " + f"Original import error: {_HF_IMPORT_ERROR}" + ) + + self.repo_id = repo_id + self.split = split + self.cache_dir = cache_dir + self.image_column = image_column + self.label_column = label_column + self.transform = transform + + # Load HF dataset split + self.ds = load_dataset(repo_id, split=split, cache_dir=cache_dir) + + # Validate image column if possible + try: + feats = self.ds.features + if image_column in feats and HFImageFeature is not None: + # If it's an HF Image feature, decoding to PIL happens on access. + pass + except Exception: + pass # proceed; we'll coerce per-sample in __getitem__ + + # Build index (filter + keep_indices) + idx = list(range(len(self.ds))) + if filter_fn is not None: + filtered = [] + for i in idx: + try: + if filter_fn(self.ds[i]): + filtered.append(i) + except Exception: + # Skip rows that fail the filter function + continue + idx = filtered + if keep_indices is not None: + # Intersect in original order + keep_set = set(int(k) for k in keep_indices) + idx = [i for i in idx if i in keep_set] + + self._indices = idx + + self.class_names = None + try: + label_feat = self.ds.features.get(label_column, None) + names = getattr(label_feat, "names", None) + if names: + self.class_names = tuple(names) + except Exception: + self.class_names = None + + def __len__(self) -> int: + return len(self._indices) + + def __getitem__(self, idx: int) -> Dict[str, Any]: + real_idx = int(self._indices[idx]) + item = self.ds[real_idx] + + image = _to_pil(item[self.image_column]) + label = int(item[self.label_column]) + + if self.transform is not None: + image = self.transform(image) + + return {"image": image, "label": label} + + +class ISICBaseDataset(ISICHFRawSplit): # type: ignore[misc] + """ + Backwards-compatible alias for the previous ISICBaseDataset, now backed by HF. + + Usage (unchanged import path): + from src.data.isic_loader import ISICBaseDataset + ds = ISICBaseDataset(repo_id="MKZuziak/ISIC_2019_224", split="train") + """ + pass \ No newline at end of file diff --git a/src/engines/finetune_engine.py b/src/engines/finetune_engine.py index e69de29..f014bd3 100644 --- a/src/engines/finetune_engine.py +++ b/src/engines/finetune_engine.py @@ -0,0 +1,181 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University +# SPDX-License-Identifier: MIT + +# -*- coding: utf-8 -*- +"""Training engine for full fine-tuning (end-to-end).""" +# pylint: disable=duplicate-code +from __future__ import annotations +from typing import Dict, Any, Tuple, Optional + +import math +import torch +from torch import nn + +try: + # PyTorch 2.0+ unified AMP API + from torch.amp import autocast +except ImportError: + # Fallback for older PyTorch versions + from torch.cuda.amp import autocast + +# pylint: disable=import-error +from src.utils.logging_core import get_logger +from src.engines.training_core import ( + _maybe_scheduler_step, + _create_grad_scaler, + _update_history_and_log, + _preprocess_batch, + _run_validation_and_scheduler, + _update_best_model_state, +) + +log = get_logger(__name__) + + +def train_finetune( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + *, + model: nn.Module, + loaders: Dict[str, Any], # {"train": DataLoader, "val": DataLoader} + loss_fn, + optimizer: torch.optim.Optimizer, + scheduler: Optional[Tuple[Any, Dict[str, Any]]] = None, # (scheduler, meta) + device: torch.device, + epochs: int, + grad_clip: Optional[float] = None, + mixed_precision: bool = True, + log_interval: int = 50, + wandb_logger=None, + metric_key: str = "val_acc", + accumulation_steps: int = 1, + zero_grad_set_to_none: bool = True, +) -> Dict[str, Any]: + """ + Generic engine for full fine-tuning. Agnostic to datasets & transforms. + + Args: + model: torch.nn.Module to train end-to-end. + loaders: dict with "train" and "val" DataLoaders. + loss_fn: callable (logits, targets) -> loss tensor. + optimizer: torch optimizer over model parameters. + scheduler: optional (scheduler, meta) where meta may contain: + - step_per: {"batch","epoch","val"} (default "epoch") + - monitor: metric value for schedulers like ReduceLROnPlateau (filled automatically) + device: torch.device. + epochs: number of epochs. + grad_clip: optional float, max norm for gradient clipping. + mixed_precision: bool, enable autocast + GradScaler. + log_interval: int steps for logging. + wandb_logger: object with .log(dict) and optional .watch(...) methods. + metric_key: "val_acc" (maximize) or "...loss" (minimize) to pick best checkpoint. + accumulation_steps: gradient accumulation steps (>1 to simulate larger batch). + zero_grad_set_to_none: pass to optimizer.zero_grad for perf. + + Returns: + dict with "best_metric", "history", and "final_lr". + """ + assert accumulation_steps >= 1, "accumulation_steps must be >= 1" + + # Initialize GradScaler with backward compatibility + scaler = _create_grad_scaler(mixed_precision) + + sched, sched_meta = scheduler or (None, {}) + best_metric = -math.inf if not metric_key.endswith("loss") else math.inf + best_state = None + + history = {"train_loss": [], "val_loss": [], "val_acc": [], "lr": []} + + for epoch in range(1, epochs + 1): + model.train() + running_loss, n_seen = 0.0, 0 + optimizer.zero_grad(set_to_none=zero_grad_set_to_none) + + for step, batch in enumerate(loaders["train"], start=1): + x, y = _preprocess_batch(batch, device) + + with autocast(device_type=device.type, enabled=mixed_precision): + logits = model(x) + loss = loss_fn(logits, y) + loss_to_backprop = loss / accumulation_steps + + if mixed_precision: + scaler.scale(loss_to_backprop).backward() + else: + loss_to_backprop.backward() + + # Step on accumulation boundary + if step % accumulation_steps == 0: + if grad_clip is not None: + if mixed_precision: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + + if mixed_precision: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + + optimizer.zero_grad(set_to_none=zero_grad_set_to_none) + + # Per-batch scheduler step (if configured) + if sched is not None: + _maybe_scheduler_step(sched_meta, sched, on="batch") + + bsz = y.size(0) + running_loss += float(loss.item()) * bsz + n_seen += bsz + + if step % log_interval == 0: + cur_lr = optimizer.param_groups[0]["lr"] + if wandb_logger: + wandb_logger.log({"train/loss": float(loss.item()), "lr": cur_lr}) + + # ---- validation and scheduler step + val_loss, val_acc = _run_validation_and_scheduler( + model=model, + loaders=loaders, + loss_fn=loss_fn, + device=device, + mixed_precision=mixed_precision, + sched=sched, + sched_meta=sched_meta, + metric_key=metric_key, + ) + + # Aggregate + log + train_loss = running_loss / max(n_seen, 1) + cur_lr = optimizer.param_groups[0]["lr"] + + _update_history_and_log( + history=history, + epoch=epoch, + train_loss=train_loss, + val_loss=val_loss, + val_acc=val_acc, + cur_lr=cur_lr, + wandb_logger=wandb_logger, + log=log, + ) + + updated_state, best_metric, is_better = _update_best_model_state( + model=model, + metric_key=metric_key, + val_loss=val_loss, + val_acc=val_acc, + best_metric=best_metric, + ) + + if is_better: + best_state = updated_state + + # Restore best weights so caller can save/export + if best_state is not None: + model.load_state_dict(best_state) + + return { + "best_metric": best_metric, + "history": history, + "final_lr": optimizer.param_groups[0]["lr"], + } diff --git a/src/engines/linear_probe_engine.py b/src/engines/linear_probe_engine.py index 9567b6f..9b5d5e4 100644 --- a/src/engines/linear_probe_engine.py +++ b/src/engines/linear_probe_engine.py @@ -1,224 +1,153 @@ -# src/engines/linear_probe_engine.py +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University +# SPDX-License-Identifier: MIT + # -*- coding: utf-8 -*- -"""Linear probing engine for training classification heads on frozen backbones.""" +"""Training engine for linear probing.""" +# pylint: disable=duplicate-code from __future__ import annotations from typing import Dict, Any, Tuple, Optional + import math import torch from torch import nn -from torch.utils.data import DataLoader + +# --- AMP import (robust across versions) --- +try: + # PyTorch 2.0+ unified AMP API + from torch.amp import autocast +except ImportError: + # Fallback for older PyTorch versions + from torch.cuda.amp import autocast # pylint: disable=import-error -from src.utils.logging import get_logger, MetricAverager, WandbLogger -from src.utils.optim import step_scheduler +from src.utils.logging_core import get_logger +from src.engines.training_core import ( + _maybe_scheduler_step, + _create_grad_scaler, + _update_history_and_log, + _preprocess_batch, + _run_validation_and_scheduler, + _update_best_model_state, +) log = get_logger(__name__) -def _unpack_loaders(loaders: Any) -> Tuple[DataLoader, Optional[DataLoader], Optional[DataLoader]]: - """ - Accepts either: - - {"train": ..., "val": ..., "test": ...} - - (train_loader, val_loader) - Returns (train, val, test) - """ - if isinstance(loaders, dict): - return loaders.get("train"), loaders.get("val"), loaders.get("test") - if isinstance(loaders, (tuple, list)) and len(loaders) >= 2: - return loaders[0], loaders[1], loaders[2] if len(loaders) > 2 else None - raise ValueError( - "`loaders` must be a dict with keys train/val[/test] or a (train,val) tuple." - ) - - -@torch.no_grad() -def _evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Dict[str, float]: - model.eval() - correct = 0 - total = 0 - running_loss = 0.0 - - # if the model exposes a criterion attribute, prefer the external loss - # passed in train loop anyway. - ce = nn.CrossEntropyLoss() - - for batch in loader: - # support (x,y) or dict - if isinstance(batch, dict): - x = (batch.get("pixel_values") or - batch.get("images") or - batch.get("x")) - y = batch.get("labels") or batch.get("y") - else: - x, y = batch - - x = x.to(device, non_blocking=True) - y = y.to(device, non_blocking=True) - - logits = model(x) - if isinstance(logits, (tuple, list)): - logits = logits[0] - loss = ce(logits, y) - running_loss += float(loss.item()) * x.size(0) - - preds = logits.argmax(dim=-1) - correct += int((preds == y).sum().item()) - total += int(y.numel()) - - acc = correct / max(total, 1) - avg_loss = running_loss / max(total, 1) - return {"val_loss": avg_loss, "val_acc": acc} - - -def train_probe( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-branches,too-many-statements +def train_probe( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + *, model: nn.Module, - loaders: Any, - loss_fn, # callable: (logits, targets) -> loss tensor + loaders: Dict[str, Any], # {"train": DataLoader, "val": DataLoader} + loss_fn, optimizer: torch.optim.Optimizer, - scheduler: Any = None, # either a scheduler or (scheduler, meta) - device: torch.device = torch.device("cpu"), - epochs: int = 10, + scheduler: Optional[Tuple[Any, Dict[str, Any]]] = None, # (scheduler, meta) + device: torch.device, + epochs: int, grad_clip: Optional[float] = None, mixed_precision: bool = True, log_interval: int = 50, - wandb_logger: Optional[WandbLogger] = None, + wandb_logger=None, metric_key: str = "val_acc", ) -> Dict[str, Any]: """ - Linear probing engine. - Expects the backbone to be frozen already; only the head should have requires_grad=True. + Generic engine for linear probing. Agnostic to dataset & transforms. + Returns a dict with best metric, histories, and final lr. """ - train_loader, val_loader, _ = _unpack_loaders(loaders) - assert train_loader is not None, "train_loader is required" - assert val_loader is not None, "val_loader is required" + model.train() - scaler = torch.amp.GradScaler(enabled=(mixed_precision and device.type == "cuda")) + # Initialize GradScaler safely across torch versions + scaler = _create_grad_scaler(mixed_precision) - # unwrap (scheduler, meta) if it came from our utils - sched, sched_meta = None, {"by": "epoch"} - if scheduler is not None: - if isinstance(scheduler, tuple) and len(scheduler) == 2: - sched, sched_meta = scheduler[0], scheduler[1] - else: - sched = scheduler + sched, sched_meta = scheduler or (None, {}) + best_metric = -math.inf if not metric_key.endswith("loss") else math.inf + best_state_dict = None - # sanity: log trainable parameter ratio - total_params = sum(p.numel() for p in model.parameters()) - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - log.info(f"[probe] trainable params: {trainable_params:,} / {total_params:,} " - f"({100.0 * trainable_params / max(total_params,1):.2f}%)") - - history = {"train_loss": [], "val_loss": [], "val_acc": []} - best_metric = -math.inf - best_state: Optional[Dict[str, torch.Tensor]] = None - - model.to(device) + history = {"train_loss": [], "val_loss": [], "val_acc": [], "lr": []} for epoch in range(1, epochs + 1): model.train() - ma = MetricAverager() - - for it, batch in enumerate(train_loader, start=1): - if isinstance(batch, dict): - x = (batch.get("pixel_values") or - batch.get("images") or - batch.get("x")) - y = batch.get("labels") or batch.get("y") - else: - x, y = batch + running_loss, n_seen = 0.0, 0 - x = x.to(device, non_blocking=True) - y = y.to(device, non_blocking=True) + for step, batch in enumerate(loaders["train"], start=1): + x, y = _preprocess_batch(batch, device) optimizer.zero_grad(set_to_none=True) + with autocast(device_type=device.type, enabled=mixed_precision): + logits = model(x) + loss = loss_fn(logits, y) - if scaler.is_enabled(): - with torch.amp.autocast(device_type=device.type): - logits = model(x) - if isinstance(logits, (tuple, list)): - logits = logits[0] - loss = loss_fn(logits, y) + if mixed_precision: scaler.scale(loss).backward() - if grad_clip is not None and grad_clip > 0: + + if grad_clip is not None: + # Unscale before clipping scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_( - filter(lambda p: p.requires_grad, model.parameters()), - max_norm=grad_clip - ) + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + + # Step and update scaler scaler.step(optimizer) scaler.update() else: - logits = model(x) - if isinstance(logits, (tuple, list)): - logits = logits[0] - loss = loss_fn(logits, y) loss.backward() - if grad_clip is not None and grad_clip > 0: - torch.nn.utils.clip_grad_norm_( - filter(lambda p: p.requires_grad, model.parameters()), - max_norm=grad_clip - ) + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() - # running metrics - with torch.no_grad(): - preds = logits.argmax(dim=-1) - acc = (preds == y).float().mean().item() - ma.update(loss=float(loss.item()), acc=acc) + if sched is not None: + _maybe_scheduler_step(sched_meta, sched, on="batch") - if it % log_interval == 0: - avgs = ma.averages() - log.info(f"[probe] epoch {epoch:03d} iter {it:05d} " - f"| loss {avgs['loss']:.4f} | acc {avgs['acc']:.4f}") - if wandb_logger: - wandb_logger.log({ - "train/loss": avgs["loss"], - "train/acc": avgs["acc"], - "epoch": epoch, - "iter": it, - }) - - # end epoch → validation - val_metrics = _evaluate(model, val_loader, device=device) - history["train_loss"].append(ma.averages()["loss"]) - history["val_loss"].append(val_metrics["val_loss"]) - history["val_acc"].append(val_metrics["val_acc"]) - - log.info(f"[probe] epoch {epoch:03d} | val_loss {val_metrics['val_loss']:.4f} " - f"| val_acc {val_metrics['val_acc']:.4f}") - - if wandb_logger: - wandb_logger.log({ - "val/loss": val_metrics["val_loss"], - "val/acc": val_metrics["val_acc"], - "epoch": epoch, - }) - - # scheduler step - if sched is not None: - # if ReduceLROnPlateau, step on val metric; else per-epoch - if hasattr(sched, "step") and sched_meta.get("by") == "val_metric": - # for plateau: lower is better typically; if you want higher-better, pass negative - step_scheduler(sched, sched_meta, epoch=epoch, val_metric=val_metrics["val_loss"]) - else: - step_scheduler(sched, sched_meta, epoch=epoch) + running_loss += float(loss.item()) * y.size(0) + n_seen += y.size(0) - # track best - current = val_metrics.get(metric_key, val_metrics["val_acc"]) - if current > best_metric: - best_metric = current - best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} - - # restore best weights (optional but standard) - if best_state is not None: - model.load_state_dict(best_state) + if step % log_interval == 0: + cur_lr = optimizer.param_groups[0]["lr"] + if wandb_logger: + wandb_logger.log({"train/loss": float(loss.item()), "lr": cur_lr}) + + train_loss = running_loss / max(n_seen, 1) + + # ---- validation and scheduler step + val_loss, val_acc = _run_validation_and_scheduler( + model=model, + loaders=loaders, + loss_fn=loss_fn, + device=device, + mixed_precision=mixed_precision, + sched=sched, + sched_meta=sched_meta, + metric_key=metric_key, + ) + + # logging + cur_lr = optimizer.param_groups[0]["lr"] + _update_history_and_log( + history=history, + epoch=epoch, + train_loss=train_loss, + val_loss=val_loss, + val_acc=val_acc, + cur_lr=cur_lr, + wandb_logger=wandb_logger, + log=log, + ) + + updated_state, best_metric, is_better = _update_best_model_state( + model=model, + metric_key=metric_key, + val_loss=val_loss, + val_acc=val_acc, + best_metric=best_metric, + ) + if is_better: + best_state_dict = updated_state + + # restore best (optional: caller can save now) + if best_state_dict is not None: + model.load_state_dict(best_state_dict) return { + "best_metric": best_metric, "history": history, - "best_metric": float(best_metric), - "best_metric_name": metric_key, - "final_val_acc": float(history["val_acc"][-1]) if history["val_acc"] else None, - "final_val_loss": float(history["val_loss"][-1]) if history["val_loss"] else None, - "trainable_params": trainable_params, - "total_params": total_params, + "final_lr": optimizer.param_groups[0]["lr"], } diff --git a/src/engines/training_core.py b/src/engines/training_core.py new file mode 100644 index 0000000..d914703 --- /dev/null +++ b/src/engines/training_core.py @@ -0,0 +1,276 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University +# SPDX-License-Identifier: MIT + +# -*- coding: utf-8 -*- +# pylint: disable=duplicate-code +"""Training engine utilities for fine-tuning and linear probing. + +This module provides helper functions used by other training engines: +- _accuracy: Computes top-1 accuracy +- _maybe_scheduler_step: Step schedulers based on configuration/meta +- _evaluate: Evaluate model on a loader with loss + accuracy +- _create_grad_scaler: Create a GradScaler with common configuration +- _update_history_and_log: Update training history and log results +- _is_metric_better: Check if current metric is better than best metric + +Imported by: +- src.engines.linear_probe_engine +- src.engines.finetune_engine +""" +from __future__ import annotations +from typing import Dict, Any, Tuple, List, Optional + +import torch +from torch import nn + +try: + # PyTorch 2.0+ unified AMP API + from torch.amp import autocast, GradScaler +except ImportError: + # Fallback for older PyTorch versions + from torch.cuda.amp import autocast, GradScaler + + +def _accuracy(logits: torch.Tensor, targets: torch.Tensor) -> float: + with torch.no_grad(): + preds = torch.argmax(logits, dim=1) + correct = (preds == targets).sum().item() + total = targets.numel() + return correct / max(total, 1) + + +def _maybe_scheduler_step(scheduler_meta: Dict[str, Any], scheduler, *, on: str): + """Step scheduler if meta says so. `on` ∈ {'batch','epoch','val'}.""" + step_when = str(scheduler_meta.get("step_per", "epoch")) + if step_when == on: + if "monitor" in scheduler_meta: + # e.g., ReduceLROnPlateau + scheduler.step(scheduler_meta["monitor"]) + else: + scheduler.step() + + +def _evaluate( + model: nn.Module, + loader, + loss_fn, + device: torch.device, + mixed_precision: bool, +) -> Tuple[float, float]: + model.eval() + running_loss, running_acc, n = 0.0, 0.0, 0 + with torch.no_grad(): + for batch in loader: + x, y = _preprocess_batch(batch, device) + + # Device-aware autocast (CPU/GPU) and version-safe + with autocast( + device_type=getattr(device, "type", "cuda"), enabled=mixed_precision + ): + logits = model(x) + loss = loss_fn(logits, y) + + bsz = y.size(0) + running_loss += float(loss.item()) * bsz + running_acc += _accuracy(logits, y) * bsz + n += bsz + + return running_loss / max(n, 1), running_acc / max(n, 1) + + +def _create_grad_scaler(mixed_precision: bool = True) -> GradScaler: + """Create a GradScaler with common configuration for mixed precision training. + + Args: + mixed_precision: Whether to enable mixed precision training. + + Returns: + A configured GradScaler instance. + """ + try: + return GradScaler( + enabled=mixed_precision, + init_scale=2.0**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + ) + except TypeError: + # Fallback for older PyTorch versions that don't support all parameters + return GradScaler(enabled=mixed_precision) + + +def _update_history_and_log( # pylint: disable=too-many-arguments + *, + history: Dict[str, List[float]], + epoch: int, + train_loss: float, + val_loss: float, + val_acc: float, + cur_lr: float, + wandb_logger: Optional[Any] = None, + log: Optional[Any] = None, +) -> None: + """Update training history and log results. + + Args: + history: Dictionary to store training history. + epoch: Current epoch number. + train_loss: Training loss for current epoch. + val_loss: Validation loss for current epoch. + val_acc: Validation accuracy for current epoch. + cur_lr: Current learning rate. + wandb_logger: Optional wandb logger instance. + log: Optional logger instance. + """ + # Update history + history["train_loss"].append(train_loss) + history["val_loss"].append(val_loss) + history["val_acc"].append(val_acc) + history["lr"].append(cur_lr) + + # Log to wandb if available + if wandb_logger: + wandb_logger.log( + { + "epoch": epoch, + "train/loss_epoch": train_loss, + "val/loss": val_loss, + "val/acc": val_acc, + "lr": cur_lr, + } + ) + + # Log to console if logger is available + if log: + log.info( + "Epoch %d | train_loss=%.4f | val_loss=%.4f | val_acc=%.4f | lr=%.2e", + epoch, + train_loss, + val_loss, + val_acc, + cur_lr, + ) + + +def _is_metric_better( + metric_key: str, current_metric: float, best_metric: float +) -> Tuple[bool, float]: + """Check if current metric is better than best metric. + + Args: + metric_key: Metric name, if ends with "loss" will minimize, otherwise maximize. + current_metric: Current metric value. + best_metric: Best metric value so far. + + Returns: + Tuple of (is_better, best_metric_value) + """ + minimize = metric_key.endswith("loss") + is_better = ( + (current_metric < best_metric) if minimize else (current_metric > best_metric) + ) + + if is_better: + return True, current_metric + return False, best_metric + + +def _preprocess_batch(batch, device): + """Preprocess a batch by moving data to the appropriate device. + + Args: + batch: Either a dictionary with 'image'/'label' keys or a tuple (x, y). + device: The target device for tensors. + + Returns: + Tuple of (input_tensor, target_tensor) on the specified device. + """ + if isinstance(batch, dict): + x, y = batch.get("image", batch.get("pixel_values")), batch.get( + "label", batch.get("labels") + ) + else: + x, y = batch + + return x.to(device, non_blocking=True), y.to(device, non_blocking=True) + + +def _run_validation_and_scheduler( # pylint: disable=too-many-arguments + *, + model: nn.Module, + loaders: Dict[str, Any], + loss_fn: Any, + device: torch.device, + mixed_precision: bool, + sched: Any, + sched_meta: Dict[str, Any], + metric_key: str, +) -> Tuple[float, float]: + """Run validation and handle scheduler steps. + + Args: + model: The model to evaluate. + loaders: Dictionary of data loaders. + loss_fn: Loss function. + device: Target device. + mixed_precision: Whether to use mixed precision. + sched: Optional scheduler. + sched_meta: Scheduler metadata. + metric_key: Metric key to monitor. + + Returns: + Tuple of (validation_loss, validation_accuracy). + """ + # Run validation + val_loss, val_acc = _evaluate( + model=model, + loader=loaders["val"], + loss_fn=loss_fn, + device=device, + mixed_precision=mixed_precision, + ) + + # Handle scheduler steps + if sched is not None: + if isinstance(sched, torch.optim.lr_scheduler.ReduceLROnPlateau): + sched_meta["monitor"] = val_loss if metric_key.endswith("loss") else val_acc + _maybe_scheduler_step(sched_meta, sched, on="val") + else: + _maybe_scheduler_step(sched_meta, sched, on="epoch") + + return val_loss, val_acc + + +def _update_best_model_state( + *, + model: nn.Module, + metric_key: str, + val_loss: float, + val_acc: float, + best_metric: float, +) -> Tuple[Dict[str, torch.Tensor], float, bool]: + """Update best model state if current metric is better. + + Args: + model: The model to get state from. + metric_key: Metric key to monitor. + val_loss: Validation loss. + val_acc: Validation accuracy. + best_metric: Best metric value so far. + + Returns: + Tuple of (best_state_dict, best_metric, is_better). + """ + monitor = val_loss if metric_key.endswith("loss") else val_acc + is_better, updated_best_metric = _is_metric_better(metric_key, monitor, best_metric) + + best_state_dict = None + if is_better: + best_state_dict = { + k: v.detach().cpu().clone() for k, v in model.state_dict().items() + } + + return best_state_dict, updated_best_metric, is_better diff --git a/src/evaluation/metrics.py b/src/evaluation/metrics.py index 459c275..de693be 100644 --- a/src/evaluation/metrics.py +++ b/src/evaluation/metrics.py @@ -1,3 +1,16 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +"""Evaluation metrics for classification tasks.""" +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/metrics/metrics.py # -*- coding: utf-8 -*- from __future__ import annotations diff --git a/src/evaluation/visualization.py b/src/evaluation/visualization.py index 913f99d..216414b 100644 --- a/src/evaluation/visualization.py +++ b/src/evaluation/visualization.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/metrics/visualization.py # -*- coding: utf-8 -*- from __future__ import annotations diff --git a/src/evaluation/visualize_results.py b/src/evaluation/visualize_results.py index 3e3fb53..2da32d1 100644 --- a/src/evaluation/visualize_results.py +++ b/src/evaluation/visualize_results.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/visualize_results.py """Visualize robustness results across models and degradations.""" import json @@ -9,138 +15,138 @@ def create_robustness_heatmap(results_path: str): """Create heatmap showing model performance across degradations.""" - + # Load results with open(results_path, 'r') as f: results = json.load(f) - + # Prepare data for heatmap models = [] degradations = [] accuracies = [] - + for training_mode in ['finetune', 'linear_probe']: for model_name, model_results in results[training_mode].items(): if 'error' not in model_results: eval_results = model_results['eval_results_by_degradation'] - + for degradation, metrics in eval_results.items(): models.append(f"{model_name}_{training_mode[:2]}") degradations.append(degradation) accuracies.append(metrics['accuracy']) - + # Create DataFrame df = pd.DataFrame({ 'Model': models, 'Degradation': degradations, 'Accuracy': accuracies }) - + # Pivot for heatmap pivot_df = df.pivot(index='Model', columns='Degradation', values='Accuracy') - + # Reorder columns logically - column_order = ['clean', + column_order = ['clean', 'jpeg_90', 'jpeg_50', 'jpeg_20', 'blur_1.0', 'blur_3.0', 'blur_5.0', 'color_64', 'color_16', 'color_4'] pivot_df = pivot_df[column_order] - + # Create figure plt.figure(figsize=(14, 8)) - + # Create heatmap - sns.heatmap(pivot_df, - annot=True, - fmt='.3f', + sns.heatmap(pivot_df, + annot=True, + fmt='.3f', cmap='RdYlGn', - vmin=0.5, + vmin=0.5, vmax=1.0, cbar_kws={'label': 'Accuracy'}) - + plt.title('Model Robustness Across Different Degradations') plt.xlabel('Degradation Type') plt.ylabel('Model (ft=finetune, lp=linear probe)') plt.tight_layout() - + # Save figure output_path = Path(results_path).parent / 'robustness_heatmap.png' plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.show() - + return pivot_df def create_degradation_curves(results_path: str): """Create line plots showing accuracy degradation.""" - + with open(results_path, 'r') as f: results = json.load(f) - + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) - + # JPEG degradation jpeg_qualities = [90, 50, 20] ax = axes[0] - + for training_mode in ['finetune', 'linear_probe']: for model_name, model_results in results[training_mode].items(): if 'error' not in model_results: eval_results = model_results['eval_results_by_degradation'] - + clean_acc = eval_results['clean']['accuracy'] jpeg_accs = [eval_results[f'jpeg_{q}']['accuracy'] for q in jpeg_qualities] - + label = f"{model_name} ({training_mode})" - ax.plot([100] + jpeg_qualities, [clean_acc] + jpeg_accs, + ax.plot([100] + jpeg_qualities, [clean_acc] + jpeg_accs, marker='o', label=label) - + ax.set_xlabel('JPEG Quality') ax.set_ylabel('Accuracy') ax.set_title('JPEG Compression Robustness') ax.legend() ax.grid(True, alpha=0.3) ax.invert_xaxis() - + # Blur degradation blur_radii = [0, 1.0, 3.0, 5.0] ax = axes[1] - + for training_mode in ['finetune', 'linear_probe']: for model_name, model_results in results[training_mode].items(): if 'error' not in model_results: eval_results = model_results['eval_results_by_degradation'] - + blur_accs = [] blur_accs.append(eval_results['clean']['accuracy']) for r in blur_radii[1:]: blur_accs.append(eval_results[f'blur_{r:.1f}']['accuracy']) - + label = f"{model_name} ({training_mode})" ax.plot(blur_radii, blur_accs, marker='o', label=label) - + ax.set_xlabel('Blur Radius') ax.set_ylabel('Accuracy') ax.set_title('Gaussian Blur Robustness') ax.legend() ax.grid(True, alpha=0.3) - + # Color quantization color_levels = [256, 64, 16, 4] ax = axes[2] - + for training_mode in ['finetune', 'linear_probe']: for model_name, model_results in results[training_mode].items(): if 'error' not in model_results: eval_results = model_results['eval_results_by_degradation'] - + color_accs = [] color_accs.append(eval_results['clean']['accuracy']) for c in color_levels[1:]: color_accs.append(eval_results[f'color_{c}']['accuracy']) - + label = f"{model_name} ({training_mode})" ax.plot(color_levels, color_accs, marker='o', label=label) - + ax.set_xlabel('Number of Colors') ax.set_ylabel('Accuracy') ax.set_title('Color Quantization Robustness') @@ -148,10 +154,10 @@ def create_degradation_curves(results_path: str): ax.grid(True, alpha=0.3) ax.set_xscale('log', base=2) ax.invert_xaxis() - + plt.suptitle('Model Robustness to Different Degradation Types') plt.tight_layout() - + # Save figure output_path = Path(results_path).parent / 'degradation_curves.png' plt.savefig(output_path, dpi=300, bbox_inches='tight') @@ -160,11 +166,11 @@ def create_degradation_curves(results_path: str): # Usage if __name__ == "__main__": results_path = "results/results_comprehensive_lr0.0001_bs128_ep3.json" - + # Create visualizations pivot_df = create_robustness_heatmap(results_path) create_degradation_curves(results_path) - + # Print summary statistics print("\nModel Rankings by Robustness:") print("-" * 40) diff --git a/src/losses/classification.py b/src/losses/classification.py index 4168dee..a8c4767 100644 --- a/src/losses/classification.py +++ b/src/losses/classification.py @@ -1,8 +1,23 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +""" +Classification loss functions for model training. + +This module provides common loss functions used for classification tasks, +including cross-entropy loss with various options like label smoothing +and class weighting. +""" + # src/losses/classification.py # -*- coding: utf-8 -*- from typing import Optional +# pylint: disable=import-error import torch.nn.functional as F -from torch import Tensor +from torch import Tensor # pylint: disable=import-error def cross_entropy_loss( diff --git a/src/losses/distillation.py b/src/losses/distillation.py index 27e985e..e63c305 100644 --- a/src/losses/distillation.py +++ b/src/losses/distillation.py @@ -1,8 +1,23 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +""" +Knowledge distillation loss functions. + +This module provides various loss functions used in knowledge distillation, +including cosine embedding loss, KL divergence, and hybrid distillation loss +for transferring knowledge from teacher to student models. +""" + # src/losses/distillation.py # -*- coding: utf-8 -*- from typing import Dict +# pylint: disable=import-error import torch.nn.functional as F -from torch import Tensor +from torch import Tensor # pylint: disable=import-error def cosine_loss(reduction: str = "mean"): @@ -17,7 +32,7 @@ def _loss(s_embed: Tensor, t_embed: Tensor) -> Tensor: loss = 1.0 - sim # minimize (1 - cos) if reduction == "mean": return loss.mean() - elif reduction == "sum": + if reduction == "sum": return loss.sum() return loss return _loss @@ -27,14 +42,14 @@ def kl_divergence_loss(temperature: float = 1.0, reduction: str = "batchmean"): """ KL divergence on logits with temperature scaling (Hinton et al., 2015). """ - T = float(temperature) - T2 = T * T + temp = float(temperature) + temp_squared = temp * temp def _loss(s_logits: Tensor, t_logits: Tensor) -> Tensor: - log_p_s = F.log_softmax(s_logits / T, dim=-1) - p_t = F.softmax(t_logits / T, dim=-1) + log_p_s = F.log_softmax(s_logits / temp, dim=-1) + p_t = F.softmax(t_logits / temp, dim=-1) kl = F.kl_div(log_p_s, p_t, reduction=reduction) - return kl * T2 + return kl * temp_squared return _loss diff --git a/src/models/factory.py b/src/models/factory.py index b7210a0..b382598 100644 --- a/src/models/factory.py +++ b/src/models/factory.py @@ -1,12 +1,21 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/models/factory.py # -*- coding: utf-8 -*- """Unified model factory: HF vision models (ViT/DINOv2), optional timm, plus matching preprocessors and helpers (freeze_backbone, save_model).""" +from __future__ import annotations from typing import Dict, Any import os -import torch.nn as nn +import json +import torch +from torch import nn from PIL import Image # --- Hugging Face --- @@ -21,28 +30,35 @@ try: import timm # type: ignore _TIMM_AVAILABLE = True -except Exception: +except ImportError: _TIMM_AVAILABLE = False -# --- Project constants (small change from your code: avoid importing configs directly) --- +# --- Project constants (optional) --- try: from src.utils.constants import HF_MODELS # e.g., {"vit", "dinov2"} -except Exception: - # Fallback if constants not present yet +except ImportError: HF_MODELS = {"vit", "dinov2"} +# --- Pillow resampling constant (handles both new and old Pillow versions) --- +try: + # Pillow ≥9.1 uses Image.Resampling + RESAMPLING_LANCZOS = Image.Resampling.LANCZOS # type: ignore[attr-defined] +except AttributeError: + # Pillow <9.1 fallback; use getattr to avoid pylint false positives + RESAMPLING_LANCZOS = getattr(Image, "LANCZOS", None) or getattr(Image, "BICUBIC", None) + # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- -def create_model(model_info: Dict[str, Any], resolution: int = 224): +def create_model(model_info: Dict[str, Any], resolution: int = 224) -> nn.Module: """ Factory function to create models based on type. Args: - model_info: {"type": "vit"|"dinov2"|("timm"), "model_id": str, "config": {...}} - resolution: input image resolution + model_info: {"type": "vit"|"dinov2"|"timm", "model_id": str, "config": {...}} + resolution: input image resolution (passed to HF heads when supported) Returns: nn.Module @@ -56,30 +72,28 @@ def create_model(model_info: Dict[str, Any], resolution: int = 224): return ViTForImageClassification.from_pretrained( model_id, num_labels=config["num_labels"], - ignore_mismatched_sizes=config.get("ignore_mismatched_sizes", True), + ignore_mismatched_sizes=bool(config.get("ignore_mismatched_sizes", True)), image_size=resolution, ) # --- HuggingFace DINOv2 (AutoModel) --- - elif model_type == "dinov2": + if model_type == "dinov2": return AutoModelForImageClassification.from_pretrained( model_id, num_labels=config["num_labels"], - ignore_mismatched_sizes=config.get("ignore_mismatched_sizes", True), + ignore_mismatched_sizes=bool(config.get("ignore_mismatched_sizes", True)), image_size=resolution, ) - # --- Optional timm branch (kept minimal, no breaking changes) --- - elif model_type == "timm": + # --- timm --- + if model_type == "timm": if not _TIMM_AVAILABLE: raise RuntimeError("timm is not installed but model_type='timm' was requested.") num_classes = int(config.get("num_labels", 1000)) pretrained = bool(config.get("pretrained", True)) - # If you need image_size-specific config, many timm models accept it via `img_size` return timm.create_model(model_id, pretrained=pretrained, num_classes=num_classes) - else: - raise ValueError(f"Unknown model type: {model_type}") + raise ValueError(f"Unknown model type: {model_type}") def create_preprocessor(model_info: Dict[str, Any], resolution: int = 224): @@ -101,39 +115,40 @@ def create_preprocessor(model_info: Dict[str, Any], resolution: int = 224): model_id, size=resolution, do_resize=True, - resample=Image.LANCZOS, + resample=RESAMPLING_LANCZOS, do_normalize=True, + # ImageNet normalization statistics [red, green, blue] image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225], ) - elif model_type == "dinov2": + if model_type == "dinov2": return AutoImageProcessor.from_pretrained( model_id, size=resolution, do_resize=True, - resample=Image.LANCZOS, + resample=RESAMPLING_LANCZOS, do_normalize=True, + # ImageNet normalization statistics [red, green, blue] image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225], ) - elif model_type == "timm": + if model_type == "timm": # timm uses torchvision transforms; return None and build transforms in your datamodule return None - else: - raise ValueError(f"Unknown model type: {model_type}") + raise ValueError(f"Unknown model type: {model_type}") -def freeze_backbone(model: nn.Module, model_type: str): +def freeze_backbone(model: nn.Module, model_type: str) -> None: """ Freeze backbone parameters for transfer learning. For HF classifiers, keep 'classifier' or 'head' trainable; freeze the rest. Args: model: nn.Module - model_type: 'vit' | 'dinov2' | ('timm' if you wire it similarly) + model_type: 'vit' | 'dinov2' | 'timm' """ if model_type in HF_MODELS: for name, param in model.named_parameters(): @@ -142,18 +157,25 @@ def freeze_backbone(model: nn.Module, model_type: str): param.requires_grad = True else: param.requires_grad = False - elif model_type == "timm": - # Optional: implement project-specific rules (e.g., freeze all except last classifier) + return + + if model_type == "timm": for name, param in model.named_parameters(): if ("classifier" in name) or ("fc" in name) or ("head" in name): param.requires_grad = True else: param.requires_grad = False - else: - raise ValueError(f"Unsupported model_type: {model_type}") + return + raise ValueError(f"Unsupported model_type: {model_type}") -def save_model(model: nn.Module, model_info: Dict[str, Any], save_dir: str, preprocessor=None): + +def save_model( + model: nn.Module, + model_info: Dict[str, Any], + save_dir: str, + preprocessor=None +) -> None: """ Save model based on its type. @@ -170,22 +192,23 @@ def save_model(model: nn.Module, model_info: Dict[str, Any], save_dir: str, prep model.save_pretrained(save_dir) if preprocessor is not None: preprocessor.save_pretrained(save_dir) - elif model_type == "timm": + return + + if model_type == "timm": # Torch-style checkpoint for timm models - import torch ckpt_path = os.path.join(save_dir, "pytorch_model.bin") torch.save(model.state_dict(), ckpt_path) # Minimal config export - with open(os.path.join(save_dir, "config.json"), "w") as f: - import json + with open(os.path.join(save_dir, "config.json"), "w", encoding="utf-8") as f: json.dump( { "model_type": "timm", - "model_id": model_info["model_id"], + "model_id": model_info.get("model_id"), "num_labels": model_info.get("config", {}).get("num_labels", None), }, f, indent=2, ) - else: - raise ValueError(f"Unsupported model_type: {model_type}") + return + + raise ValueError(f"Unsupported model_type: {model_type}") diff --git a/src/requirements.txt b/src/requirements.txt deleted file mode 100644 index 2d33048..0000000 Binary files a/src/requirements.txt and /dev/null differ diff --git a/src/transformation/transforms.py b/src/transformation/transforms.py index fd39fe3..87ce6f3 100644 --- a/src/transformation/transforms.py +++ b/src/transformation/transforms.py @@ -1,44 +1,54 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + """Image transformation utilities.""" +# Standard library imports import io -from typing import Optional -import numpy as np -from PIL import Image, ImageFilter +from typing import Optional, Tuple -class ResolutionReductionTransform: # pylint: disable=too-few-public-methods - """Reduce spatial resolution of images.""" +# Third-party imports +import numpy as np # pylint: disable=import-error +from PIL import Image, ImageFilter # pylint: disable=import-error - def __init__(self, reduction_factor: Optional[float] = None): - """ - Args: - reduction_factor: Factor to reduce resolution by (0.1-1.0). - For example, 0.5 reduces to half resolution. - If None, random factor is used. - """ +class ResolutionReductionTransform: # pylint: disable=too-few-public-methods + """Reduce image resolution by factor or target resolution.""" + + def __init__( + self, + reduction_factor: Optional[float] = None, + target_resolution: Optional[Tuple[int, int]] = None, + restore_original_size: bool = False, + ): self.reduction_factor = reduction_factor + self.target_resolution = target_resolution + self.restore_original_size = restore_original_size def __call__(self, img: Image.Image) -> Image.Image: - """Apply resolution reduction.""" - if self.reduction_factor is None: - # Random reduction factor between 0.2 and 0.8 - reduction_factor = np.random.uniform(0.2, 0.8) + ow, oh = img.size + + if self.target_resolution is not None: + nw, nh = self.target_resolution else: - reduction_factor = self.reduction_factor + factor = ( + np.random.uniform(0.2, 0.8) + if self.reduction_factor is None + else self.reduction_factor + ) + factor = max(0.1, min(1.0, factor)) + nw, nh = max(1, int(ow * factor)), max(1, int(oh * factor)) - # Clamp reduction factor to valid range - reduction_factor = max(0.1, min(1.0, reduction_factor)) + # Downsample + reduced = img.resize((nw, nh), Image.Resampling.LANCZOS) - # Calculate new size - original_width, original_height = img.size - new_width = int(original_width * reduction_factor) - new_height = int(original_height * reduction_factor) + # Either return the reduced image as-is, or restore to original size + if self.restore_original_size: + return reduced.resize((ow, oh), Image.Resampling.LANCZOS) + return reduced - # Ensure minimum size of 1x1 - new_width = max(1, new_width) - new_height = max(1, new_height) - # Downsample and then upsample back to original size - downsampled = img.resize((new_width, new_height), Image.Resampling.LANCZOS) - return downsampled.resize((original_width, original_height), Image.Resampling.LANCZOS) class JPEGCompressionTransform: # pylint: disable=too-few-public-methods """Apply JPEG compression to images.""" @@ -84,7 +94,7 @@ def __call__(self, img: Image.Image) -> Image.Image: return img.filter(ImageFilter.GaussianBlur(radius=radius)) -class ColorQuantizationTransform: +class ColorQuantizationTransform: # pylint: disable=too-few-public-methods """Reduce color palette of images.""" def __init__(self, n_colors: Optional[int] = None): diff --git a/src/utils/callbacks_hf.py b/src/utils/callbacks_hf.py index 630399a..928209b 100644 --- a/src/utils/callbacks_hf.py +++ b/src/utils/callbacks_hf.py @@ -1,9 +1,17 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +# pylint: disable=all + # src/utils/callbacks_hf.py # -*- coding: utf-8 -*- from __future__ import annotations import os import json -from typing import Optional, Dict, Any +from typing import Dict, Any from transformers import TrainerCallback # type: ignore from src.utils.training_utils import get_gpu_memory diff --git a/src/utils/constants.py b/src/utils/constants.py index 9f97406..29e2172 100644 --- a/src/utils/constants.py +++ b/src/utils/constants.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/utils/constants.py # -*- coding: utf-8 -*- """Global constants and lightweight enums used across the training pipeline.""" @@ -15,5 +21,3 @@ NUM_CLASSES = 1000 # update dynamically per dataset if needed NUM_FILTERED_CLASSES = 8 # for ISIC filtered subset example - - diff --git a/src/utils/logging.py b/src/utils/logging.py deleted file mode 100644 index 5a5a69c..0000000 --- a/src/utils/logging.py +++ /dev/null @@ -1,114 +0,0 @@ -# src/utils/logging.py -# -*- coding: utf-8 -*- -from __future__ import annotations -import logging -import json -import os -from pathlib import Path -from typing import Dict, Any, Optional - -# Optional: Weights & Biases -_WANDB_AVAILABLE = False -try: - import wandb # type: ignore - _WANDB_AVAILABLE = True -except Exception: - _WANDB_AVAILABLE = False - - -def setup_logging(level: int = logging.INFO) -> None: - """Configure root logger format/level.""" - fmt = "[%(asctime)s] %(levelname)s - %(name)s: %(message)s" - datefmt = "%H:%M:%S" - logging.basicConfig(level=level, format=fmt, datefmt=datefmt) - - -def get_logger(name: str) -> logging.Logger: - """Get a module-specific logger.""" - return logging.getLogger(name) - - -class MetricAverager: - """ - Track running averages of named scalars. - Usage: - ma = MetricAverager() - ma.update(loss=0.1, acc=0.9) - avgs = ma.averages() # {"loss": ..., "acc": ...} - """ - def __init__(self) -> None: - self.totals: Dict[str, float] = {} - self.counts: Dict[str, int] = {} - - def update(self, **kwargs: float) -> None: - for k, v in kwargs.items(): - self.totals[k] = self.totals.get(k, 0.0) + float(v) - self.counts[k] = self.counts.get(k, 0) + 1 - - def averages(self) -> Dict[str, float]: - return {k: (self.totals[k] / max(self.counts[k], 1)) for k in self.totals} - - def reset(self) -> None: - self.totals.clear() - self.counts.clear() - - -class WandbLogger: - """ - Thin W&B wrapper that is safe when wandb is not installed. - """ - def __init__( - self, - project: str, - run_name: Optional[str] = None, - config: Any = None, - enabled: Optional[bool] = None, - entity: Optional[str] = None, - tags: Optional[list[str]] = None, - ) -> None: - self.enabled = (_WANDB_AVAILABLE if enabled is None else enabled) - self.run = None - if self.enabled: - self.run = wandb.init( - project=project, - name=run_name, - config=_maybe_serialize_config(config), - entity=entity, - tags=tags or [], - reinit=True, - ) - - def log(self, metrics: Dict[str, Any], step: Optional[int] = None, commit: bool = True) -> None: - if self.enabled and self.run is not None: - wandb.log(metrics, step=step, commit=commit) - - def watch_model(self, model, log: str = "gradients", log_freq: int = 100) -> None: - if self.enabled and self.run is not None: - wandb.watch(model, log=log, log_freq=log_freq) - - def save_artifact(self, path: str, name: Optional[str] = None, type_: str = "file") -> None: - if self.enabled and self.run is not None and os.path.exists(path): - art = wandb.Artifact(name or Path(path).name, type=type_) - art.add_file(path) - self.run.log_artifact(art) - - def finish(self) -> None: - if self.enabled and self.run is not None: - self.run.finish() - - -def _maybe_serialize_config(cfg: Any) -> Dict[str, Any]: - """Best-effort: convert config objects to plain dicts.""" - try: - from omegaconf import OmegaConf # type: ignore - if isinstance(cfg, dict): - return cfg - if OmegaConf.is_config(cfg): - return OmegaConf.to_container(cfg, resolve=True) # type: ignore - except Exception: - pass - try: - json.dumps(cfg) # type: ignore - return cfg # type: ignore - except Exception: - return {} diff --git a/src/utils/logging_core.py b/src/utils/logging_core.py new file mode 100644 index 0000000..5fe232c --- /dev/null +++ b/src/utils/logging_core.py @@ -0,0 +1,195 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +""" +Logging and metrics utilities for machine learning experiments. + +Provides functionality for: +- Standard Python logging setup +- Metric averaging for tracking training stats +- Optional Weights & Biases integration +""" + +# -*- coding: utf-8 -*- +from __future__ import annotations +import logging +import json +import os +from pathlib import Path +from typing import Dict, Any, Optional + +# Optional: Weights & Biases +wandb_available = False # Module-level flag for wandb availability +try: + import wandb # type: ignore # pylint: disable=import-error + wandb_available = True +except ImportError: + # wandb is an optional dependency + wandb_available = False + + +def setup_logging(level: int = logging.INFO) -> None: # pylint: disable=no-member + """Configure root logger format/level.""" + fmt = "[%(asctime)s] %(levelname)s - %(name)s: %(message)s" + datefmt = "%H:%M:%S" + logging.basicConfig(level=level, format=fmt, datefmt=datefmt) # pylint: disable=no-member + + +def get_logger(name: str) -> logging.Logger: # pylint: disable=no-member + """Get a module-specific logger.""" + return logging.getLogger(name) # pylint: disable=no-member + + +class MetricAverager: + """ + Track running averages of named scalars. + Usage: + ma = MetricAverager() + ma.update(loss=0.1, acc=0.9) + avgs = ma.averages() # {"loss": ..., "acc": ...} + """ + def __init__(self) -> None: + self.totals: Dict[str, float] = {} + self.counts: Dict[str, int] = {} + + def update(self, **kwargs: float) -> None: + """ + Update metrics with new values. + + Args: + **kwargs: Keyword arguments of metric names and values + """ + for k, v in kwargs.items(): + self.totals[k] = self.totals.get(k, 0.0) + float(v) + self.counts[k] = self.counts.get(k, 0) + 1 + + def averages(self) -> Dict[str, float]: + """ + Get current averages for all tracked metrics. + + Returns: + Dictionary mapping metric names to their averages + """ + return {k: (v / max(self.counts[k], 1)) for k, v in self.totals.items()} + + def reset(self) -> None: + """Clear all tracked metrics.""" + self.totals.clear() + self.counts.clear() + + +class WandbLogger: + """ + Thin W&B wrapper that is safe when wandb is not installed. + """ + # pylint: disable=too-many-arguments,too-many-positional-arguments + def __init__( + self, + project: str, + run_name: Optional[str] = None, + config: Any = None, + enabled: Optional[bool] = None, + entity: Optional[str] = None, + tags: Optional[list[str]] = None, + ) -> None: + """ + Initialize the W&B logger. + + Args: + project: W&B project name + run_name: Optional name for this run + config: Configuration to log (dict, OmegaConf or compatible) + enabled: Whether to enable logging (defaults to wandb_available) + entity: Optional W&B team/entity name + tags: Optional list of tags for the run + """ + self.enabled = (wandb_available if enabled is None else enabled) + self.run = None + if self.enabled: + self.run = wandb.init( + project=project, + name=run_name, + config=_maybe_serialize_config(config), + entity=entity, + tags=tags or [], + reinit=True, + ) + + def log(self, metrics: Dict[str, Any], step: Optional[int] = None, commit: bool = True) -> None: + """ + Log metrics to W&B. + + Args: + metrics: Dictionary of metrics to log + step: Optional step/iteration number + commit: Whether to immediately commit to W&B + """ + if self.enabled and self.run is not None: + wandb.log(metrics, step=step, commit=commit) + + def watch_model(self, model, log: str = "gradients", log_freq: int = 100) -> None: + """ + Watch a model's parameters and gradients. + + Args: + model: PyTorch model to watch + log: What to log ('gradients', 'parameters', 'all', or None) + log_freq: How frequently to log + """ + if self.enabled and self.run is not None: + wandb.watch(model, log=log, log_freq=log_freq) + + def save_artifact(self, path: str, name: Optional[str] = None, type_: str = "file") -> None: + """ + Save a file as a W&B artifact. + + Args: + path: Path to the file to save + name: Optional artifact name (defaults to filename) + type_: Artifact type + """ + if self.enabled and self.run is not None and os.path.exists(path): + art = wandb.Artifact(name or Path(path).name, type=type_) + art.add_file(path) + self.run.log_artifact(art) + + def finish(self) -> None: + """Mark the W&B run as complete.""" + if self.enabled and self.run is not None: + self.run.finish() + + +def _maybe_serialize_config(cfg: Any) -> Dict[str, Any]: + """ + Best-effort: convert config objects to plain dicts. + + Args: + cfg: Configuration object (dict, OmegaConf, or other) + + Returns: + JSON-serializable dictionary + """ + # If it's already a dict, return it + if isinstance(cfg, dict): + return cfg + + # Try OmegaConf conversion + try: + # pylint: disable=import-error,import-outside-toplevel + from omegaconf import OmegaConf # type: ignore + if OmegaConf.is_config(cfg): + return OmegaConf.to_container(cfg, resolve=True) # type: ignore + except ImportError: + # OmegaConf not available + pass + + # Try direct JSON serialization + try: + json.dumps(cfg) # type: ignore + return cfg # type: ignore + except (TypeError, ValueError): + # Not JSON serializable + return {} diff --git a/src/utils/optim.py b/src/utils/optim.py index 7190423..a563006 100644 --- a/src/utils/optim.py +++ b/src/utils/optim.py @@ -1,11 +1,26 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +""" +Optimization utilities for PyTorch models. + +This module provides functionality for: +- Building optimizers with weight decay parameter groups +- Creating learning rate schedulers with various policies +- Handling scheduler updates during training +""" + # src/utils/optim.py # -*- coding: utf-8 -*- from __future__ import annotations from typing import Iterable, Tuple, Dict, Any import math -import torch -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler, LambdaLR, StepLR, ReduceLROnPlateau +import torch # pylint: disable=import-error +from torch.optim import Optimizer # pylint: disable=import-error +from torch.optim.lr_scheduler import _LRScheduler, LambdaLR, StepLR, ReduceLROnPlateau # pylint: disable=import-error def _param_groups_decay(model_or_params: Iterable, weight_decay: float) -> list: @@ -35,12 +50,16 @@ def _param_groups_decay(model_or_params: Iterable, weight_decay: float) -> list: def _build_optimizer(cfg, params) -> Optimizer: - opt_cfg = getattr(cfg, "train").get("optimizer", {}) if hasattr(cfg, "train") else {} + opt_cfg = ( + getattr(cfg, "train").get("optimizer", {}) if hasattr(cfg, "train") else {} + ) name = str(opt_cfg.get("name", "adamw")).lower() lr = float(opt_cfg.get("lr", 1e-4)) wd = float(opt_cfg.get("weight_decay", 0.05)) - if isinstance(params, torch.nn.Module) or (hasattr(params, "__iter__") and hasattr(next(iter(params)), "ndim")): + if isinstance(params, torch.nn.Module) or ( + hasattr(params, "__iter__") and hasattr(next(iter(params)), "ndim") + ): param_groups = _param_groups_decay(params, wd) else: # already groups @@ -50,12 +69,14 @@ def _build_optimizer(cfg, params) -> Optimizer: betas = tuple(opt_cfg.get("betas", (0.9, 0.999))) eps = float(opt_cfg.get("eps", 1e-8)) return torch.optim.AdamW(param_groups, lr=lr, betas=betas, eps=eps) - elif name == "sgd": + if name == "sgd": momentum = float(opt_cfg.get("momentum", 0.9)) nesterov = bool(opt_cfg.get("nesterov", True)) - return torch.optim.SGD(param_groups, lr=lr, momentum=momentum, nesterov=nesterov) - else: - raise ValueError(f"Unsupported optimizer: {name}") + return torch.optim.SGD( + param_groups, lr=lr, momentum=momentum, nesterov=nesterov + ) + # If we reach this point, the optimizer is not supported + raise ValueError(f"Unsupported optimizer: {name}") def _warmup_cosine_lambda_fn(epochs: int, warmup_epochs: int, min_lr_ratio: float): @@ -63,45 +84,67 @@ def lr_lambda(current_epoch: int): if current_epoch < warmup_epochs: return float(current_epoch + 1) / float(max(1, warmup_epochs)) # cosine from 1.0 -> min_lr_ratio - progress = (current_epoch - warmup_epochs) / float(max(1, epochs - warmup_epochs)) + progress = (current_epoch - warmup_epochs) / float( + max(1, epochs - warmup_epochs) + ) cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) return min_lr_ratio + (1.0 - min_lr_ratio) * cosine + return lr_lambda -def _build_scheduler(cfg, optimizer: Optimizer) -> Tuple[_LRScheduler | None, Dict[str, Any]]: - sch_cfg = getattr(cfg, "train").get("scheduler", {}) if hasattr(cfg, "train") else {} +def _build_scheduler( + cfg, optimizer: Optimizer +) -> Tuple[_LRScheduler | None, Dict[str, Any]]: + sch_cfg = ( + getattr(cfg, "train").get("scheduler", {}) if hasattr(cfg, "train") else {} + ) name = str(sch_cfg.get("name", "cosine")).lower() epochs = int(sch_cfg.get("epochs", getattr(cfg.train, "epochs", 50))) warmup_epochs = int(sch_cfg.get("warmup_epochs", 0)) if name in ("cosine", "cosineanneal", "cosine_anneal"): min_lr = float(sch_cfg.get("min_lr", 1e-6)) - base_lr = float(getattr(cfg.train.get("optimizer", {}), "lr", 1e-4) if hasattr(cfg, "train") else 1e-4) + base_lr = float( + getattr(cfg.train.get("optimizer", {}), "lr", 1e-4) + if hasattr(cfg, "train") + else 1e-4 + ) min_lr_ratio = max(min_lr / max(base_lr, 1e-12), 0.0) lr_lambda = _warmup_cosine_lambda_fn(epochs, warmup_epochs, min_lr_ratio) scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) return scheduler, {"by": "epoch"} - elif name == "step": + if name == "step": step_size = int(sch_cfg.get("step_size", 30)) gamma = float(sch_cfg.get("gamma", 0.1)) scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma) return scheduler, {"by": "epoch"} - elif name in ("plateau", "reduceonplateau"): + if name in ("plateau", "reduceonplateau"): patience = int(sch_cfg.get("patience", 5)) factor = float(sch_cfg.get("factor", 0.5)) scheduler = ReduceLROnPlateau(optimizer, patience=patience, factor=factor) return scheduler, {"by": "val_metric"} - elif name in ("none", "off"): + if name in ("none", "off"): return None, {"by": "none"} - else: - raise ValueError(f"Unsupported scheduler: {name}") + # If we reach this point, the scheduler is not supported + raise ValueError(f"Unsupported scheduler: {name}") -def step_scheduler(scheduler, meta: Dict[str, Any], epoch: int, val_metric: float | None = None): +def step_scheduler( + scheduler, + meta: Dict[str, Any], + epoch: int = None, # pylint: disable=unused-argument + val_metric: float | None = None +): """ Step the scheduler depending on configuration (by epoch or by val metric). meta["by"] is returned from _build_scheduler. + + Args: + scheduler: The scheduler to step + meta: Metadata dictionary with scheduling policy + epoch: Current epoch (not used but kept for API compatibility) + val_metric: Validation metric for ReduceLROnPlateau schedulers """ if scheduler is None: return @@ -112,7 +155,9 @@ def step_scheduler(scheduler, meta: Dict[str, Any], epoch: int, val_metric: floa scheduler.step(val_metric) -def make_optimizer_and_scheduler(cfg, params) -> Tuple[Optimizer, Tuple[_LRScheduler | None, Dict[str, Any]]]: +def make_optimizer_and_scheduler( + cfg, params +) -> Tuple[Optimizer, Tuple[_LRScheduler | None, Dict[str, Any]]]: """ Factory: returns (optimizer, (scheduler, meta)). Use step_scheduler(scheduler, meta, epoch, val_metric) to step it. diff --git a/src/utils/training_utils.py b/src/utils/training_utils.py index 9459c6c..2420d2f 100644 --- a/src/utils/training_utils.py +++ b/src/utils/training_utils.py @@ -1,3 +1,11 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +# pylint: disable=all + # src/utils/training_utils.py # -*- coding: utf-8 -*- from __future__ import annotations diff --git a/src/utils/utils.py b/src/utils/utils.py index daa6c7b..9aad2b4 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,3 +1,10 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +# pylint: disable=all """General utilities for environment, GPU, and I/O operations.""" import os import json diff --git a/src/wrappers/distill.py b/src/wrappers/distill.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/wrappers/finetune.py b/src/wrappers/finetune.py index f76c6fa..72ebf15 100644 --- a/src/wrappers/finetune.py +++ b/src/wrappers/finetune.py @@ -1,38 +1,147 @@ -from typing import Dict, Any +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University +# SPDX-License-Identifier: MIT + +# -*- coding: utf-8 -*- +"""Fine-tuning wrapper for end-to-end optimization.""" +from __future__ import annotations +from typing import Any, Dict + +import os import torch +from torch.utils.data import DataLoader -from src.engines.finetune_engine import train_finetune -from src.data.datamodule import build_datamodule -from src.models.factory import build_classifier -from src.losses.classification import cross_entropy_loss +# pylint: disable=import-error +from src.utils.logging_core import setup_logging, get_logger, WandbLogger from src.utils.optim import make_optimizer_and_scheduler -from src.utils.logging import get_logger # TODO +from src.losses.classification import cross_entropy_loss +from src.models.factory import create_model, create_preprocessor, save_model +from src.data.datamodule import BaseDataModule +from src.engines.finetune_engine import train_finetune +from src.utils.training_utils import profile_model log = get_logger(__name__) -def run(cfg) -> Dict[str, Any]: - device = torch.device(cfg.runtime["device"]) - dm = build_datamodule(cfg) - model = build_classifier(cfg).to(device) # backbone + classification head - - # Unfreeze all params for full finetuning - if hasattr(model, "unfreeze_all"): - model.unfreeze_all() - - loss_fn = cross_entropy_loss() - - optimizer, scheduler = make_optimizer_and_scheduler(cfg, model.parameters()) - - log.info("Starting end-to-end finetuning...") - metrics = train_finetune( - model=model, - loaders=dm.loaders(), - loss_fn=loss_fn, - optimizer=optimizer, - scheduler=scheduler, - device=device, - epochs=cfg.train.epochs, - grad_clip=getattr(cfg.train, "grad_clip_norm", None), - mixed_precision=getattr(cfg.train, "amp", True), - ) - return metrics + +class FinetuneWrapper: # pylint: disable=too-many-instance-attributes,too-few-public-methods + """ + Orchestrates full fine-tuning: + - builds model (+preprocessor if needed), + - prepares dataloaders via DataModule, + - creates optimizer/scheduler/loss, + - calls the finetune engine, + - saves best model. + """ + + def __init__(self, cfg: Any): + """ + Expected cfg fields (suggested): + cfg.model, cfg.data, cfg.train, cfg.loss, cfg.logging, cfg.runtime + """ + self.cfg = cfg + setup_logging() + + # Model + self.model_info = cfg.model + self.model = create_model(self.model_info, resolution=cfg.data.image_size) + + # Optional preprocessor + try: + self.preprocessor = create_preprocessor( + self.model_info, resolution=cfg.data.image_size + ) + except (ImportError, AttributeError, KeyError) as e: + log.debug(f"Preprocessor creation failed: {e}") + self.preprocessor = None + + # Data + self.dm = BaseDataModule( + cfg=cfg, + dataset_name=cfg.data.dataset_name, + data_dir=cfg.data.data_dir, + batch_size=cfg.data.batch_size, + num_workers=cfg.data.num_workers, + pin_memory=True, + ) + self.dm.setup("fit") + + # Optimizer & scheduler + self.optimizer, (self.scheduler, self.sched_meta) = ( + make_optimizer_and_scheduler(cfg, self.model.parameters()) + ) + + # Loss + self.loss_fn = cross_entropy_loss( + label_smoothing=float(getattr(cfg.loss, "label_smoothing", 0.0)), + class_weight=None, + ignore_index=int(getattr(cfg.loss, "ignore_index", -100)), + reduction=str(getattr(cfg.loss, "reduction", "mean")), + ) + + # W&B + self.wandb = WandbLogger( + project=getattr(cfg.logging, "project", "resolution-aware-finetune"), + run_name=getattr(cfg.logging, "run_name", None), + config=cfg, + enabled=bool(getattr(cfg.logging, "wandb_enabled", True)), + entity=getattr(cfg.logging, "entity", None), + tags=getattr(cfg.logging, "tags", ["finetune"]), + ) + + # Device + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) + + def _make_loaders(self) -> Dict[str, DataLoader]: + return { + "train": self.dm.train_dataloader(), + "val": self.dm.val_dataloader(), + } + + def train(self) -> Dict[str, Any]: + """Run fine-tuning training.""" + log.info("Starting fine-tune...") + + gflops = profile_model(self.model, self.cfg.data.image_size) + if self.wandb: + self.wandb.log({"model/gflops": gflops}) + + results = train_finetune( + model=self.model, + loaders=self._make_loaders(), + loss_fn=self.loss_fn, + optimizer=self.optimizer, + scheduler=(self.scheduler, self.sched_meta), + device=self.device, + epochs=int(self.cfg.train.epochs), + grad_clip=getattr(self.cfg.train, "grad_clip", None), + mixed_precision=bool(getattr(self.cfg.train, "mixed_precision", True)), + log_interval=int(getattr(self.cfg.train, "log_interval", 50)), + wandb_logger=self.wandb, + metric_key=str(getattr(self.cfg.train, "metric_key", "val_acc")), + ) + + run_dir = getattr(self.cfg.runtime, "run_dir", "./runs/finetune") + os.makedirs(run_dir, exist_ok=True) + try: + save_model( + self.model, + self.model_info, + save_dir=run_dir, + preprocessor=self.preprocessor + ) + except (OSError, ValueError, AttributeError) as e: + log.warning(f"Failed to save model in HF format; error: {e}") + + if self.wandb: + self.wandb.log({"best/metric": results.get("best_metric", None)}) + self.wandb.finish() + + log.info("Fine-tune finished.") + return results + + +def run(cfg: Any) -> Dict[str, Any]: + wrapper = FinetuneWrapper(cfg) + return wrapper.train() diff --git a/src/wrappers/probe.py b/src/wrappers/probe.py index f1c7c12..3c872c8 100644 --- a/src/wrappers/probe.py +++ b/src/wrappers/probe.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/wrappers/probe.py # -*- coding: utf-8 -*- """Linear probing wrapper for training classification heads on frozen backbones.""" @@ -9,7 +15,7 @@ from torch.utils.data import DataLoader # pylint: disable=import-error -from src.utils.logging import setup_logging, get_logger, WandbLogger +from src.utils.logging_core import setup_logging, get_logger, WandbLogger from src.utils.optim import make_optimizer_and_scheduler from src.losses.classification import cross_entropy_loss from src.models.factory import (