From 66d27a47c30ed449cf0d9d325814b9e9b612bcde Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 11 Jul 2025 12:51:40 -0700 Subject: [PATCH 01/26] code refactoring and removal of deprecated files --- .DS_Store | Bin 6148 -> 6148 bytes .../results_metrics_finetune.json | 0 .../results_metrics_lr_experiment.json | 0 results_metrics_linear_probe.json | 3 - run.sh | 17 - .../baseline_finetuning.py | 0 src/models/.DS_Store | Bin 0 -> 6148 bytes ... => model_comparison_baseline_not_used.py} | 10 +- src/models/model_comparison_lr.py | 2 +- src/models/model_comparison_models.py | 616 +++++++----------- ...arison.py => model_comparison_not_used.py} | 2 +- src/models/utils/transforms.py | 83 ++- src/models/utils/transforms_test.py | 61 -- src/models/utils/util_classes.py | 119 ---- ...{util_classes_test.py => utils_classes.py} | 68 +- vit_lr_0.0005/config.json | 25 - vit_lr_0.0005/preprocessor_config.json | 20 - vit_lr_0.001/config.json | 25 - vit_lr_0.001/preprocessor_config.json | 20 - 19 files changed, 369 insertions(+), 702 deletions(-) rename results_metrics_finetune.json => results/results_metrics_finetune.json (100%) rename results_metrics_lr_experiment.json => results/results_metrics_lr_experiment.json (100%) delete mode 100644 results_metrics_linear_probe.json delete mode 100644 run.sh rename src/{finetune => finetune_not_used}/baseline_finetuning.py (100%) create mode 100644 src/models/.DS_Store rename src/models/{model_comparison_baseline.py => model_comparison_baseline_not_used.py} (99%) rename src/models/{model_comparison.py => model_comparison_not_used.py} (99%) delete mode 100644 src/models/utils/transforms_test.py delete mode 100644 src/models/utils/util_classes.py rename src/models/utils/{util_classes_test.py => utils_classes.py} (69%) delete mode 100644 vit_lr_0.0005/config.json delete mode 100644 vit_lr_0.0005/preprocessor_config.json delete mode 100644 vit_lr_0.001/config.json delete mode 100644 vit_lr_0.001/preprocessor_config.json diff --git a/.DS_Store b/.DS_Store index 2c9bc6d440774839ebc8173648889279e76eeed6..3af1e0fb5b2d9cf4def8daebc47418e1416a7038 100644 GIT binary patch delta 39 vcmZoMXfc@J&&a(oU^g=(_hufJr;L+>*-|DyW(%3Tl09x=1LJ0Pj=%f>1qls# delta 202 zcmZoMXfc@J&&alV`9hisvw7GUPF&G9)q-F~l>7R#~$#Y%x{^cad6GCXtglaq4tlNcBn1Q-~YQYX8yB`5PTlmSgC0qV~I sn_>V|YXF9(sOlN-GcYh-{tpH~6BrnHaOgBd(-%2;7F+ISc8OV diff --git a/results_metrics_finetune.json b/results/results_metrics_finetune.json similarity index 100% rename from results_metrics_finetune.json rename to results/results_metrics_finetune.json diff --git a/results_metrics_lr_experiment.json b/results/results_metrics_lr_experiment.json similarity index 100% rename from results_metrics_lr_experiment.json rename to results/results_metrics_lr_experiment.json diff --git a/results_metrics_linear_probe.json b/results_metrics_linear_probe.json deleted file mode 100644 index 9370969..0000000 --- a/results_metrics_linear_probe.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "vit": {} -} \ No newline at end of file diff --git a/run.sh b/run.sh deleted file mode 100644 index f48d21b..0000000 --- a/run.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name=231n_job -#SBATCH --time=2-23:59:00 -#SBATCH --output=job_output_%j.out -#SBATCH --gres=gpu:1 -#SBATCH -p roxanad -#SBATCH --mem=128G - - -ml python/3.12 -python3 -m venv venv -source venv/bin/activate - -pip install -r /home/groups/roxanad/eric/CS231N/requirements.txt - -python3 src/models/model_comparison_baseline.py \ No newline at end of file diff --git a/src/finetune/baseline_finetuning.py b/src/finetune_not_used/baseline_finetuning.py similarity index 100% rename from src/finetune/baseline_finetuning.py rename to src/finetune_not_used/baseline_finetuning.py diff --git a/src/models/.DS_Store b/src/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..26c154e2727ff8c63d16b48d252a6395506fc4d3 GIT binary patch literal 6148 zcmeHKI|>3Z5S{S@f{mqRuHX%V=n3`$7J>+(;IH1wb9pr1e41sk(?WRzlb1~9CFB)5 zJ0haX+jb!`6OjqrP#!k)&GyZEHpqwq;W*=RZ_dZV>A36Vz6%(4EH}BzUJf0;?a-(I z6`%rCfC^B7Pb-iWb~63+!90%&P=TLUz`hR!ZdeoBK>u`L@D>0#Lf8#+?GNs z1|kB}paO%c*+Nm*NakF#1^;2XH*JmF@TI|YN6W1yE~EUX;QJt^{v&9Pq- U+d!uy?sOo3222+k75KISF9lB&>Hq)$ literal 0 HcmV?d00001 diff --git a/src/models/model_comparison_baseline.py b/src/models/model_comparison_baseline_not_used.py similarity index 99% rename from src/models/model_comparison_baseline.py rename to src/models/model_comparison_baseline_not_used.py index 4cfc4ba..c3a03a0 100644 --- a/src/models/model_comparison_baseline.py +++ b/src/models/model_comparison_baseline_not_used.py @@ -41,14 +41,6 @@ # Weights & Biases import wandb -# Metrics -from sklearn.metrics import ( - accuracy_score, - confusion_matrix, - f1_score, - roc_auc_score, -) - # Model Profiling & Vision Backbones import timm from thop import profile @@ -60,7 +52,7 @@ GaussianBlurTransform, ColorQuantizationTransform, ) -from utils.util_classes import ( +from utils.utils_classes import ( ISICDataset, SimCLRForClassification, LossLoggerCallback, diff --git a/src/models/model_comparison_lr.py b/src/models/model_comparison_lr.py index 08cba4a..1985973 100644 --- a/src/models/model_comparison_lr.py +++ b/src/models/model_comparison_lr.py @@ -64,7 +64,7 @@ GaussianBlurTransform, ColorQuantizationTransform, ) -from utils.util_classes_test import ( +from utils.util_classes import ( ISICDataset, SimCLRForClassification, LossLoggerCallback, diff --git a/src/models/model_comparison_models.py b/src/models/model_comparison_models.py index f2caf77..af0e28a 100644 --- a/src/models/model_comparison_models.py +++ b/src/models/model_comparison_models.py @@ -13,21 +13,15 @@ # Standard Libraries import io -import json -import random -import time -import argparse # Scientific & Visualization Libraries import numpy as np -import matplotlib.pyplot as plt -import seaborn as sns from PIL import Image # PyTorch & Torchvision import torch import torch.nn as nn -from torch.utils.data import Dataset, Subset, ConcatDataset +from torch.utils.data import Dataset from torchvision import transforms # Hugging Face Transformers & Datasets @@ -40,32 +34,17 @@ ViTFeatureExtractor, ViTForImageClassification, ) -from datasets import load_dataset, ClassLabel +from datasets import load_dataset # Weights & Biases import wandb -# Metrics -from sklearn.metrics import ( - accuracy_score, - confusion_matrix, - f1_score, - roc_auc_score, -) - # Model Profiling & Vision Backbones import timm -from thop import profile # Local Application Imports -from utils.constants import HF_MODELS, SSL_MODEL, SIMCLR_BACKBONE, NUM_CLASSES, FILTERED_CLASSES, NUM_FILTERED_CLASSES -from utils.transforms_test import ( - JPEGCompressionTransform, - GaussianBlurTransform, - ColorQuantizationTransform, -) -from utils.util_classes_test import ( - ISICDataset, +from utils.constants import SSL_MODEL, SIMCLR_BACKBONE, FILTERED_CLASSES, NUM_FILTERED_CLASSES +from utils.util_classes import ( SimCLRForClassification, LossLoggerCallback, ) @@ -123,9 +102,220 @@ def on_evaluate(self, args, state, control, metrics=None, **kwargs): # Log evaluation metrics wandb.log(metrics) +def initialize_model_and_preprocessor(model_info, resolution): + """ + Initialize the model and preprocessor based on the model type. + + Args: + model_info (dict): Dictionary containing model details (name, model_id, type, config). + resolution (int): Image resolution. + + Returns: + model (torch.nn.Module): Initialized model. + preprocessor (transformers.PreTrainedTokenizer or None): Preprocessor for the model. + """ + name, model_id, typ, config = ( + model_info["name"], + model_info["model_id"], + model_info["type"], + model_info["config"], + ) + + if typ == "vit": + preprocessor = ViTFeatureExtractor.from_pretrained( + model_id, + size=resolution, + do_resize=True, + resample=Image.LANCZOS, + do_normalize=True, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + ) + model = ViTForImageClassification.from_pretrained( + model_id, + num_labels=NUM_FILTERED_CLASSES, + ignore_mismatched_sizes=True, + image_size=resolution, + ) + elif typ == "dinov2": + preprocessor = AutoImageProcessor.from_pretrained( + model_id, + size=resolution, + do_resize=True, + resample=Image.LANCZOS, + do_normalize=True, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + ) + model = AutoModelForImageClassification.from_pretrained( + model_id, + num_labels=NUM_FILTERED_CLASSES, + ignore_mismatched_sizes=True, + image_size=resolution, + ) + elif typ == SSL_MODEL: + backbone = timm.create_model( + SIMCLR_BACKBONE, + pretrained=True, + num_classes=0, # Remove classification head + ) + model = SimCLRForClassification(backbone, NUM_FILTERED_CLASSES) + freeze_backbone(model, SSL_MODEL) + preprocessor = None + else: + raise ValueError(f"Unsupported model type: {typ}") + + return model, preprocessor + +def balance_dataset(dataset, num_train_images, filtered_classes): + """ + Balance the dataset by sampling an equal number of images per class. + + Args: + dataset (Dataset): The dataset to balance. + num_train_images (int): Total number of training images to use. + filtered_classes (list): List of class labels to filter. + + Returns: + balanced_dataset (Dataset): Balanced dataset with equal images per class. + """ + print("Balancing dataset...") + class_counts = {label: 0 for label in filtered_classes} + for label in dataset["label"]: + class_counts[str(label)] += 1 + + print(f"Class counts: {class_counts}") # Debug print to verify counts + + min_class_size = min(class_counts.values()) + images_per_class = min(num_train_images // len(filtered_classes), min_class_size) + + np.random.seed(42) + balanced_indices = [] + for label in filtered_classes: + class_indices = [i for i, l in enumerate(dataset["label"]) if str(l) == label] + sampled_indices = np.random.choice(class_indices, images_per_class, replace=False) + balanced_indices.extend(sampled_indices) + + np.random.shuffle(balanced_indices) + return dataset.select(balanced_indices) + +def train_model(model, train_ds, val_ds, name, typ, resolution, batch_size, num_epochs, learning_rate, eval_steps, wandb_config): + """ + Train the model using the Hugging Face Trainer. + + Args: + model (torch.nn.Module): The model to train. + train_ds (Dataset): Training dataset. + val_ds (Dataset): Validation dataset. + name (str): Model name. + typ (str): Model type. + resolution (int): Image resolution. + batch_size (int): Batch size. + num_epochs (int): Number of epochs. + learning_rate (float): Learning rate. + eval_steps (int): Evaluation steps. + wandb_config (dict): Configuration for wandb logging. + + Returns: + dict: Training results and metrics. + """ + wandb.init( + entity="ericcui-use-stanford-university", + project="CS231N Test", + name=f"{name}_{resolution}_{num_epochs}_epochs_finetune", + config={**wandb_config, "model_name": name, "model_type": typ}, + tags=["baseline", "model-comparison", "finetune", name, f"res_{resolution}"], + reinit=True, + ) + + train_args = TrainingArguments( + output_dir=os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}"), + num_train_epochs=num_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + learning_rate=learning_rate, + lr_scheduler_type="cosine", + weight_decay=0.01, + logging_dir=os.path.join(env_path("LOG_DIR", "."), f"{name}"), + logging_steps=1, + eval_strategy="steps", + eval_steps=eval_steps, + save_strategy="steps", + save_steps=eval_steps, + load_best_model_at_end=False, + metric_for_best_model="accuracy", + save_total_limit=1, + ) + + trainer = Trainer( + model=model, + args=train_args, + train_dataset=train_ds, + eval_dataset=val_ds, + compute_metrics=lambda pred: compute_metrics(pred, name), + callbacks=[ + LossLoggerCallback( + log_dir=env_path("LOG_DIR", "./logs"), + phase="finetune", + model_name=name, + ), + WandbCallback(name, "finetune"), + ], + ) + + trainer.train() + eval_results = trainer.evaluate() + + wandb.finish() + return eval_results + +def prepare_datasets(dataset, preprocessor, resolution, proportion_per_transform): + """ + Prepare training and validation datasets with optional preprocessing. + + Args: + dataset: The balanced HuggingFace dataset. + preprocessor: Preprocessing function or None. + resolution: Image resolution. + proportion_per_transform: Proportion for each transform. + + Returns: + train_ds, val_ds: Torch-compatible datasets for training and validation. + """ + # Split dataset into train and validation (80/20 split) + train_size = int(0.8 * len(dataset)) + val_size = len(dataset) - train_size + train_dataset = dataset.select(range(train_size)) + val_dataset = dataset.select(range(train_size, train_size + val_size)) + + # Define basic transform + transform = transforms.Compose([ + transforms.Resize((resolution, resolution)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]) + + class TorchDataset(Dataset): + def __init__(self, hf_dataset, transform): + self.hf_dataset = hf_dataset + self.transform = transform + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + item = self.hf_dataset[idx] + image = Image.open(io.BytesIO(item["image"])).convert("RGB") + if self.transform: + image = self.transform(image) + label = int(item["label"]) + return {"pixel_values": image, "labels": label} + + train_ds = TorchDataset(train_dataset, transform) + val_ds = TorchDataset(val_dataset, transform) + return train_ds, val_ds + def main(num_train_images=100, proportion_per_transform=0.2, resolution=224, batch_size=256, num_epochs=3, eval_steps=10, learning_rate=1e-4): - - # Initialize wandb config wandb_config = { "num_train_images": num_train_images, "proportion_per_transform": proportion_per_transform, @@ -144,380 +334,26 @@ def main(num_train_images=100, proportion_per_transform=0.2, resolution=224, bat "num_labels": NUM_FILTERED_CLASSES, "ignore_mismatched_sizes": True }}, - # {"name": "dinov2", "model_id": "facebook/dinov2-base", "type": "dinov2", "config": { - # "image_size": resolution, - # "num_labels": NUM_FILTERED_CLASSES, - # "ignore_mismatched_sizes": True - # }}, - # {"name": "simclr", "model_id": "resnet50", "type": "simclr", "config": { - # "img_size": resolution - # }}, ] - results = {m["name"]: {} for m in models} - results_linear_probe = {m["name"]: {} for m in models} - dataset = load_dataset( "MKZuziak/ISIC_2019_224", cache_dir=os.environ["HF_DATASETS_CACHE"], split="train", ) - print(f"Initial dataset size: {len(dataset)} images") - - # Get indices of images with desired labels - filtered_indices = [ - i for i, label in enumerate(dataset["label"]) - if str(label) in FILTERED_CLASSES # Convert to string for comparison - ] - - # Select only those indices - dataset = dataset.select(filtered_indices) - print(f"Number of images after filtering for classes {FILTERED_CLASSES}: {len(dataset)}") - dataset = dataset.cast_column("label", ClassLabel(num_classes=NUM_FILTERED_CLASSES)) - - # Get class counts and balance dataset - optimized version - print("Balancing dataset...") - # Get counts for each class - class_counts = {label: 0 for label in FILTERED_CLASSES} - for label in dataset["label"]: - class_counts[str(label)] += 1 # Convert to string for dictionary key - - print(f"Class counts: {class_counts}") # Debug print to verify counts - - # Calculate how many images to use per class - min_class_size = min(class_counts.values()) - images_per_class = min(num_train_images // 2, min_class_size) - - # Sample indices for each class - np.random.seed(42) - balanced_indices = [] - for label in FILTERED_CLASSES: - class_indices = [i for i, l in enumerate(dataset["label"]) if str(l) == label] # Convert to string for comparison - print(f"Found {len(class_indices)} images for class {label}") # Debug print - sampled_indices = np.random.choice(class_indices, images_per_class, replace=False) - balanced_indices.extend(sampled_indices) - - np.random.shuffle(balanced_indices) - balanced_dataset = dataset.select(balanced_indices) - - # Split into train and validation - full_dataset = balanced_dataset.train_test_split( - test_size=0.2, stratify_by_column="label", seed=42 - ) - - train_dataset, val_dataset = full_dataset["train"], full_dataset["test"] - - degradation_transforms = [ - JPEGCompressionTransform(), - GaussianBlurTransform(), - ColorQuantizationTransform(), - ] - - num_transforms = len(degradation_transforms) - num_images = len(train_dataset) - images_per_transform = int(num_images * proportion_per_transform) + dataset = balance_dataset(dataset, num_train_images, FILTERED_CLASSES) - # Create a single preprocessor for each model type - preprocessors = {} for model_info in models: - name, model_id, typ, config = ( - model_info["name"], - model_info["model_id"], - model_info["type"], - model_info["config"], - ) - if typ == "vit": - preprocessors[typ] = ViTFeatureExtractor.from_pretrained( - model_id, - size=resolution, - do_resize=True, - resample=Image.LANCZOS, - do_normalize=True, - image_mean=[0.485, 0.456, 0.406], - image_std=[0.229, 0.224, 0.225] - ) - elif typ == "dinov2": - preprocessors[typ] = AutoImageProcessor.from_pretrained( - model_id, - size=resolution, - do_resize=True, - resample=Image.LANCZOS, - do_normalize=True, - image_mean=[0.485, 0.456, 0.406], - image_std=[0.229, 0.224, 0.225] - ) - else: - preprocessors[typ] = None - - # Process each model type separately - for model_info in models: - name, model_id, typ, config = ( - model_info["name"], - model_info["model_id"], - model_info["type"], - model_info["config"], - ) - - # Initialize wandb for this specific model run - wandb.init( - entity="ericcui-use-stanford-university", - project="CS231N Test", - name=f"{name}_{resolution}_{num_epochs}_epochs_finetune", - config={**wandb_config, "model_config": config}, - tags=["baseline", "model-comparison", "finetune", name, f"res_{resolution}"], - reinit=True - ) + model, preprocessor = initialize_model_and_preprocessor(model_info, resolution) - transformed_datasets = [] - indices = np.arange(num_images) - np.random.shuffle(indices) + # Prepare datasets + train_ds, val_ds = prepare_datasets(dataset, preprocessor, resolution, proportion_per_transform) - used_indices = [] - for i, transform in enumerate(degradation_transforms): - subset_indices = indices[i * images_per_transform:(i + 1) * images_per_transform] - used_indices.extend(subset_indices) - subset = Subset(train_dataset, subset_indices) - transform_compose = transforms.Compose([transform]) - - transformed_ds = ISICDataset( - subset, - preprocessors[typ], - resolution, - transform_compose, - typ - ) - transformed_datasets.append(transformed_ds) - - remaining_indices = np.setdiff1d(indices, used_indices) - - if len(remaining_indices) > 0: - remaining_subset = Subset(train_dataset, remaining_indices) - untransformed_ds = ISICDataset( - remaining_subset, - preprocessors[typ], - resolution, - None, - typ - ) - transformed_datasets.append(untransformed_ds) - - train_ds = ConcatDataset(transformed_datasets) - val_ds = ISICDataset( - val_dataset, - preprocessors[typ], - resolution, - model_type=typ, - ) - - if typ == "vit": - model = ViTForImageClassification.from_pretrained( - model_id, - num_labels=NUM_FILTERED_CLASSES, - ignore_mismatched_sizes=True, - image_size=resolution, - ) - elif typ == "dinov2": - model = AutoModelForImageClassification.from_pretrained( - model_id, - num_labels=NUM_FILTERED_CLASSES, - ignore_mismatched_sizes=True, - image_size=resolution - ) - elif typ == SSL_MODEL: - backbone = timm.create_model( - SIMCLR_BACKBONE, - pretrained=True, - num_classes=0 # Remove classification head - ) - model = SimCLRForClassification(backbone, NUM_FILTERED_CLASSES) - freeze_backbone(model, SSL_MODEL) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - - try: - dummy_input = torch.randn(1, 3, resolution, resolution).to(device) - model.to(device) - flops, _ = profile(model, inputs=(dummy_input,)) - flops /= 1e9 - except Exception as e: - print(f"FLOP profiling failed: {e}") - flops = -1 - - train_args = TrainingArguments( - output_dir=os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}"), - num_train_epochs=num_epochs, - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - learning_rate=learning_rate, - lr_scheduler_type="cosine", - weight_decay=0.01, - logging_dir=os.path.join(env_path("LOG_DIR", "."), f"{name}"), - logging_steps=1, - eval_strategy="steps", - eval_steps=eval_steps, - save_strategy="steps", - save_steps=eval_steps, - load_best_model_at_end=False, - metric_for_best_model="accuracy", - save_total_limit=1, - save_safetensors=False, - hub_model_id=None, - hub_strategy="end", - push_to_hub=False, - save_only_model=True, + # Train the model + results = train_model( + model, train_ds, val_ds, model_info["name"], model_info["type"], + resolution, batch_size, num_epochs, learning_rate, eval_steps, wandb_config ) - - # Clean up old model directories before training - model_dirs = [ - os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}"), - os.path.join(env_path("MODEL_DIR", "."), f"{name}"), - os.path.join(env_path("LOG_DIR", "."), f"{name}"), - ] - - for dir_path in model_dirs: - if os.path.exists(dir_path): - print(f"Cleaning up directory: {dir_path}") - import shutil - shutil.rmtree(dir_path) - os.makedirs(dir_path, exist_ok=True) - - # Monitor disk space before training - import shutil - total, used, free = shutil.disk_usage("/") - print(f"Disk space before training {name}:") - print(f"Total: {total // (2**30)} GB") - print(f"Used: {used // (2**30)} GB") - print(f"Free: {free // (2**30)} GB") - - # If less than 1GB free, raise an error - if free < 1 * (2**30): # 1GB in bytes - raise RuntimeError("Not enough disk space to train model. Please free up at least 1GB of space.") - - trainer = Trainer( - model=model, - args=train_args, - train_dataset=train_ds, - eval_dataset=val_ds, - compute_metrics=lambda pred: compute_metrics(pred, name), - callbacks=[ - LossLoggerCallback( - log_dir=env_path("LOG_DIR", "./logs"), - phase="finetune", - model_name=name, - ), - WandbCallback(name, "finetune"), - ], - ) - - # ---- TRAINING PHASE ---- - start_time = time.time() - peak_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - - # Log model architecture - if typ in HF_MODELS: - wandb.watch(model, log="all", log_freq=100) - elif typ == SSL_MODEL: - wandb.watch(model.backbone, log="all", log_freq=100) - - trainer.train() - - current_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - peak_memory = max(peak_memory, current_memory) - - eval_start_time = time.time() - eval_results = trainer.evaluate() - eval_time = time.time() - eval_start_time - train_time = time.time() - start_time - eval_time - - # Log model-specific metrics - model_metrics = { - "model_name": name, - "model_type": typ, - "peak_memory_mb": peak_memory, - "flops_giga": flops, - "train_time_seconds": train_time, - "eval_time_seconds": eval_time, - "eval_metrics": eval_results, - } - wandb.log(model_metrics) - - # Create parameterized directory path - dir_name = f"{name}_{typ}_lr{learning_rate}_bs{batch_size}" - model_dir = os.path.join(env_path("MODEL_DIR", "."), dir_name) - os.makedirs(model_dir, exist_ok=True) - - if typ in HF_MODELS: - model.save_pretrained(model_dir) - preprocessors[typ].save_pretrained(model_dir) - elif typ == SSL_MODEL: - torch.save( - model.state_dict(), os.path.join(model_dir, "pytorch_model.bin") - ) - with open(os.path.join(model_dir, "config.json"), "w") as f: - json.dump( - { - "model_type": SSL_MODEL, - "backbone": "resnet50", - "num_classes": NUM_FILTERED_CLASSES, - }, - f, - ) - - # Save model as wandb artifact - artifact = wandb.Artifact( - name=f"{dir_name}_model", - type="model", - description=f"Trained {name} model with {typ} architecture" - ) - artifact.add_dir(model_dir) - wandb.log_artifact(artifact) - - results[name] = model_metrics - - print(f"[Finetune] {name}: {results[name]}") - - # Close the wandb run for this model - wandb.finish() - - # Clear GPU memory after model is done - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Remove the final wandb.finish() since we're now closing each run individually - with open( - os.path.join( - env_path("TRAIN_OUTPUT_DIR", "."), f"results_metrics_finetune_{dir_name}.json" - ), - "w", - ) as f: - json.dump(results, f, indent=4) - - with open( - os.path.join( - env_path("TRAIN_OUTPUT_DIR", "."), f"results_metrics_linear_probe_{dir_name}.json" - ), - "w", - ) as f: - json.dump(results_linear_probe, f, indent=4) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Model comparison script for image classification.") - parser.add_argument('--resolution', type=int, default=224, help='Input image resolution (default: 224)') - parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training and evaluation (default: 128)') - parser.add_argument('--num_train_images', type=int, default=500, help='Number of training images to use per class (default: 500)') - parser.add_argument('--num_epochs', type=int, default=3, help='Number of training epochs (default: 3)') - parser.add_argument('--eval_steps', type=int, default=100, help='Number of steps between evaluations (default: 100)') - parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate (default: 1e-4)') - args = parser.parse_args() - main( - resolution=args.resolution, - batch_size=args.batch_size, - num_train_images=args.num_train_images, - num_epochs=args.num_epochs, - eval_steps=args.eval_steps, - learning_rate=args.learning_rate - ) + print(f"Results for {model_info['name']}: {results}") diff --git a/src/models/model_comparison.py b/src/models/model_comparison_not_used.py similarity index 99% rename from src/models/model_comparison.py rename to src/models/model_comparison_not_used.py index 77cca79..d4a8639 100644 --- a/src/models/model_comparison.py +++ b/src/models/model_comparison_not_used.py @@ -55,7 +55,7 @@ # Local Application Imports from utils.constants import HF_MODELS, SSL_MODEL, SIMCLR_BACKBONE, NUM_CLASSES, FILTERED_CLASSES, NUM_FILTERED_CLASSES -from utils.transforms import ( +from src.models.utils.transforms import ( JPEGCompressionTransform, GaussianBlurTransform, ColorQuantizationTransform, diff --git a/src/models/utils/transforms.py b/src/models/utils/transforms.py index 46acd96..93d0fa4 100644 --- a/src/models/utils/transforms.py +++ b/src/models/utils/transforms.py @@ -7,39 +7,118 @@ from PIL import Image from torchvision import transforms + class JPEGCompressionTransform: def __init__(self, quality=75): + """ + Apply JPEG compression to an image. + + Args: + quality (int): Compression quality (1-100, higher is better quality). + """ self.quality = quality def __call__(self, img): + """ + Apply JPEG compression to the input image. + + Args: + img (PIL.Image or Tensor): Input image. + + Returns: + PIL.Image: Compressed image. + """ if not isinstance(img, Image.Image): img = transforms.ToPILImage()(img) + + # Store original size + original_size = img.size + + # Apply JPEG compression buffer = io.BytesIO() img.save(buffer, format="JPEG", quality=self.quality) buffer.seek(0) - return Image.open(buffer) + img = Image.open(buffer) + + # Ensure size is maintained + if img.size != original_size: + img = img.resize(original_size, Image.LANCZOS) + + return img + class GaussianBlurTransform: def __init__(self, p=1): + """ + Apply Gaussian blur to an image with a given probability. + + Args: + p (float): Probability of applying the blur (0 to 1). + """ self.p = p def __call__(self, img): + """ + Apply Gaussian blur to the input image. + + Args: + img (PIL.Image or Tensor): Input image. + + Returns: + PIL.Image: Blurred image. + """ if not isinstance(img, Image.Image): img = transforms.ToPILImage()(img) + + # Store original size + original_size = img.size + + # Apply Gaussian blur with probability p if random.random() < self.p: kernel_size = random.choice([3, 5, 7]) sigma = random.uniform(0.1, 2.0) img = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma)(img) + + # Ensure size is maintained + if img.size != original_size: + img = img.resize(original_size, Image.LANCZOS) + return img + class ColorQuantizationTransform: def __init__(self, p=1): + """ + Apply color quantization to an image with a given probability. + + Args: + p (float): Probability of applying the quantization (0 to 1). + """ self.p = p def __call__(self, img): + """ + Apply color quantization to the input image. + + Args: + img (PIL.Image or Tensor): Input image. + + Returns: + PIL.Image: Quantized image. + """ if not isinstance(img, Image.Image): img = transforms.ToPILImage()(img) + + # Store original size + original_size = img.size + + # Apply color quantization with probability p if random.random() < self.p: num_colors = random.randint(16, 64) img = img.quantize(colors=num_colors, method=Image.Quantize.MAXCOVERAGE).convert("RGB") - return img \ No newline at end of file + + # Ensure size is maintained + if img.size != original_size: + img = img.resize(original_size, Image.LANCZOS) + + return img \ No newline at end of file diff --git a/src/models/utils/transforms_test.py b/src/models/utils/transforms_test.py deleted file mode 100644 index 5450f22..0000000 --- a/src/models/utils/transforms_test.py +++ /dev/null @@ -1,61 +0,0 @@ -""" -Image transformation utilities for data augmentation and degradation. -""" - -import io -import random -from PIL import Image -from torchvision import transforms - -class JPEGCompressionTransform: - def __init__(self, quality=75): - self.quality = quality - - def __call__(self, img): - if not isinstance(img, Image.Image): - img = transforms.ToPILImage()(img) - # Store original size - original_size = img.size - buffer = io.BytesIO() - img.save(buffer, format="JPEG", quality=self.quality) - buffer.seek(0) - img = Image.open(buffer) - # Ensure size is maintained - if img.size != original_size: - img = img.resize(original_size, Image.LANCZOS) - return img - -class GaussianBlurTransform: - def __init__(self, p=1): - self.p = p - - def __call__(self, img): - if not isinstance(img, Image.Image): - img = transforms.ToPILImage()(img) - # Store original size - original_size = img.size - if random.random() < self.p: - kernel_size = random.choice([3, 5, 7]) - sigma = random.uniform(0.1, 2.0) - img = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma)(img) - # Ensure size is maintained - if img.size != original_size: - img = img.resize(original_size, Image.LANCZOS) - return img - -class ColorQuantizationTransform: - def __init__(self, p=1): - self.p = p - - def __call__(self, img): - if not isinstance(img, Image.Image): - img = transforms.ToPILImage()(img) - # Store original size - original_size = img.size - if random.random() < self.p: - num_colors = random.randint(16, 64) - img = img.quantize(colors=num_colors, method=Image.Quantize.MAXCOVERAGE).convert("RGB") - # Ensure size is maintained - if img.size != original_size: - img = img.resize(original_size, Image.LANCZOS) - return img diff --git a/src/models/utils/util_classes.py b/src/models/utils/util_classes.py deleted file mode 100644 index e35cb5a..0000000 --- a/src/models/utils/util_classes.py +++ /dev/null @@ -1,119 +0,0 @@ -""" -Utility classes for model training and data handling. -""" - -import os -import json -import numpy as np -import torch -import torch.nn as nn -from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms -from transformers import TrainerCallback - -from .constants import HF_MODELS, NUM_FILTERED_CLASSES, SSL_MODEL -from .transforms import JPEGCompressionTransform - -class ISICDataset(Dataset): - def __init__( - self, - dataset, - preprocessor=None, - resolution=224, - transform=None, - model_type="vit", - jpeg_quality=None, - ): - self.dataset = dataset - self.preprocessor = preprocessor - self.resolution = resolution - self.transform = transform - self.model_type = model_type - self.jpeg_quality = jpeg_quality - if model_type == SSL_MODEL: - self.preprocessor = transforms.Compose( - [ - transforms.Resize((resolution, resolution)), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), - ] - ) - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx): - # Convert numpy.int64 to Python int if necessary - if isinstance(idx, (np.integer, np.int64)): - idx = int(idx) - - # Handle both direct dataset access and Subset access - if hasattr(self.dataset, 'dataset'): - # This is a Subset - subset_idx = int(self.dataset.indices[idx]) # Convert the index from the indices array - item = self.dataset.dataset[subset_idx] - else: - # This is a direct dataset - item = self.dataset[idx] - - image = item["image"] - label = item["label"] - - if self.resolution != 224: - image = image.resize((self.resolution, self.resolution), Image.LANCZOS) - - if self.transform: - image = self.transform(image) - - if self.jpeg_quality is not None: - image = JPEGCompressionTransform(self.jpeg_quality)(image) - - if self.model_type in HF_MODELS: - encoding = self.preprocessor(images=image, return_tensors="pt") - pixel_values = encoding["pixel_values"].squeeze(0) - elif self.model_type == SSL_MODEL: - pixel_values = self.preprocessor(image) - else: - raise ValueError(f"Unsupported model_type: {self.model_type}") - - label = torch.tensor(label, dtype=torch.long) - return {"pixel_values": pixel_values, "labels": label} - - -class SimCLRForClassification(nn.Module): - def __init__(self, backbone, num_classes=NUM_FILTERED_CLASSES): - super().__init__() - self.backbone = backbone - self.classifier = nn.Linear(2048, num_classes) - - def forward(self, pixel_values, labels=None): - features = self.backbone(pixel_values) - logits = self.classifier(features) - loss = None - if labels is not None: - loss = nn.CrossEntropyLoss()(logits, labels) - return ( - {"logits": logits, "loss": loss} if loss is not None else {"logits": logits} - ) - - -class LossLoggerCallback(TrainerCallback): - """ - Logs each training step's loss and other metrics to a structured JSON Lines file. - """ - - def __init__(self, log_dir: str, phase: str, model_name: str): - os.makedirs(log_dir, exist_ok=True) - self.log_file = os.path.join( - log_dir, f"{model_name}_{phase}_log.jsonl" - ) - - def on_log(self, args, state, control, logs=None, **kwargs): - if logs is None: - return - with open(self.log_file, "a") as f: - json.dump({"step": state.global_step, **logs}, f) - f.write("\n") \ No newline at end of file diff --git a/src/models/utils/util_classes_test.py b/src/models/utils/utils_classes.py similarity index 69% rename from src/models/utils/util_classes_test.py rename to src/models/utils/utils_classes.py index 25c1e18..0fe3e88 100644 --- a/src/models/utils/util_classes_test.py +++ b/src/models/utils/utils_classes.py @@ -15,6 +15,7 @@ from .constants import HF_MODELS, NUM_FILTERED_CLASSES, SSL_MODEL from .transforms import JPEGCompressionTransform + class ISICDataset(Dataset): def __init__( self, @@ -25,24 +26,36 @@ def __init__( model_type="vit", jpeg_quality=None, ): + """ + Dataset class for handling ISIC image data. + + Args: + dataset: The dataset to load. + preprocessor: Preprocessing function for Hugging Face models. + resolution: Target image resolution. + transform: Additional transformations to apply. + model_type: Type of model (e.g., "vit", "ssl"). + jpeg_quality: JPEG compression quality (if applicable). + """ self.dataset = dataset self.preprocessor = preprocessor self.resolution = resolution self.transform = transform self.model_type = model_type self.jpeg_quality = jpeg_quality - - # Create a base preprocessing pipeline that always resizes to the target resolution + + # Base preprocessing pipeline for resizing and tensor conversion self.base_preprocessor = transforms.Compose([ transforms.Resize((resolution, resolution), Image.LANCZOS), transforms.ToTensor(), ]) - + + # Preprocessor for SSL models if model_type == SSL_MODEL: self.preprocessor = transforms.Compose([ transforms.Resize((resolution, resolution), Image.LANCZOS), transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __len__(self): @@ -52,7 +65,7 @@ def __getitem__(self, idx): # Convert numpy.int64 to Python int if necessary if isinstance(idx, (np.integer, np.int64)): idx = int(idx) - + # Handle both direct dataset access and Subset access if hasattr(self.dataset, 'dataset'): # This is a Subset @@ -61,21 +74,24 @@ def __getitem__(self, idx): else: # This is a direct dataset item = self.dataset[idx] - + image = item["image"] label = item["label"] # Always resize to target resolution first image = image.resize((self.resolution, self.resolution), Image.LANCZOS) + # Apply additional transformations if provided if self.transform: image = self.transform(image) + # Apply JPEG compression if specified if self.jpeg_quality is not None: image = JPEGCompressionTransform(self.jpeg_quality)(image) + # Preprocessing for Hugging Face models if self.model_type in HF_MODELS: - # For HF models, ensure the preprocessor doesn't resize again + # Ensure the preprocessor doesn't resize again if hasattr(self.preprocessor, 'size'): self.preprocessor.size = self.resolution encoding = self.preprocessor(images=image, return_tensors="pt") @@ -87,15 +103,32 @@ def __getitem__(self, idx): label = torch.tensor(label, dtype=torch.long) return {"pixel_values": pixel_values, "labels": label} - + class SimCLRForClassification(nn.Module): def __init__(self, backbone, num_classes=NUM_FILTERED_CLASSES): + """ + SimCLR-based classification model. + + Args: + backbone: The backbone model (e.g., ResNet). + num_classes: Number of output classes. + """ super().__init__() self.backbone = backbone self.classifier = nn.Linear(2048, num_classes) def forward(self, pixel_values, labels=None): + """ + Forward pass for the model. + + Args: + pixel_values: Input image tensors. + labels: Ground truth labels (optional). + + Returns: + dict: Dictionary containing logits and loss (if labels are provided). + """ features = self.backbone(pixel_values) logits = self.classifier(features) loss = None @@ -112,14 +145,31 @@ class LossLoggerCallback(TrainerCallback): """ def __init__(self, log_dir: str, phase: str, model_name: str): + """ + Initialize the callback. + + Args: + log_dir: Directory to save the log file. + phase: Training phase (e.g., "finetune"). + model_name: Name of the model. + """ os.makedirs(log_dir, exist_ok=True) self.log_file = os.path.join( log_dir, f"{model_name}_{phase}_log.jsonl" ) def on_log(self, args, state, control, logs=None, **kwargs): + """ + Log metrics to a JSON Lines file. + + Args: + args: Training arguments. + state: Trainer state. + control: Trainer control. + logs: Metrics to log. + """ if logs is None: return with open(self.log_file, "a") as f: json.dump({"step": state.global_step, **logs}, f) - f.write("\n") + f.write("\n") \ No newline at end of file diff --git a/vit_lr_0.0005/config.json b/vit_lr_0.0005/config.json deleted file mode 100644 index f72ae5f..0000000 --- a/vit_lr_0.0005/config.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "architectures": [ - "ViTForImageClassification" - ], - "attention_probs_dropout_prob": 0.0, - "encoder_stride": 16, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.0, - "hidden_size": 768, - "image_size": 56, - "initializer_range": 0.02, - "intermediate_size": 3072, - "layer_norm_eps": 1e-12, - "model_type": "vit", - "num_attention_heads": 12, - "num_channels": 3, - "num_hidden_layers": 12, - "patch_size": 16, - "pooler_act": "tanh", - "pooler_output_size": 768, - "problem_type": "single_label_classification", - "qkv_bias": true, - "torch_dtype": "float32", - "transformers_version": "4.52.4" -} diff --git a/vit_lr_0.0005/preprocessor_config.json b/vit_lr_0.0005/preprocessor_config.json deleted file mode 100644 index 2349c28..0000000 --- a/vit_lr_0.0005/preprocessor_config.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "do_convert_rgb": null, - "do_normalize": true, - "do_rescale": true, - "do_resize": true, - "image_mean": [ - 0.485, - 0.456, - 0.406 - ], - "image_processor_type": "ViTFeatureExtractor", - "image_std": [ - 0.229, - 0.224, - 0.225 - ], - "resample": 1, - "rescale_factor": 0.00392156862745098, - "size": 56 -} diff --git a/vit_lr_0.001/config.json b/vit_lr_0.001/config.json deleted file mode 100644 index f72ae5f..0000000 --- a/vit_lr_0.001/config.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "architectures": [ - "ViTForImageClassification" - ], - "attention_probs_dropout_prob": 0.0, - "encoder_stride": 16, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.0, - "hidden_size": 768, - "image_size": 56, - "initializer_range": 0.02, - "intermediate_size": 3072, - "layer_norm_eps": 1e-12, - "model_type": "vit", - "num_attention_heads": 12, - "num_channels": 3, - "num_hidden_layers": 12, - "patch_size": 16, - "pooler_act": "tanh", - "pooler_output_size": 768, - "problem_type": "single_label_classification", - "qkv_bias": true, - "torch_dtype": "float32", - "transformers_version": "4.52.4" -} diff --git a/vit_lr_0.001/preprocessor_config.json b/vit_lr_0.001/preprocessor_config.json deleted file mode 100644 index 2349c28..0000000 --- a/vit_lr_0.001/preprocessor_config.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "do_convert_rgb": null, - "do_normalize": true, - "do_rescale": true, - "do_resize": true, - "image_mean": [ - 0.485, - 0.456, - 0.406 - ], - "image_processor_type": "ViTFeatureExtractor", - "image_std": [ - 0.229, - 0.224, - 0.225 - ], - "resample": 1, - "rescale_factor": 0.00392156862745098, - "size": 56 -} From a2418277958fd0aad5b5121ec87d102f96f03b9f Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 11 Jul 2025 13:02:49 -0700 Subject: [PATCH 02/26] clean model scripts and renaming --- docs/pipeline.md | 66 ++ src/finetune_not_used/baseline_finetuning.py | 260 -------- src/models/model_comparison_baseline.py | 250 ++++++++ .../model_comparison_baseline_not_used.py | 567 ------------------ ...son_lr.py => model_comparison_lr_sweep.py} | 11 + src/models/model_comparison_models.py | 359 ----------- src/models/model_comparison_not_used.py | 476 --------------- 7 files changed, 327 insertions(+), 1662 deletions(-) create mode 100644 docs/pipeline.md delete mode 100644 src/finetune_not_used/baseline_finetuning.py create mode 100644 src/models/model_comparison_baseline.py delete mode 100644 src/models/model_comparison_baseline_not_used.py rename src/models/{model_comparison_lr.py => model_comparison_lr_sweep.py} (97%) delete mode 100644 src/models/model_comparison_models.py delete mode 100644 src/models/model_comparison_not_used.py diff --git a/docs/pipeline.md b/docs/pipeline.md new file mode 100644 index 0000000..772baa5 --- /dev/null +++ b/docs/pipeline.md @@ -0,0 +1,66 @@ +# Model Comparison Pipeline + +## Overview + +This script (`model_comparison_models.py`) provides a baseline for comparing different image classification models at various image compression levels, including the original images. It supports fine-tuning, linear probing, and optional image degradation transforms. + +--- + +## Features + +- **Model Support:** Vision Transformer (ViT), DINOv2, SimCLR (self-supervised backbone) +- **Data Augmentation:** Optional JPEG compression, Gaussian blur, and color quantization +- **Dataset Balancing:** Ensures equal samples per class for fair comparison +- **Training & Evaluation:** Uses Hugging Face Trainer for streamlined workflows +- **Experiment Tracking:** Integrated with Weights & Biases (`wandb`) +- **GPU Monitoring:** Optional support via `pynvml` + +--- + +## Workflow + +1. **Environment Setup** + - Loads required libraries and sets up cache directories. + - Checks for GPU availability. + +2. **Dataset Loading & Balancing** + - Loads ISIC_2019_224 dataset. + - Balances the dataset across filtered classes. + +3. **Model Initialization** + - Initializes model and preprocessor based on configuration. + +4. **Preprocessing & Augmentation** + - Applies resizing, normalization, and optional degradation transforms. + +5. **Training & Evaluation** + - Splits data into training and validation sets. + - Trains and evaluates each model, logging results to `wandb`. + +--- + +## Usage + +```bash +python [model_comparison_models.py](http://_vscodecontentref_/2) --resolution 224 --batch_size 256 --num_train_images 25000 --num_epochs 10 --eval_steps 10 +``` + +## Configuration +- Models: Edit the models list in the script to add or modify model configurations. +- Transforms: Toggle apply_transforms in prepare_datasets() to enable/disable augmentations. +- Hyperparameters: Adjust arguments in the main() function for batch size, epochs, etc. + +## Output +- Training and evaluation metrics are printed to the console and logged to Weights & Biases. +- Results can be used for further analysis or ablation studies. + +## Extending +- Add new models by updating the models list. +- Implement new transforms in utils/transforms.py. +- Add new datasets by modifying the dataset loading logic. + +## References +- Hugging Face Transformers +- PyTorch +- Weights & Biases +- ISIC 2019 Dataset \ No newline at end of file diff --git a/src/finetune_not_used/baseline_finetuning.py b/src/finetune_not_used/baseline_finetuning.py deleted file mode 100644 index 6d43dcd..0000000 --- a/src/finetune_not_used/baseline_finetuning.py +++ /dev/null @@ -1,260 +0,0 @@ -""" -This script fine-tunes a ViT model on the ISIC 2019 dataset with various resolutions. -It includes data augmentation, model evaluation, and GPU memory monitoring. - -Example run: -python finetune_isic.py \ - --dataset_name MKZuziak/ISIC_2019_224 \ - --resolutions 224 112 \ - --num_epochs 5 \ - --output_dir ./results_vit_isic -""" - -# TODO: pass the file paths as os.imports - -# Standard libraries -import json -import time - -# Third-party libraries -import numpy as np -import matplotlib.pyplot as plt -from scipy.special import softmax -import seaborn as sns -from sklearn.metrics import ( - accuracy_score, - confusion_matrix, - f1_score, - roc_auc_score, -) -import torch -from torchvision import transforms -from transformers import ( - ViTForImageClassification, - ViTFeatureExtractor, - TrainingArguments, - Trainer, -) -from datasets import load_dataset - -# Profiling -from thop import profile - -# GPU memory monitoring via pynvml -try: - import pynvml - pynvml.nvmlInit() - GPU_AVAILABLE = True -except ImportError: - GPU_AVAILABLE = False - print("pynvml not installed, GPU memory monitoring disabled.") - -import argparse -from pathlib import Path - -def parse_args(): - parser = argparse.ArgumentParser(description="Fine-tune ViT on ISIC 2019 dataset.") - parser.add_argument("--dataset_name", type=str, default="MKZuziak/ISIC_2019_224", - help="HuggingFace dataset name") - parser.add_argument("--output_dir", type=Path, default=Path("./results"), - help="Where to save models and results") - parser.add_argument("--resolutions", type=int, nargs="+", default=[224, 112, 56], - help="List of resolutions to fine-tune on") - parser.add_argument("--num_epochs", type=int, default=3, - help="Number of training epochs") - return parser.parse_args() - - -# Compute metrics for evaluation -def compute_metrics(eval_pred, model_name, resolution): - """ - Compute accuracy, F1 score, and AUC for the model predictions. - Also generates a confusion matrix and saves it as an image. - """ - logits, labels = eval_pred - predictions = np.argmax(logits, axis=-1) - acc = accuracy_score(labels, predictions) - f1 = f1_score(labels, predictions, average="weighted") - - probs = softmax(logits, axis=1) - auc = roc_auc_score(labels, probs, multi_class="ovr") - # You're computing AUC from logits, not probabilities. Use soft-max first. - # auc = roc_auc_score(labels, logits, multi_class="ovr") - - # Plot confusion matrix - conf_mat = confusion_matrix(labels, predictions) - _, ax = plt.subplots(figsize=(10, 10)) - sns.heatmap(conf_mat, annot=True, cmap="Blues") - ax.set_xlabel("Predicted labels") - ax.set_ylabel("True labels") - ax.set_title(f"{model_name}_{resolution}_conf_mat") - plt.savefig(f"{model_name}_{resolution}_conf_mat.png", dpi=300, bbox_inches="tight") - plt.close() - - # Classification breakdown - unique, counts = np.unique(predictions, return_counts=True) - class_breakdown = dict(zip(unique, counts)) - with open(f"{model_name}_{resolution}_class_breakdown.json", "w") as f: - json.dump(class_breakdown, f) - - return {"accuracy": acc, "f1": f1, "auc": auc} - - -# Measure GPU memory usage -def get_gpu_memory(device_id=0): - if not GPU_AVAILABLE: - return -1 - try: - handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) - mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) - return mem_info.used / 1024**2 # MB - except: - return -1 - - -# Main function for fine-tuning -def main(): - """ - Main function to fine-tune the model on the ISIC 2019 dataset. - """ - # Models and resolutions to compare - models = [ - {"name": "vit", "model_id": "google/vit-base-patch16-224", "type": "vit"}, - ] - resolutions = args.resolutions - - # Results storage - results = {model["name"]: {} for model in models} - - # Load dataset - dataset = load_dataset(args.dataset_name) - if dataset is None: - print(f"Dataset {args.dataset_name} not found.") - return - full_dataset = dataset["train"].train_test_split( - test_size=0.2, stratify_by_column="label", seed=42 - ) - train_dataset = full_dataset["train"] - val_dataset = full_dataset["test"] - - # Data augmentation - transform = transforms.Compose( - [ - transforms.RandomHorizontalFlip(), - transforms.RandomRotation(20), - transforms.ColorJitter(brightness=0.2, contrast=0.2), - ] - ) - - for model_info in models: - model_name = model_info["name"] - model_id = model_info["model_id"] - model_type = model_info["type"] - print(f"\nFine-tuning model: {model_name}") - - for resolution in resolutions: - print(f"Resolution: {resolution}x{resolution}") - - # Load preprocessor - if model_type == "vit": - preprocessor = ViTFeatureExtractor.from_pretrained( - model_id, size=resolution - ) - else: - preprocessor = None # timm models use manual preprocessing - - # Create datasets - train_ds = ISICDataset( - train_dataset, preprocessor, resolution, transform, model_type - ) - val_ds = ISICDataset( - val_dataset, preprocessor, resolution, model_type=model_type - ) - - # Load model - if model_type == "vit": - model = ViTForImageClassification.from_pretrained( - model_id, num_labels=8, ignore_mismatched_sizes=True - ) - else: - pass # Load other models as needed - - # Estimate FLOPs - input_tensor = torch.randn(1, 3, resolution, resolution) - try: - flops, _ = profile(model, inputs=(input_tensor,)) - flops = flops / 1e9 # GFLOPs - except: - flops = -1 # Fallback if FLOPs estimation fails - - # Training arguments - training_args = TrainingArguments( - output_dir=args.output_dir / f"{model_name}_{resolution}", - num_train_epochs=args.num_epochs, - per_device_train_batch_size=16, - per_device_eval_batch_size=16, - warmup_steps=500, - weight_decay=0.01, - logging_dir=f"./logs_{model_name}_{resolution}", - logging_steps=10, - eval_strategy="epoch", - save_strategy="epoch", - load_best_model_at_end=True, - metric_for_best_model="accuracy", - ) - - # Initialize Trainer - trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_ds, - eval_dataset=val_ds, - compute_metrics=lambda pred: compute_metrics( - pred, model_name, resolution - ), - ) - - # Measure memory and time - start_time = time.time() - peak_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - - # Fine-tune - trainer.train() - - # Update peak memory - current_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - peak_memory = max(peak_memory, current_memory) - - # Evaluate - eval_start_time = time.time() - eval_results = trainer.evaluate() - eval_time = time.time() - eval_start_time - - # Total training time - train_time = time.time() - start_time - eval_time - - # Save model - model.save_pretrained(f"./finetuned_{model_name}_{resolution}") - if model_type == "vit": - preprocessor.save_pretrained(f"./finetuned_{model_name}_{resolution}") - - # Store results - results[model_name][resolution] = { - "peak_memory_mb": peak_memory, - "flops_giga": flops, - "train_time_seconds": train_time, - "eval_time_seconds": eval_time, - "eval_metrics": eval_results, - } - print( - f"Results for {model_name} at {resolution}x{resolution}: {results[model_name][resolution]}" - ) - - # Save results to JSON - args.output_dir.mkdir(parents=True, exist_ok=True) - with open(args.output_dir / "results_metrics.json", "w", encoding="utf-8") as f: - json.dump(results, f, indent=4) - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/src/models/model_comparison_baseline.py b/src/models/model_comparison_baseline.py new file mode 100644 index 0000000..89da7a4 --- /dev/null +++ b/src/models/model_comparison_baseline.py @@ -0,0 +1,250 @@ +""" +This script is a baseline for comparing different image classification models +at three different image compression levels, in comparison to the original. +It supports fine-tuning, linear probing, and optional degradation transforms. + +Purpose: General pipeline for comparing image classification models (ViT, DINOv2, SimCLR) with optional image degradation transforms. +Features: +- Supports fine-tuning and linear probing. +- Balances dataset across classes. +- Applies optional transforms (JPEG, blur, quantization). +- Uses a fixed model list (currently ViT). +- No command-line argument parsing; hyperparameters are set in the main() function. +- Designed for baseline model comparison across compression levels. +""" + +# Environment Setup +import os +import io +import json +import numpy as np +from PIL import Image + +# PyTorch & Torchvision +import torch +import torch.nn as nn +from torch.utils.data import Dataset +from torchvision import transforms + +# Hugging Face Transformers & Datasets +from transformers import ( + AutoImageProcessor, + AutoModelForImageClassification, + Trainer, + TrainerCallback, + TrainingArguments, + ViTFeatureExtractor, + ViTForImageClassification, +) +from datasets import load_dataset + +# Weights & Biases +import wandb + +# Model Profiling & Vision Backbones +import timm +from thop import profile + +# Local Application Imports +from utils.constants import SSL_MODEL, SIMCLR_BACKBONE, FILTERED_CLASSES, NUM_FILTERED_CLASSES +from utils.transforms import JPEGCompressionTransform, GaussianBlurTransform, ColorQuantizationTransform +from utils.util_classes import SimCLRForClassification, LossLoggerCallback +from utils.util_methods import env_path, compute_metrics, get_gpu_memory, freeze_backbone + +# GPU Memory Monitoring (optional) +try: + import pynvml + pynvml.nvmlInit() + GPU_AVAILABLE = True +except ImportError: + GPU_AVAILABLE = False + print("pynvml not installed, GPU memory monitoring disabled.") + +# Cache paths +os.environ["TRANSFORMERS_CACHE"] = os.getenv("TRANSFORMERS_CACHE", "~/.cache/huggingface/transformers") +os.environ["HF_DATASETS_CACHE"] = os.getenv("HF_DATASETS_CACHE", "~/.cache/huggingface/datasets") +os.environ["HF_HOME"] = os.getenv("HF_HOME", "~/.cache/huggingface") + + +class WandbCallback(TrainerCallback): + def __init__(self, model_name, phase): + self.model_name = model_name + self.phase = phase + self.best_accuracy = 0.0 + + def on_log(self, args, state, control, logs=None, **kwargs): + if logs is not None: + logs["model"] = self.model_name + logs["phase"] = self.phase + if GPU_AVAILABLE: + logs["gpu_memory_mb"] = get_gpu_memory() + wandb.log(logs) + + def on_evaluate(self, args, state, control, metrics=None, **kwargs): + if metrics is not None: + if "eval_accuracy" in metrics: + self.best_accuracy = max(self.best_accuracy, metrics["eval_accuracy"]) + metrics["best_accuracy"] = self.best_accuracy + wandb.log(metrics) + + +def initialize_model_and_preprocessor(model_info, resolution): + name, model_id, typ, config = ( + model_info["name"], + model_info["model_id"], + model_info["type"], + model_info["config"], + ) + + if typ == "vit": + preprocessor = ViTFeatureExtractor.from_pretrained( + model_id, + size=resolution, + do_resize=True, + resample=Image.LANCZOS, + do_normalize=True, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + ) + model = ViTForImageClassification.from_pretrained( + model_id, + num_labels=NUM_FILTERED_CLASSES, + ignore_mismatched_sizes=True, + image_size=resolution, + ) + elif typ == "dinov2": + preprocessor = AutoImageProcessor.from_pretrained( + model_id, + size=resolution, + do_resize=True, + resample=Image.LANCZOS, + do_normalize=True, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + ) + model = AutoModelForImageClassification.from_pretrained( + model_id, + num_labels=NUM_FILTERED_CLASSES, + ignore_mismatched_sizes=True, + image_size=resolution, + ) + elif typ == SSL_MODEL: + backbone = timm.create_model( + SIMCLR_BACKBONE, + pretrained=True, + num_classes=0, + ) + model = SimCLRForClassification(backbone, NUM_FILTERED_CLASSES) + freeze_backbone(model, SSL_MODEL) + preprocessor = None + else: + raise ValueError(f"Unsupported model type: {typ}") + + return model, preprocessor + + +def balance_dataset(dataset, num_train_images, filtered_classes): + print("Balancing dataset...") + class_counts = {label: 0 for label in filtered_classes} + for label in dataset["label"]: + class_counts[str(label)] += 1 + + min_class_size = min(class_counts.values()) + images_per_class = min(num_train_images // len(filtered_classes), min_class_size) + + np.random.seed(42) + balanced_indices = [] + for label in filtered_classes: + class_indices = [i for i, l in enumerate(dataset["label"]) if str(l) == label] + sampled_indices = np.random.choice(class_indices, images_per_class, replace=False) + balanced_indices.extend(sampled_indices) + + np.random.shuffle(balanced_indices) + return dataset.select(balanced_indices) + + +def prepare_datasets(dataset, preprocessor, resolution, proportion_per_transform, apply_transforms=False): + train_size = int(0.8 * len(dataset)) + val_size = len(dataset) - train_size + train_dataset = dataset.select(range(train_size)) + val_dataset = dataset.select(range(train_size, train_size + val_size)) + + transform = transforms.Compose([ + transforms.Resize((resolution, resolution)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]) + + if apply_transforms: + transform = transforms.Compose([ + transform, + JPEGCompressionTransform(quality=75), + GaussianBlurTransform(p=0.5), + ColorQuantizationTransform(p=0.5), + ]) + + class TorchDataset(Dataset): + def __init__(self, hf_dataset, transform): + self.hf_dataset = hf_dataset + self.transform = transform + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + item = self.hf_dataset[idx] + image = Image.open(io.BytesIO(item["image"])).convert("RGB") + if self.transform: + image = self.transform(image) + label = int(item["label"]) + return {"pixel_values": image, "labels": label} + + train_ds = TorchDataset(train_dataset, transform) + val_ds = TorchDataset(val_dataset, transform) + return train_ds, val_ds + + +def main(num_train_images=100, proportion_per_transform=0.2, resolution=224, batch_size=256, num_epochs=3, eval_steps=10, learning_rate=1e-4): + wandb_config = { + "num_train_images": num_train_images, + "proportion_per_transform": proportion_per_transform, + "resolution": resolution, + "batch_size": batch_size, + "num_epochs": num_epochs, + "eval_steps": eval_steps, + "weight_decay": 0.01, + "learning_rate": learning_rate, + "gpu_available": GPU_AVAILABLE, + } + + models = [ + {"name": "vit", "model_id": "google/vit-base-patch16-224", "type": "vit", "config": { + "image_size": resolution, + "num_labels": NUM_FILTERED_CLASSES, + "ignore_mismatched_sizes": True + }}, + ] + + dataset = load_dataset( + "MKZuziak/ISIC_2019_224", + cache_dir=os.environ["HF_DATASETS_CACHE"], + split="train", + ) + + dataset = balance_dataset(dataset, num_train_images, FILTERED_CLASSES) + + for model_info in models: + model, preprocessor = initialize_model_and_preprocessor(model_info, resolution) + + train_ds, val_ds = prepare_datasets(dataset, preprocessor, resolution, proportion_per_transform, apply_transforms=True) + + results = train_model( + model, train_ds, val_ds, model_info["name"], model_info["type"], + resolution, batch_size, num_epochs, learning_rate, eval_steps, wandb_config + ) + + print(f"Results for {model_info['name']}: {results}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/models/model_comparison_baseline_not_used.py b/src/models/model_comparison_baseline_not_used.py deleted file mode 100644 index c3a03a0..0000000 --- a/src/models/model_comparison_baseline_not_used.py +++ /dev/null @@ -1,567 +0,0 @@ -''' -This script is a baseline for comparing different image classification models -at three different image compression levels, in comparison to the original. -It has a set number of augmentation transforms and does NOT combine them. -This does NOT experiment on JPEG compression levels -''' - -# Environment Setup -import os - -# Standard Libraries -import io -import json -import random -import time - -# Scientific & Visualization Libraries -import numpy as np -import matplotlib.pyplot as plt -import seaborn as sns -from PIL import Image - -# PyTorch & Torchvision -import torch -import torch.nn as nn -from torch.utils.data import Dataset, Subset, ConcatDataset -from torchvision import transforms - -# Hugging Face Transformers & Datasets -from transformers import ( - AutoImageProcessor, - AutoModelForImageClassification, - Trainer, - TrainerCallback, - TrainingArguments, - ViTFeatureExtractor, - ViTForImageClassification, -) -from datasets import load_dataset, ClassLabel - -# Weights & Biases -import wandb - -# Model Profiling & Vision Backbones -import timm -from thop import profile - -# Local Application Imports -from utils.constants import HF_MODELS, SSL_MODEL, SIMCLR_BACKBONE, NUM_CLASSES, FILTERED_CLASSES, NUM_FILTERED_CLASSES -from utils.transforms import ( - JPEGCompressionTransform, - GaussianBlurTransform, - ColorQuantizationTransform, -) -from utils.utils_classes import ( - ISICDataset, - SimCLRForClassification, - LossLoggerCallback, -) -from utils.util_methods import ( - env_path, - compute_metrics, - get_gpu_memory, - freeze_backbone, -) - -# GPU Memory Monitoring (optional) -try: - import pynvml - pynvml.nvmlInit() - GPU_AVAILABLE = True -except ImportError: - GPU_AVAILABLE = False - print("pynvml not installed, GPU memory monitoring disabled.") - -# Cache paths -os.environ["TRANSFORMERS_CACHE"] = os.getenv( - "TRANSFORMERS_CACHE", "~/.cache/huggingface/transformers" -) -os.environ["HF_DATASETS_CACHE"] = os.getenv( - "HF_DATASETS_CACHE", "~/.cache/huggingface/datasets" -) -os.environ["HF_HOME"] = os.getenv("HF_HOME", "~/.cache/huggingface") - -wandb.login(key=os.getenv("WANDB_API_KEY", "")) - -class WandbCallback(TrainerCallback): - def __init__(self, model_name, phase): - self.model_name = model_name - self.phase = phase - self.best_accuracy = 0.0 - - def on_log(self, args, state, control, logs=None, **kwargs): - if logs is not None: - # Add model name and phase to logs - logs["model"] = self.model_name - logs["phase"] = self.phase - - # Track GPU memory if available - if GPU_AVAILABLE: - logs["gpu_memory_mb"] = get_gpu_memory() - - # Log to wandb - wandb.log(logs) - - def on_evaluate(self, args, state, control, metrics=None, **kwargs): - if metrics is not None: - # Track best accuracy - if "eval_accuracy" in metrics: - self.best_accuracy = max(self.best_accuracy, metrics["eval_accuracy"]) - metrics["best_accuracy"] = self.best_accuracy - - # Log evaluation metrics - wandb.log(metrics) - -def main(num_train_images=5000, proportion_per_transform=0.2, resolution=224): - - # Simple batch size configuration - batch_size = 256 - - # Initialize wandb - wandb_config = { - "num_train_images": num_train_images, - "proportion_per_transform": proportion_per_transform, - "resolution": resolution, - "batch_size": batch_size, - "num_epochs": 3, - "warmup_steps": 500, - "weight_decay": 0.01, - "gpu_available": GPU_AVAILABLE, - } - - wandb.init( - entity="ericcui-use-stanford-university", - project="CS231N Test", - config=wandb_config, - tags=["baseline", "model-comparison"] - ) - - models = [ - {"name": "vit", "model_id": "google/vit-base-patch16-224", "type": "vit"}, - {"name": "dinov2", "model_id": "facebook/dinov2-base", "type": "dinov2"}, - {"name": "simclr", "model_id": "resnet50", "type": "simclr"}, - ] - - results = {m["name"]: {} for m in models} - results_linear_probe = {m["name"]: {} for m in models} - - dataset = load_dataset( - "MKZuziak/ISIC_2019_224", - cache_dir=os.environ["HF_DATASETS_CACHE"], - split="train", - ) - - print(f"Initial dataset size: {len(dataset)} images") - - # Get indices of images with desired labels - filtered_indices = [ - i for i, label in enumerate(dataset["label"]) - if str(label) in FILTERED_CLASSES # Convert to string for comparison - ] - - # Select only those indices - dataset = dataset.select(filtered_indices) - print(f"Number of images after filtering for classes {FILTERED_CLASSES}: {len(dataset)}") - dataset = dataset.cast_column("label", ClassLabel(num_classes=NUM_FILTERED_CLASSES)) - - # Get class counts and balance dataset - optimized version - print("Balancing dataset...") - # Get counts for each class - class_counts = {label: 0 for label in FILTERED_CLASSES} - for label in dataset["label"]: - class_counts[str(label)] += 1 # Convert to string for dictionary key - - print(f"Class counts: {class_counts}") # Debug print to verify counts - - # Calculate how many images to use per class - min_class_size = min(class_counts.values()) - images_per_class = min(num_train_images // 2, min_class_size) - - # Sample indices for each class - np.random.seed(42) - balanced_indices = [] - for label in FILTERED_CLASSES: - class_indices = [i for i, l in enumerate(dataset["label"]) if str(l) == label] # Convert to string for comparison - print(f"Found {len(class_indices)} images for class {label}") # Debug print - sampled_indices = np.random.choice(class_indices, images_per_class, replace=False) - balanced_indices.extend(sampled_indices) - - np.random.shuffle(balanced_indices) - balanced_dataset = dataset.select(balanced_indices) - - # Split into train and validation - full_dataset = balanced_dataset.train_test_split( - test_size=0.2, stratify_by_column="label", seed=42 - ) - - train_dataset, val_dataset = full_dataset["train"], full_dataset["test"] - - degradation_transforms = [ - JPEGCompressionTransform(), - GaussianBlurTransform(), - ColorQuantizationTransform(), - ] - - num_transforms = len(degradation_transforms) - num_images = len(train_dataset) - images_per_transform = int(num_images * proportion_per_transform) - - transformed_datasets = [] - indices = np.arange(num_images) - np.random.shuffle(indices) - - used_indices = [] - for i, transform in enumerate(degradation_transforms): - subset_indices = indices[i * images_per_transform:(i + 1) * images_per_transform] - used_indices.extend(subset_indices) - subset = Subset(train_dataset, subset_indices) - transform_compose = transforms.Compose([transform]) - - for model_info in models: - name, model_id, typ = ( - model_info["name"], - model_info["model_id"], - model_info["type"], - ) - if typ == "vit": - preprocessor = ViTFeatureExtractor.from_pretrained(model_id, size=resolution) - elif typ == "dinov2": - preprocessor = AutoImageProcessor.from_pretrained(model_id, size=resolution) - else: - preprocessor = None - - transformed_ds = ISICDataset(subset, preprocessor, resolution, transform_compose, typ) - transformed_datasets.append(transformed_ds) - - remaining_indices = np.setdiff1d(indices, used_indices) - - if len(remaining_indices) > 0: - remaining_subset = Subset(train_dataset, remaining_indices) - for model_info in models: - name, model_id, typ = ( - model_info["name"], - model_info["model_id"], - model_info["type"], - ) - if typ == "vit": - preprocessor = ViTFeatureExtractor.from_pretrained(model_id, size=resolution) - elif typ == "dinov2": - preprocessor = AutoImageProcessor.from_pretrained(model_id, size=resolution) - else: - preprocessor = None - - # No transform applied to remaining indices - untransformed_ds = ISICDataset(remaining_subset, preprocessor, resolution, None, typ) - transformed_datasets.append(untransformed_ds) - - train_ds = ConcatDataset(transformed_datasets) - - val_ds = ISICDataset( - val_dataset, - preprocessor, - resolution, - model_type=typ, - ) - - for model_info in models: - name, model_id, typ = ( - model_info["name"], - model_info["model_id"], - model_info["type"], - ) - if typ == "vit": - preprocessor = ViTFeatureExtractor.from_pretrained( - model_id, size=resolution - ) - elif typ == "dinov2": - preprocessor = AutoImageProcessor.from_pretrained(model_id, size=resolution) - else: - preprocessor = None - - if typ == "vit": - model = ViTForImageClassification.from_pretrained( - model_id, num_labels=NUM_FILTERED_CLASSES, ignore_mismatched_sizes=True - ) - elif typ == "dinov2": - model = AutoModelForImageClassification.from_pretrained( - model_id, num_labels=NUM_FILTERED_CLASSES, ignore_mismatched_sizes=True - ) - elif typ == SSL_MODEL: - # Load pretrained ResNet50 backbone - backbone = timm.create_model( - SIMCLR_BACKBONE, - pretrained=True, - num_classes=0 # Remove classification head - ) - # Create SimCLR model with the backbone - model = SimCLRForClassification(backbone, NUM_FILTERED_CLASSES) - # Freeze the backbone initially - freeze_backbone(model, SSL_MODEL) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - - try: - dummy_input = torch.randn(1, 3, resolution, resolution).to(device) - model.to(device) - flops, _ = profile(model, inputs=(dummy_input,)) - flops /= 1e9 - except Exception as e: - print(f"FLOP profiling failed: {e}") - flops = -1 - - train_args = TrainingArguments( - output_dir=os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}"), - num_train_epochs=3, - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - warmup_steps=500, - weight_decay=0.01, - logging_dir=os.path.join(env_path("LOG_DIR", "."), f"{name}"), - logging_steps=1, - eval_strategy="epoch", - save_strategy="epoch", - load_best_model_at_end=True, - metric_for_best_model="accuracy", - save_total_limit=1, # Only keep the best model - ) - - trainer = Trainer( - model=model, - args=train_args, - train_dataset=train_ds, - eval_dataset=val_ds, - compute_metrics=lambda pred: compute_metrics(pred, name), - callbacks=[ - LossLoggerCallback( - log_dir=env_path("LOG_DIR", "./logs"), - phase="finetune", - model_name=name, - ), - WandbCallback(name, "finetune"), - ], - ) - - # ---- TRAINING PHASE ---- - start_time = time.time() - peak_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - - # Log model architecture - if typ in HF_MODELS: - wandb.watch(model, log="all", log_freq=100) - elif typ == SSL_MODEL: - wandb.watch(model.backbone, log="all", log_freq=100) - - trainer.train() - - current_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - peak_memory = max(peak_memory, current_memory) - - eval_start_time = time.time() - eval_results = trainer.evaluate() - eval_time = time.time() - eval_start_time - train_time = time.time() - start_time - eval_time - - # Log model-specific metrics - model_metrics = { - "model_name": name, - "model_type": typ, - "peak_memory_mb": peak_memory, - "flops_giga": flops, - "train_time_seconds": train_time, - "eval_time_seconds": eval_time, - "eval_metrics": eval_results, - } - wandb.log(model_metrics) - - model_dir = os.path.join( - env_path("MODEL_DIR", "."), f"{name}" - ) - os.makedirs(model_dir, exist_ok=True) - - if typ in HF_MODELS: - model.save_pretrained(model_dir) - preprocessor.save_pretrained(model_dir) - elif typ == SSL_MODEL: - torch.save( - model.state_dict(), os.path.join(model_dir, "pytorch_model.bin") - ) - with open(os.path.join(model_dir, "config.json"), "w") as f: - json.dump( - { - "model_type": SSL_MODEL, - "backbone": "resnet50", - "num_classes": NUM_FILTERED_CLASSES, - }, - f, - ) - - # Save model as wandb artifact - artifact = wandb.Artifact( - name=f"{name}_model", - type="model", - description=f"Trained {name} model with {typ} architecture" - ) - artifact.add_dir(model_dir) - wandb.log_artifact(artifact) - - results[name] = model_metrics - - print( - f"[Finetune] {name}: {results[name]}" - ) - - # ---- LINEAR PROBE PHASE ---- - # Create a new wandb run for linear probe - wandb.init( - project="model-comparison-baseline", - config=wandb_config, - tags=["baseline", "model-comparison", "linear-probe"], - name=f"{name}_linear_probe" - ) - - if typ == "vit": - model = ViTForImageClassification.from_pretrained( - model_id, num_labels=NUM_FILTERED_CLASSES, ignore_mismatched_sizes=True - ) - elif typ == "dinov2": - model = AutoModelForImageClassification.from_pretrained( - model_id, num_labels=NUM_FILTERED_CLASSES, ignore_mismatched_sizes=True - ) - elif typ == SSL_MODEL: - backbone = timm.create_model("resnet50", pretrained=True, num_classes=0) - model = SimCLRForClassification(backbone, NUM_FILTERED_CLASSES) - - model.to(device) - freeze_backbone(model, typ) - - # Log model architecture for linear probe - if typ in HF_MODELS: - wandb.watch(model, log="all", log_freq=100) - elif typ == SSL_MODEL: - wandb.watch(model.backbone, log="all", log_freq=100) - - linear_args = TrainingArguments( - output_dir=os.path.join( - env_path("TRAIN_OUTPUT_DIR", "."), - f"{name}_linear_probe", - ), - num_train_epochs=1, - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - warmup_steps=100, - weight_decay=0.01, - logging_dir=os.path.join( - env_path("LOG_DIR", "."), f"{name}_linear_probe" - ), - logging_steps=1, - eval_strategy="epoch", - save_strategy="epoch", - load_best_model_at_end=True, - metric_for_best_model="accuracy", - save_total_limit=1, # Only keep the best model - ) - - trainer = Trainer( - model=model, - args=linear_args, - train_dataset=train_ds, - eval_dataset=val_ds, - compute_metrics=lambda pred: compute_metrics(pred, name), - callbacks=[ - LossLoggerCallback( - log_dir=env_path("LOG_DIR", "./logs"), - phase="linear_probe", - model_name=name, - ), - WandbCallback(name, "linear_probe"), - ], - ) - - start_time = time.time() - peak_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - trainer.train() - current_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - peak_memory = max(peak_memory, current_memory) - - eval_start_time = time.time() - eval_results = trainer.evaluate() - eval_time = time.time() - eval_start_time - train_time = time.time() - start_time - eval_time - - # Log linear probe metrics - linear_probe_metrics = { - "model_name": name, - "model_type": typ, - "phase": "linear_probe", - "peak_memory_mb": peak_memory, - "flops_giga": flops, - "train_time_seconds": train_time, - "eval_time_seconds": eval_time, - "eval_metrics": eval_results, - } - wandb.log(linear_probe_metrics) - - model_dir = os.path.join( - env_path("MODEL_DIR", "."), f"{name}_linear_probe" - ) - os.makedirs(model_dir, exist_ok=True) - - if typ in HF_MODELS: - model.save_pretrained(model_dir) - preprocessor.save_pretrained(model_dir) - elif typ == SSL_MODEL: - torch.save( - model.state_dict(), os.path.join(model_dir, "pytorch_model.bin") - ) - with open(os.path.join(model_dir, "config.json"), "w") as f: - json.dump( - { - "model_type": SSL_MODEL, - "backbone": "resnet50", - "num_classes": NUM_FILTERED_CLASSES, - }, - f, - ) - - # Save linear probe model as wandb artifact - artifact = wandb.Artifact( - name=f"{name}_linear_probe_model", - type="model", - description=f"Linear probe {name} model with {typ} architecture" - ) - artifact.add_dir(model_dir) - wandb.log_artifact(artifact) - - results_linear_probe[name] = linear_probe_metrics - - print( - f"[LinearProbe] {name}: {results_linear_probe[name]}" - ) - - # Close the wandb run for linear probe - wandb.finish() - - # Close the main wandb run - wandb.finish() - - with open( - os.path.join( - env_path("TRAIN_OUTPUT_DIR", "."), "results_metrics_finetune.json" - ), - "w", - ) as f: - json.dump(results, f, indent=4) - - with open( - os.path.join( - env_path("TRAIN_OUTPUT_DIR", "."), "results_metrics_linear_probe.json" - ), - "w", - ) as f: - json.dump(results_linear_probe, f, indent=4) - - -if __name__ == "__main__": - main() diff --git a/src/models/model_comparison_lr.py b/src/models/model_comparison_lr_sweep.py similarity index 97% rename from src/models/model_comparison_lr.py rename to src/models/model_comparison_lr_sweep.py index 1985973..5185bbc 100644 --- a/src/models/model_comparison_lr.py +++ b/src/models/model_comparison_lr_sweep.py @@ -3,6 +3,17 @@ at three different image compression levels, in comparison to the original. It has a set number of augmentation transforms and does NOT combine them. This does NOT experiment on JPEG compression levels + + +Purpose: Specialized for learning rate experiments with image classification models. +Features: +- Focuses on learning rate sweeps for a single model (currently DINOv2). +- Uses command-line argument parsing for hyperparameters. +- Applies a set number of augmentation transforms (does not combine them). +- Does not experiment with JPEG compression levels. +- Logs results for each learning rate and saves them to a JSON file. +- More flexible for hyperparameter tuning and ablation studies. + ''' # Environment Setup diff --git a/src/models/model_comparison_models.py b/src/models/model_comparison_models.py deleted file mode 100644 index af0e28a..0000000 --- a/src/models/model_comparison_models.py +++ /dev/null @@ -1,359 +0,0 @@ -''' -This script is a baseline for comparing different image classification models -at three different image compression levels, in comparison to the original. -It has a set number of augmentation transforms and does NOT combine them. -This does NOT experiment on JPEG compression levels -''' - -# Environment Setup -import os - -# Set memory optimization -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - -# Standard Libraries -import io - -# Scientific & Visualization Libraries -import numpy as np -from PIL import Image - -# PyTorch & Torchvision -import torch -import torch.nn as nn -from torch.utils.data import Dataset -from torchvision import transforms - -# Hugging Face Transformers & Datasets -from transformers import ( - AutoImageProcessor, - AutoModelForImageClassification, - Trainer, - TrainerCallback, - TrainingArguments, - ViTFeatureExtractor, - ViTForImageClassification, -) -from datasets import load_dataset - -# Weights & Biases -import wandb - -# Model Profiling & Vision Backbones -import timm - -# Local Application Imports -from utils.constants import SSL_MODEL, SIMCLR_BACKBONE, FILTERED_CLASSES, NUM_FILTERED_CLASSES -from utils.util_classes import ( - SimCLRForClassification, - LossLoggerCallback, -) -from utils.util_methods import ( - env_path, - compute_metrics, - get_gpu_memory, - freeze_backbone, -) - -# GPU Memory Monitoring (optional) -try: - import pynvml - pynvml.nvmlInit() - GPU_AVAILABLE = True -except ImportError: - GPU_AVAILABLE = False - print("pynvml not installed, GPU memory monitoring disabled.") - -# Cache paths -os.environ["TRANSFORMERS_CACHE"] = os.getenv( - "TRANSFORMERS_CACHE", "~/.cache/huggingface/transformers" -) -os.environ["HF_DATASETS_CACHE"] = os.getenv( - "HF_DATASETS_CACHE", "~/.cache/huggingface/datasets" -) -os.environ["HF_HOME"] = os.getenv("HF_HOME", "~/.cache/huggingface") - -class WandbCallback(TrainerCallback): - def __init__(self, model_name, phase): - self.model_name = model_name - self.phase = phase - self.best_accuracy = 0.0 - - def on_log(self, args, state, control, logs=None, **kwargs): - if logs is not None: - # Add model name and phase to logs - logs["model"] = self.model_name - logs["phase"] = self.phase - - # Track GPU memory if available - if GPU_AVAILABLE: - logs["gpu_memory_mb"] = get_gpu_memory() - - # Log to wandb - wandb.log(logs) - - def on_evaluate(self, args, state, control, metrics=None, **kwargs): - if metrics is not None: - # Track best accuracy - if "eval_accuracy" in metrics: - self.best_accuracy = max(self.best_accuracy, metrics["eval_accuracy"]) - metrics["best_accuracy"] = self.best_accuracy - - # Log evaluation metrics - wandb.log(metrics) - -def initialize_model_and_preprocessor(model_info, resolution): - """ - Initialize the model and preprocessor based on the model type. - - Args: - model_info (dict): Dictionary containing model details (name, model_id, type, config). - resolution (int): Image resolution. - - Returns: - model (torch.nn.Module): Initialized model. - preprocessor (transformers.PreTrainedTokenizer or None): Preprocessor for the model. - """ - name, model_id, typ, config = ( - model_info["name"], - model_info["model_id"], - model_info["type"], - model_info["config"], - ) - - if typ == "vit": - preprocessor = ViTFeatureExtractor.from_pretrained( - model_id, - size=resolution, - do_resize=True, - resample=Image.LANCZOS, - do_normalize=True, - image_mean=[0.485, 0.456, 0.406], - image_std=[0.229, 0.224, 0.225], - ) - model = ViTForImageClassification.from_pretrained( - model_id, - num_labels=NUM_FILTERED_CLASSES, - ignore_mismatched_sizes=True, - image_size=resolution, - ) - elif typ == "dinov2": - preprocessor = AutoImageProcessor.from_pretrained( - model_id, - size=resolution, - do_resize=True, - resample=Image.LANCZOS, - do_normalize=True, - image_mean=[0.485, 0.456, 0.406], - image_std=[0.229, 0.224, 0.225], - ) - model = AutoModelForImageClassification.from_pretrained( - model_id, - num_labels=NUM_FILTERED_CLASSES, - ignore_mismatched_sizes=True, - image_size=resolution, - ) - elif typ == SSL_MODEL: - backbone = timm.create_model( - SIMCLR_BACKBONE, - pretrained=True, - num_classes=0, # Remove classification head - ) - model = SimCLRForClassification(backbone, NUM_FILTERED_CLASSES) - freeze_backbone(model, SSL_MODEL) - preprocessor = None - else: - raise ValueError(f"Unsupported model type: {typ}") - - return model, preprocessor - -def balance_dataset(dataset, num_train_images, filtered_classes): - """ - Balance the dataset by sampling an equal number of images per class. - - Args: - dataset (Dataset): The dataset to balance. - num_train_images (int): Total number of training images to use. - filtered_classes (list): List of class labels to filter. - - Returns: - balanced_dataset (Dataset): Balanced dataset with equal images per class. - """ - print("Balancing dataset...") - class_counts = {label: 0 for label in filtered_classes} - for label in dataset["label"]: - class_counts[str(label)] += 1 - - print(f"Class counts: {class_counts}") # Debug print to verify counts - - min_class_size = min(class_counts.values()) - images_per_class = min(num_train_images // len(filtered_classes), min_class_size) - - np.random.seed(42) - balanced_indices = [] - for label in filtered_classes: - class_indices = [i for i, l in enumerate(dataset["label"]) if str(l) == label] - sampled_indices = np.random.choice(class_indices, images_per_class, replace=False) - balanced_indices.extend(sampled_indices) - - np.random.shuffle(balanced_indices) - return dataset.select(balanced_indices) - -def train_model(model, train_ds, val_ds, name, typ, resolution, batch_size, num_epochs, learning_rate, eval_steps, wandb_config): - """ - Train the model using the Hugging Face Trainer. - - Args: - model (torch.nn.Module): The model to train. - train_ds (Dataset): Training dataset. - val_ds (Dataset): Validation dataset. - name (str): Model name. - typ (str): Model type. - resolution (int): Image resolution. - batch_size (int): Batch size. - num_epochs (int): Number of epochs. - learning_rate (float): Learning rate. - eval_steps (int): Evaluation steps. - wandb_config (dict): Configuration for wandb logging. - - Returns: - dict: Training results and metrics. - """ - wandb.init( - entity="ericcui-use-stanford-university", - project="CS231N Test", - name=f"{name}_{resolution}_{num_epochs}_epochs_finetune", - config={**wandb_config, "model_name": name, "model_type": typ}, - tags=["baseline", "model-comparison", "finetune", name, f"res_{resolution}"], - reinit=True, - ) - - train_args = TrainingArguments( - output_dir=os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}"), - num_train_epochs=num_epochs, - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - learning_rate=learning_rate, - lr_scheduler_type="cosine", - weight_decay=0.01, - logging_dir=os.path.join(env_path("LOG_DIR", "."), f"{name}"), - logging_steps=1, - eval_strategy="steps", - eval_steps=eval_steps, - save_strategy="steps", - save_steps=eval_steps, - load_best_model_at_end=False, - metric_for_best_model="accuracy", - save_total_limit=1, - ) - - trainer = Trainer( - model=model, - args=train_args, - train_dataset=train_ds, - eval_dataset=val_ds, - compute_metrics=lambda pred: compute_metrics(pred, name), - callbacks=[ - LossLoggerCallback( - log_dir=env_path("LOG_DIR", "./logs"), - phase="finetune", - model_name=name, - ), - WandbCallback(name, "finetune"), - ], - ) - - trainer.train() - eval_results = trainer.evaluate() - - wandb.finish() - return eval_results - -def prepare_datasets(dataset, preprocessor, resolution, proportion_per_transform): - """ - Prepare training and validation datasets with optional preprocessing. - - Args: - dataset: The balanced HuggingFace dataset. - preprocessor: Preprocessing function or None. - resolution: Image resolution. - proportion_per_transform: Proportion for each transform. - - Returns: - train_ds, val_ds: Torch-compatible datasets for training and validation. - """ - # Split dataset into train and validation (80/20 split) - train_size = int(0.8 * len(dataset)) - val_size = len(dataset) - train_size - train_dataset = dataset.select(range(train_size)) - val_dataset = dataset.select(range(train_size, train_size + val_size)) - - # Define basic transform - transform = transforms.Compose([ - transforms.Resize((resolution, resolution)), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ]) - - class TorchDataset(Dataset): - def __init__(self, hf_dataset, transform): - self.hf_dataset = hf_dataset - self.transform = transform - - def __len__(self): - return len(self.hf_dataset) - - def __getitem__(self, idx): - item = self.hf_dataset[idx] - image = Image.open(io.BytesIO(item["image"])).convert("RGB") - if self.transform: - image = self.transform(image) - label = int(item["label"]) - return {"pixel_values": image, "labels": label} - - train_ds = TorchDataset(train_dataset, transform) - val_ds = TorchDataset(val_dataset, transform) - return train_ds, val_ds - -def main(num_train_images=100, proportion_per_transform=0.2, resolution=224, batch_size=256, num_epochs=3, eval_steps=10, learning_rate=1e-4): - wandb_config = { - "num_train_images": num_train_images, - "proportion_per_transform": proportion_per_transform, - "resolution": resolution, - "batch_size": batch_size, - "num_epochs": num_epochs, - "eval_steps": eval_steps, - "weight_decay": 0.01, - "learning_rate": learning_rate, - "gpu_available": GPU_AVAILABLE, - } - - models = [ - {"name": "vit", "model_id": "google/vit-base-patch16-224", "type": "vit", "config": { - "image_size": resolution, - "num_labels": NUM_FILTERED_CLASSES, - "ignore_mismatched_sizes": True - }}, - ] - - dataset = load_dataset( - "MKZuziak/ISIC_2019_224", - cache_dir=os.environ["HF_DATASETS_CACHE"], - split="train", - ) - - dataset = balance_dataset(dataset, num_train_images, FILTERED_CLASSES) - - for model_info in models: - model, preprocessor = initialize_model_and_preprocessor(model_info, resolution) - - # Prepare datasets - train_ds, val_ds = prepare_datasets(dataset, preprocessor, resolution, proportion_per_transform) - - # Train the model - results = train_model( - model, train_ds, val_ds, model_info["name"], model_info["type"], - resolution, batch_size, num_epochs, learning_rate, eval_steps, wandb_config - ) - - print(f"Results for {model_info['name']}: {results}") diff --git a/src/models/model_comparison_not_used.py b/src/models/model_comparison_not_used.py deleted file mode 100644 index d4a8639..0000000 --- a/src/models/model_comparison_not_used.py +++ /dev/null @@ -1,476 +0,0 @@ -''' -This script is a baseline for comparing different image classification models -at three different image compression levels, in comparison to the original. -It has a set number of augmentation transforms and does NOT combine them. -This does NOT experiment on JPEG compression levels -''' - -# Environment Setup -import os - -# Standard Libraries -import io -import json -import random -import time - -# Scientific & Visualization Libraries -import numpy as np -import matplotlib.pyplot as plt -import seaborn as sns -from PIL import Image - -# PyTorch & Torchvision -import torch -import torch.nn as nn -from torch.utils.data import Dataset, Subset, ConcatDataset -from torchvision import transforms - -# Hugging Face Transformers & Datasets -from transformers import ( - AutoImageProcessor, - AutoModelForImageClassification, - Trainer, - TrainerCallback, - TrainingArguments, - ViTFeatureExtractor, - ViTForImageClassification, -) -from datasets import load_dataset, ClassLabel - -# Weights & Biases -import wandb - -# Metrics -from sklearn.metrics import ( - accuracy_score, - confusion_matrix, - f1_score, - roc_auc_score, -) - -# Model Profiling & Vision Backbones -import timm -from thop import profile - -# Local Application Imports -from utils.constants import HF_MODELS, SSL_MODEL, SIMCLR_BACKBONE, NUM_CLASSES, FILTERED_CLASSES, NUM_FILTERED_CLASSES -from src.models.utils.transforms import ( - JPEGCompressionTransform, - GaussianBlurTransform, - ColorQuantizationTransform, -) -from utils.util_classes import ( - ISICDataset, - SimCLRForClassification, - LossLoggerCallback, -) -from utils.util_methods import ( - env_path, - compute_metrics, - get_gpu_memory, - freeze_backbone, -) - -# GPU Memory Monitoring (optional) -try: - import pynvml - pynvml.nvmlInit() - GPU_AVAILABLE = True -except ImportError: - GPU_AVAILABLE = False - print("pynvml not installed, GPU memory monitoring disabled.") - -# Cache paths -os.environ["TRANSFORMERS_CACHE"] = os.getenv( - "TRANSFORMERS_CACHE", "~/.cache/huggingface/transformers" -) -os.environ["HF_DATASETS_CACHE"] = os.getenv( - "HF_DATASETS_CACHE", "~/.cache/huggingface/datasets" -) -os.environ["HF_HOME"] = os.getenv("HF_HOME", "~/.cache/huggingface") - -class WandbCallback(TrainerCallback): - def __init__(self, model_name, phase): - self.model_name = model_name - self.phase = phase - self.best_accuracy = 0.0 - - def on_log(self, args, state, control, logs=None, **kwargs): - if logs is not None: - # Add model name and phase to logs - logs["model"] = self.model_name - logs["phase"] = self.phase - - # Track GPU memory if available - if GPU_AVAILABLE: - logs["gpu_memory_mb"] = get_gpu_memory() - - # Log to wandb - wandb.log(logs) - - def on_evaluate(self, args, state, control, metrics=None, **kwargs): - if metrics is not None: - # Track best accuracy - if "eval_accuracy" in metrics: - self.best_accuracy = max(self.best_accuracy, metrics["eval_accuracy"]) - metrics["best_accuracy"] = self.best_accuracy - - # Log evaluation metrics - wandb.log(metrics) - -def main(num_train_images=25000, proportion_per_transform=0.2, resolution=224): - - # Simple batch size configuration - batch_size = 256 - - # Initialize wandb config - wandb_config = { - "num_train_images": num_train_images, - "proportion_per_transform": proportion_per_transform, - "resolution": resolution, - "batch_size": batch_size, - "num_epochs": 3, - "warmup_steps": 500, - "weight_decay": 0.01, - "gpu_available": GPU_AVAILABLE, - } - - models = [ - {"name": "vit", "model_id": "google/vit-base-patch16-224", "type": "vit"}, - {"name": "dinov2", "model_id": "facebook/dinov2-base", "type": "dinov2"}, - {"name": "simclr", "model_id": "resnet50", "type": "simclr"}, - ] - - results = {m["name"]: {} for m in models} - results_linear_probe = {m["name"]: {} for m in models} - - dataset = load_dataset( - "MKZuziak/ISIC_2019_224", - cache_dir=os.environ["HF_DATASETS_CACHE"], - split="train", - ) - - print(f"Initial dataset size: {len(dataset)} images") - - # Get indices of images with desired labels - filtered_indices = [ - i for i, label in enumerate(dataset["label"]) - if str(label) in FILTERED_CLASSES # Convert to string for comparison - ] - - # Select only those indices - dataset = dataset.select(filtered_indices) - print(f"Number of images after filtering for classes {FILTERED_CLASSES}: {len(dataset)}") - dataset = dataset.cast_column("label", ClassLabel(num_classes=NUM_FILTERED_CLASSES)) - - # Get class counts and balance dataset - optimized version - print("Balancing dataset...") - # Get counts for each class - class_counts = {label: 0 for label in FILTERED_CLASSES} - for label in dataset["label"]: - class_counts[str(label)] += 1 # Convert to string for dictionary key - - print(f"Class counts: {class_counts}") # Debug print to verify counts - - # Calculate how many images to use per class - min_class_size = min(class_counts.values()) - images_per_class = min(num_train_images // 2, min_class_size) - - # Sample indices for each class - np.random.seed(42) - balanced_indices = [] - for label in FILTERED_CLASSES: - class_indices = [i for i, l in enumerate(dataset["label"]) if str(l) == label] # Convert to string for comparison - print(f"Found {len(class_indices)} images for class {label}") # Debug print - sampled_indices = np.random.choice(class_indices, images_per_class, replace=False) - balanced_indices.extend(sampled_indices) - - np.random.shuffle(balanced_indices) - balanced_dataset = dataset.select(balanced_indices) - - # Split into train and validation - full_dataset = balanced_dataset.train_test_split( - test_size=0.2, stratify_by_column="label", seed=42 - ) - - train_dataset, val_dataset = full_dataset["train"], full_dataset["test"] - - degradation_transforms = [ - JPEGCompressionTransform(), - GaussianBlurTransform(), - ColorQuantizationTransform(), - ] - - num_transforms = len(degradation_transforms) - num_images = len(train_dataset) - images_per_transform = int(num_images * proportion_per_transform) - - transformed_datasets = [] - indices = np.arange(num_images) - np.random.shuffle(indices) - - used_indices = [] - for i, transform in enumerate(degradation_transforms): - subset_indices = indices[i * images_per_transform:(i + 1) * images_per_transform] - used_indices.extend(subset_indices) - subset = Subset(train_dataset, subset_indices) - transform_compose = transforms.Compose([transform]) - - for model_info in models: - name, model_id, typ = ( - model_info["name"], - model_info["model_id"], - model_info["type"], - ) - if typ == "vit": - preprocessor = ViTFeatureExtractor.from_pretrained(model_id, size=resolution) - elif typ == "dinov2": - preprocessor = AutoImageProcessor.from_pretrained(model_id, size=resolution) - else: - preprocessor = None - - transformed_ds = ISICDataset(subset, preprocessor, resolution, transform_compose, typ) - transformed_datasets.append(transformed_ds) - - remaining_indices = np.setdiff1d(indices, used_indices) - - if len(remaining_indices) > 0: - remaining_subset = Subset(train_dataset, remaining_indices) - for model_info in models: - name, model_id, typ = ( - model_info["name"], - model_info["model_id"], - model_info["type"], - ) - if typ == "vit": - preprocessor = ViTFeatureExtractor.from_pretrained(model_id, size=resolution) - elif typ == "dinov2": - preprocessor = AutoImageProcessor.from_pretrained(model_id, size=resolution) - else: - preprocessor = None - - # No transform applied to remaining indices - untransformed_ds = ISICDataset(remaining_subset, preprocessor, resolution, None, typ) - transformed_datasets.append(untransformed_ds) - - train_ds = ConcatDataset(transformed_datasets) - - val_ds = ISICDataset( - val_dataset, - preprocessor, - resolution, - model_type=typ, - ) - - for model_info in models: - name, model_id, typ = ( - model_info["name"], - model_info["model_id"], - model_info["type"], - ) - - # Initialize wandb for this specific model run - wandb.init( - entity="ericcui-use-stanford-university", - project="CS231N Test", - name=f"{name}_{resolution}_finetune", - config=wandb_config, - tags=["baseline", "model-comparison", "finetune", name], - reinit=True - ) - - if typ == "vit": - preprocessor = ViTFeatureExtractor.from_pretrained( - model_id, size=resolution - ) - elif typ == "dinov2": - preprocessor = AutoImageProcessor.from_pretrained(model_id, size=resolution) - else: - preprocessor = None - - if typ == "vit": - model = ViTForImageClassification.from_pretrained( - model_id, num_labels=NUM_FILTERED_CLASSES, ignore_mismatched_sizes=True - ) - elif typ == "dinov2": - model = AutoModelForImageClassification.from_pretrained( - model_id, num_labels=NUM_FILTERED_CLASSES, ignore_mismatched_sizes=True - ) - elif typ == SSL_MODEL: - # Load pretrained ResNet50 backbone - backbone = timm.create_model( - SIMCLR_BACKBONE, - pretrained=True, - num_classes=0 # Remove classification head - ) - # Create SimCLR model with the backbone - model = SimCLRForClassification(backbone, NUM_FILTERED_CLASSES) - # Freeze the backbone initially - freeze_backbone(model, SSL_MODEL) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - - try: - dummy_input = torch.randn(1, 3, resolution, resolution).to(device) - model.to(device) - flops, _ = profile(model, inputs=(dummy_input,)) - flops /= 1e9 - except Exception as e: - print(f"FLOP profiling failed: {e}") - flops = -1 - - train_args = TrainingArguments( - output_dir=os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}"), - num_train_epochs=3, - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - warmup_steps=500, - weight_decay=0.01, - logging_dir=os.path.join(env_path("LOG_DIR", "."), f"{name}"), - logging_steps=1, - eval_strategy="epoch", - save_strategy="epoch", - load_best_model_at_end=True, - metric_for_best_model="accuracy", - save_total_limit=1, # Only keep the best model - save_safetensors=False, # Use PyTorch format instead of safetensors - hub_model_id=None, # Don't push to hub - hub_strategy="end", # Only push at the end if needed - push_to_hub=False, # Don't push to hub - save_only_model=True, # Don't save optimizer state - ) - - # Clean up old model directories before training - model_dirs = [ - os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}"), - os.path.join(env_path("MODEL_DIR", "."), f"{name}"), - os.path.join(env_path("LOG_DIR", "."), f"{name}"), - ] - - for dir_path in model_dirs: - if os.path.exists(dir_path): - print(f"Cleaning up directory: {dir_path}") - import shutil - shutil.rmtree(dir_path) - os.makedirs(dir_path, exist_ok=True) - - # Monitor disk space before training - import shutil - total, used, free = shutil.disk_usage("/") - print(f"Disk space before training {name}:") - print(f"Total: {total // (2**30)} GB") - print(f"Used: {used // (2**30)} GB") - print(f"Free: {free // (2**30)} GB") - - # If less than 1GB free, raise an error - if free < 1 * (2**30): # 1GB in bytes - raise RuntimeError("Not enough disk space to train model. Please free up at least 1GB of space.") - - trainer = Trainer( - model=model, - args=train_args, - train_dataset=train_ds, - eval_dataset=val_ds, - compute_metrics=lambda pred: compute_metrics(pred, name), - callbacks=[ - LossLoggerCallback( - log_dir=env_path("LOG_DIR", "./logs"), - phase="finetune", - model_name=name, - ), - WandbCallback(name, "finetune"), - ], - ) - - # ---- TRAINING PHASE ---- - start_time = time.time() - peak_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - - # Log model architecture - if typ in HF_MODELS: - wandb.watch(model, log="all", log_freq=100) - elif typ == SSL_MODEL: - wandb.watch(model.backbone, log="all", log_freq=100) - - trainer.train() - - current_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - peak_memory = max(peak_memory, current_memory) - - eval_start_time = time.time() - eval_results = trainer.evaluate() - eval_time = time.time() - eval_start_time - train_time = time.time() - start_time - eval_time - - # Log model-specific metrics - model_metrics = { - "model_name": name, - "model_type": typ, - "peak_memory_mb": peak_memory, - "flops_giga": flops, - "train_time_seconds": train_time, - "eval_time_seconds": eval_time, - "eval_metrics": eval_results, - } - wandb.log(model_metrics) - - model_dir = os.path.join( - env_path("MODEL_DIR", "."), f"{name}" - ) - os.makedirs(model_dir, exist_ok=True) - - if typ in HF_MODELS: - model.save_pretrained(model_dir) - preprocessor.save_pretrained(model_dir) - elif typ == SSL_MODEL: - torch.save( - model.state_dict(), os.path.join(model_dir, "pytorch_model.bin") - ) - with open(os.path.join(model_dir, "config.json"), "w") as f: - json.dump( - { - "model_type": SSL_MODEL, - "backbone": "resnet50", - "num_classes": NUM_FILTERED_CLASSES, - }, - f, - ) - - # Save model as wandb artifact - artifact = wandb.Artifact( - name=f"{name}_model", - type="model", - description=f"Trained {name} model with {typ} architecture" - ) - artifact.add_dir(model_dir) - wandb.log_artifact(artifact) - - results[name] = model_metrics - - print(f"[Finetune] {name}: {results[name]}") - - # Close the wandb run for this model - wandb.finish() - - # Remove the final wandb.finish() since we're now closing each run individually - with open( - os.path.join( - env_path("TRAIN_OUTPUT_DIR", "."), "results_metrics_finetune.json" - ), - "w", - ) as f: - json.dump(results, f, indent=4) - - with open( - os.path.join( - env_path("TRAIN_OUTPUT_DIR", "."), "results_metrics_linear_probe.json" - ), - "w", - ) as f: - json.dump(results_linear_probe, f, indent=4) - - -if __name__ == "__main__": - main() From 665f18a9d20ac1e6470c3b4ab72bc3fece1452b4 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 11 Jul 2025 14:36:32 -0700 Subject: [PATCH 03/26] major refactoring of model comparison scripts --- .DS_Store | Bin 6148 -> 0 bytes .github/workflows/build-and-test.yml | 45 + .github/workflows/pull_request.yml | 37 + .gitignore | 6 + .reuse/dep5.txt | 11 + README.md | 8 +- configs/example_config.yaml.license | 5 + docs/pipeline.md | 6 + jobs/job_template.slurm | 6 + jobs/run.sh | 6 + jobs/run_lr.sh | 6 + jobs/run_models.sh | 6 + CS231N Poster.png => media/CS231N Poster.png | Bin requirements.txt | 16 +- requirements.txt.licence | 1 - requirements.txt.license | 5 + scripts/download_unpack_isic2019.sh | 6 + scripts/submit_from_config.sh | 6 + src/__init__.py | 3 - src/evaluation/evaluate_isic_results.py | 31 +- src/models/.DS_Store | Bin 6148 -> 0 bytes src/models/__init__.py | 3 - src/models/model_comparison_baseline.py | 222 +++-- src/models/model_comparison_lr_sweep.py | 777 ++++++++++-------- src/models/utils/__init__.py | 3 - src/models/utils/constants.py | 8 + src/models/utils/transforms.py | 48 +- src/models/utils/utils_classes.py | 68 +- .../{util_methods.py => utils_methods.py} | 31 +- src/requirements.txt | Bin 1716 -> 0 bytes 30 files changed, 894 insertions(+), 476 deletions(-) delete mode 100644 .DS_Store create mode 100644 .github/workflows/build-and-test.yml create mode 100644 .github/workflows/pull_request.yml create mode 100644 .reuse/dep5.txt create mode 100644 configs/example_config.yaml.license rename CS231N Poster.png => media/CS231N Poster.png (100%) delete mode 100644 requirements.txt.licence create mode 100644 requirements.txt.license delete mode 100644 src/models/.DS_Store rename src/models/utils/{util_methods.py => utils_methods.py} (76%) delete mode 100644 src/requirements.txt diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 3af1e0fb5b2d9cf4def8daebc47418e1416a7038..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%}(4f5dI7+Ev1+3jSEuL2UrPu>MdP%4@iJ2_5l`cps-c;2NERIQ>4B}`%pXp zZ@`sf`DScT;stRkLTE;^KgXH*JpQcMF#y-QJl+C20M=LpTkEWzFnKR!$yP$oN1{`X z2~JV$=JRZvFL_&sDPRh`H3j6`Z6m`RGyFj9`<;*ZMz)3DEPwy0u?2j>4u;sq7i{4G z3BEJ7K#3Dvpo;~0PW=w>$Z#7mwD6NL=ljzBVC~x*-H8dcBy0`YZozJN6epJWX|a^BI6ZT zP*!sCsKzs6xSME|s7X>%_Zwm0t7F&}&z|0gd z1x$gu0;%Y_VI7ZTV@Cm?U5tu=2<^H0MK!KGcOPhV$X zu=41`;lky^g^^vjp*S6#{E525r5>#|1x$ga0;~S9Bj^9%`TKt}$ev6AQ{Y`G;9C7* zzsD(svvp~5a@Hm+M=WCES9#PBR^d3-h8)EkENXn0NQ0OLtUR)ZW diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml new file mode 100644 index 0000000..23d3d99 --- /dev/null +++ b/.github/workflows/build-and-test.yml @@ -0,0 +1,45 @@ +# +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# + +name: Build and Test + +on: + push: + branches: + - main + pull_request: + workflow_dispatch: + workflow_call: + +jobs: + pylint: + name: PyLint + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + - name: Install Infrastructure + run: | + pip install -r requirements.txt + pip install pylint + - name: Analysing the code with pylint + run: | + pylint $(git ls-files '*.py') + black_lint: + name: Black Code Formatter Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + - name: Install Black + run: pip install black[jupyter] + - name: Check code formatting with Black + run: black . --exclude '\.ipynb$' diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml new file mode 100644 index 0000000..c74fa09 --- /dev/null +++ b/.github/workflows/pull_request.yml @@ -0,0 +1,37 @@ +# +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# + +name: Pull Request + +on: + pull_request: + workflow_dispatch: + +jobs: + reuse_action: + name: REUSE Compliance Check + uses: DaneshjouLab/.github/.github/workflows/reuse.yml@main + markdown_link_check: + name: Markdown Link Check + uses: DaneshjouLab/.github/.github/workflows/markdown-link-check.yml@main + yamllint: + name: YAML Lint Check + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + + - name: Install yamllint + run: pip install yamllint + + - name: Run yamllint with custom config + run: yamllint -c .yamllint .github/workflows/*.yml diff --git a/.gitignore b/.gitignore index b7faf40..f80baca 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # Byte-compiled / optimized / DLL files __pycache__/ *.py[codz] diff --git a/.reuse/dep5.txt b/.reuse/dep5.txt new file mode 100644 index 0000000..6f0048b --- /dev/null +++ b/.reuse/dep5.txt @@ -0,0 +1,11 @@ +Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ + +Files: media/*.png +Copyright: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) +License: MIT +Comment: All files are part of the Daneshjou Lab projects. + +Files: results/*.json +Copyright: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) +License: MIT +Comment: All files are part of the Daneshjou Lab projects. \ No newline at end of file diff --git a/README.md b/README.md index ce54a4a..0e42e05 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,14 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # Finetuning Pretrained Models for Compressed Dermatology Image Analysis This project explores how compressed and degraded dermatology images (from the ISIC 2019 dataset) affect classification performance using pretrained vision models. It compares fine-tuning vs. linear probing across multiple JPEG quality levels. -![System architecture diagram](<./CS231N Poster.png>) +![System architecture diagram](<./media/CS231N Poster.png>) ## Project Goals diff --git a/configs/example_config.yaml.license b/configs/example_config.yaml.license new file mode 100644 index 0000000..3cc951b --- /dev/null +++ b/configs/example_config.yaml.license @@ -0,0 +1,5 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT \ No newline at end of file diff --git a/docs/pipeline.md b/docs/pipeline.md index 772baa5..da4efbf 100644 --- a/docs/pipeline.md +++ b/docs/pipeline.md @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # Model Comparison Pipeline ## Overview diff --git a/jobs/job_template.slurm b/jobs/job_template.slurm index 8f442b8..b5aae84 100644 --- a/jobs/job_template.slurm +++ b/jobs/job_template.slurm @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + #!/bin/bash #SBATCH --job-name={{JOB_NAME}} #SBATCH --partition={{PARTITION}} diff --git a/jobs/run.sh b/jobs/run.sh index 1adbfa0..598137f 100644 --- a/jobs/run.sh +++ b/jobs/run.sh @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + #!/bin/bash #SBATCH --job-name=231n_job diff --git a/jobs/run_lr.sh b/jobs/run_lr.sh index 8546a74..a8951f7 100644 --- a/jobs/run_lr.sh +++ b/jobs/run_lr.sh @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + #!/bin/bash #SBATCH --job-name=231n_job diff --git a/jobs/run_models.sh b/jobs/run_models.sh index bd5b663..9cbb3f1 100644 --- a/jobs/run_models.sh +++ b/jobs/run_models.sh @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + #!/bin/bash #SBATCH --job-name=231n_job diff --git a/CS231N Poster.png b/media/CS231N Poster.png similarity index 100% rename from CS231N Poster.png rename to media/CS231N Poster.png diff --git a/requirements.txt b/requirements.txt index 0acb2d9..b93e56d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ +accelerate annotated-types anyio asttokens -accelerate beautifulsoup4 biopython certifi @@ -62,20 +62,18 @@ six sniffio soupsieve stack-data -timm thop -threadpoolctl -tornado +threadpoolctladpoolctl +timm torch -torchvision -transformers +torchvisionvision tqdm traitlets -typing-inspection +tritonon +typing-inspectionspection typing_extensions -triton tzdata -urllib3 +urllib33 wandb wcwidth wheel diff --git a/requirements.txt.licence b/requirements.txt.licence deleted file mode 100644 index 176ed16..0000000 --- a/requirements.txt.licence +++ /dev/null @@ -1 +0,0 @@ -# MIT \ No newline at end of file diff --git a/requirements.txt.license b/requirements.txt.license new file mode 100644 index 0000000..3cc951b --- /dev/null +++ b/requirements.txt.license @@ -0,0 +1,5 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT \ No newline at end of file diff --git a/scripts/download_unpack_isic2019.sh b/scripts/download_unpack_isic2019.sh index f584dbf..7825179 100644 --- a/scripts/download_unpack_isic2019.sh +++ b/scripts/download_unpack_isic2019.sh @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + #!/bin/bash set -e diff --git a/scripts/submit_from_config.sh b/scripts/submit_from_config.sh index ba7a062..d60b99d 100644 --- a/scripts/submit_from_config.sh +++ b/scripts/submit_from_config.sh @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + #!/bin/bash set -e diff --git a/src/__init__.py b/src/__init__.py index 1a66f1d..e69de29 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,3 +0,0 @@ -""" -Root package for the project. -""" diff --git a/src/evaluation/evaluate_isic_results.py b/src/evaluation/evaluate_isic_results.py index b56e6d0..24c6eeb 100644 --- a/src/evaluation/evaluate_isic_results.py +++ b/src/evaluation/evaluate_isic_results.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + """ Evaluates results from ISIC 2019 fine-tuning and linear probing experiments. Conducts paired t-tests to assess statistical significance across resolutions @@ -59,20 +65,19 @@ def paired_t_tests(metrics, models, qualities, modes): metric_names = ["accuracy", "f1", "auc"] def run_if_valid(k1, k2, label): - if ( - k1 in metrics and k2 in metrics and - len(metrics[k1]) == len(metrics[k2]) > 1 - ): + if k1 in metrics and k2 in metrics and len(metrics[k1]) == len(metrics[k2]) > 1: for metric in metric_names: try: v1 = [r[metric] for r in metrics[k1]] v2 = [r[metric] for r in metrics[k2]] stat, p = ttest_rel(v1, v2) - t_test_results.append({ - "comparison": label.format(metric=metric), - "statistic": stat, - "p_value": p, - }) + t_test_results.append( + { + "comparison": label.format(metric=metric), + "statistic": stat, + "p_value": p, + } + ) except Exception as e: print(f"Error in t-test for {label.format(metric=metric)}: {e}") @@ -87,7 +92,7 @@ def run_if_valid(k1, k2, label): run_if_valid( (model, q1, mode), (model, q2, mode), - f"{model}_{mode}" + "_{{metric}}_jpeg{q1}_vs_jpeg{q2}" + f"{model}_{mode}" + "_{{metric}}_jpeg{q1}_vs_jpeg{q2}", ) # 2. Across models: same quality & mode @@ -96,7 +101,7 @@ def run_if_valid(k1, k2, label): run_if_valid( (m1, quality, mode), (m2, quality, mode), - f"{m1}_vs_{m2}_{mode}" + "_{{metric}}_jpeg{quality}" + f"{m1}_vs_{m2}_{mode}" + "_{{metric}}_jpeg{quality}", ) # 3. Finetune vs. Linear Probe: same model & quality @@ -104,7 +109,7 @@ def run_if_valid(k1, k2, label): run_if_valid( (model, quality, "finetune"), (model, quality, "linear_probe"), - f"{model}_finetune_vs_linear_probe" + "_{{metric}}_jpeg{quality}" + f"{model}_finetune_vs_linear_probe" + "_{{metric}}_jpeg{quality}", ) return t_test_results @@ -115,7 +120,7 @@ def summarize_performance(metrics): Summarizes performance metrics (mean, std) across models, qualities, and modes. """ summary = [] - for (model, quality, mode), runs in metrics.items(): + for (model, _, mode), runs in metrics.items(): if not runs: continue diff --git a/src/models/.DS_Store b/src/models/.DS_Store deleted file mode 100644 index 26c154e2727ff8c63d16b48d252a6395506fc4d3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKI|>3Z5S{S@f{mqRuHX%V=n3`$7J>+(;IH1wb9pr1e41sk(?WRzlb1~9CFB)5 zJ0haX+jb!`6OjqrP#!k)&GyZEHpqwq;W*=RZ_dZV>A36Vz6%(4EH}BzUJf0;?a-(I z6`%rCfC^B7Pb-iWb~63+!90%&P=TLUz`hR!ZdeoBK>u`L@D>0#Lf8#+?GNs z1|kB}paO%c*+Nm*NakF#1^;2XH*JmF@TI|YN6W1yE~EUX;QJt^{v&9Pq- U+d!uy?sOo3222+k75KISF9lB&>Hq)$ diff --git a/src/models/__init__.py b/src/models/__init__.py index 8a9e4f5..e69de29 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,3 +0,0 @@ -""" -Models package containing model implementations and utilities. -""" \ No newline at end of file diff --git a/src/models/model_comparison_baseline.py b/src/models/model_comparison_baseline.py index 89da7a4..9de0be4 100644 --- a/src/models/model_comparison_baseline.py +++ b/src/models/model_comparison_baseline.py @@ -1,9 +1,16 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + """ This script is a baseline for comparing different image classification models at three different image compression levels, in comparison to the original. It supports fine-tuning, linear probing, and optional degradation transforms. -Purpose: General pipeline for comparing image classification models (ViT, DINOv2, SimCLR) with optional image degradation transforms. +Purpose: General pipeline for comparing image classification models (ViT, DINOv2, SimCLR) with +optional image degradation transforms. Features: - Supports fine-tuning and linear probing. - Balances dataset across classes. @@ -16,40 +23,35 @@ # Environment Setup import os import io -import json import numpy as np from PIL import Image - -# PyTorch & Torchvision -import torch -import torch.nn as nn from torch.utils.data import Dataset from torchvision import transforms - -# Hugging Face Transformers & Datasets from transformers import ( AutoImageProcessor, AutoModelForImageClassification, - Trainer, - TrainerCallback, - TrainingArguments, ViTFeatureExtractor, ViTForImageClassification, + Trainer, + TrainingArguments, ) from datasets import load_dataset - -# Weights & Biases import wandb - -# Model Profiling & Vision Backbones import timm -from thop import profile -# Local Application Imports from utils.constants import SSL_MODEL, SIMCLR_BACKBONE, FILTERED_CLASSES, NUM_FILTERED_CLASSES -from utils.transforms import JPEGCompressionTransform, GaussianBlurTransform, ColorQuantizationTransform +from utils.transforms import ( + JPEGCompressionTransform, GaussianBlurTransform, +) from utils.util_classes import SimCLRForClassification, LossLoggerCallback -from utils.util_methods import env_path, compute_metrics, get_gpu_memory, freeze_backbone +from utils.util_methods import get_gpu_memory, GPU_AVAILABLE, freeze_backbone +from transforms import ColorQuantizationTransform + +# Compatibility for LANCZOS resampling +try: + LANCZOS = Image.Resampling.LANCZOS +except AttributeError: + LANCZOS = Image.LANCZOS # pylint: disable=no-member # GPU Memory Monitoring (optional) try: @@ -60,19 +62,16 @@ GPU_AVAILABLE = False print("pynvml not installed, GPU memory monitoring disabled.") -# Cache paths -os.environ["TRANSFORMERS_CACHE"] = os.getenv("TRANSFORMERS_CACHE", "~/.cache/huggingface/transformers") -os.environ["HF_DATASETS_CACHE"] = os.getenv("HF_DATASETS_CACHE", "~/.cache/huggingface/datasets") -os.environ["HF_HOME"] = os.getenv("HF_HOME", "~/.cache/huggingface") - - -class WandbCallback(TrainerCallback): +class WandbCallback: + """Callback for logging to Weights & Biases.""" def __init__(self, model_name, phase): self.model_name = model_name self.phase = phase self.best_accuracy = 0.0 - def on_log(self, args, state, control, logs=None, **kwargs): + def on_log(self, _args, _state, _control, logs=None, **_kwargs): + """ + Log metrics to Weights & Biases.""" if logs is not None: logs["model"] = self.model_name logs["phase"] = self.phase @@ -80,7 +79,8 @@ def on_log(self, args, state, control, logs=None, **kwargs): logs["gpu_memory_mb"] = get_gpu_memory() wandb.log(logs) - def on_evaluate(self, args, state, control, metrics=None, **kwargs): + def on_evaluate(self, _args, _state, _control, metrics=None, **_kwargs): + """Log evaluation metrics to Weights & Biases.""" if metrics is not None: if "eval_accuracy" in metrics: self.best_accuracy = max(self.best_accuracy, metrics["eval_accuracy"]) @@ -89,7 +89,10 @@ def on_evaluate(self, args, state, control, metrics=None, **kwargs): def initialize_model_and_preprocessor(model_info, resolution): - name, model_id, typ, config = ( + """ + Initialize model and preprocessor. + """ + _, model_id, typ, _ = ( model_info["name"], model_info["model_id"], model_info["type"], @@ -101,7 +104,7 @@ def initialize_model_and_preprocessor(model_info, resolution): model_id, size=resolution, do_resize=True, - resample=Image.LANCZOS, + resample=LANCZOS, do_normalize=True, image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225], @@ -117,7 +120,7 @@ def initialize_model_and_preprocessor(model_info, resolution): model_id, size=resolution, do_resize=True, - resample=Image.LANCZOS, + resample=LANCZOS, do_normalize=True, image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225], @@ -144,6 +147,9 @@ def initialize_model_and_preprocessor(model_info, resolution): def balance_dataset(dataset, num_train_images, filtered_classes): + """ + Balance the dataset by sampling equal images per class. + """ print("Balancing dataset...") class_counts = {label: 0 for label in filtered_classes} for label in dataset["label"]: @@ -163,7 +169,10 @@ def balance_dataset(dataset, num_train_images, filtered_classes): return dataset.select(balanced_indices) -def prepare_datasets(dataset, preprocessor, resolution, proportion_per_transform, apply_transforms=False): +def prepare_datasets(dataset, _preprocessor, resolution, apply_transforms=False): + """ + Prepare train and validation datasets. + """ train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_dataset = dataset.select(range(train_size)) @@ -184,6 +193,9 @@ def prepare_datasets(dataset, preprocessor, resolution, proportion_per_transform ]) class TorchDataset(Dataset): + """ + Custom dataset class for Hugging Face datasets. + """ def __init__(self, hf_dataset, transform): self.hf_dataset = hf_dataset self.transform = transform @@ -204,22 +216,107 @@ def __getitem__(self, idx): return train_ds, val_ds -def main(num_train_images=100, proportion_per_transform=0.2, resolution=224, batch_size=256, num_epochs=3, eval_steps=10, learning_rate=1e-4): - wandb_config = { - "num_train_images": num_train_images, - "proportion_per_transform": proportion_per_transform, - "resolution": resolution, - "batch_size": batch_size, - "num_epochs": num_epochs, - "eval_steps": eval_steps, - "weight_decay": 0.01, - "learning_rate": learning_rate, - "gpu_available": GPU_AVAILABLE, - } +def train_model( + model, + train_ds, + val_ds, + config +): + """ + Train the model using Hugging Face Trainer. + + Args: + model: The model to train. + train_ds: Training dataset. + val_ds: Validation dataset. + config (dict): Configuration dictionary with keys: + - model_name + - model_type + - resolution + - batch_size + - num_epochs + - learning_rate + - eval_steps + - wandb_config + + Returns: + dict: Evaluation results. + """ + wandb.init( + project="Model Comparison Baseline", + name=f"{config['model_name']}_{config['resolution']}_{config['num_epochs']}_epochs", + config=config['wandb_config'], + tags=["baseline", config['model_name'], f"res_{config['resolution']}"], + reinit=True, + ) + + train_args = TrainingArguments( + output_dir=f"./outputs/{config['model_name']}", + num_train_epochs=config['num_epochs'], + per_device_train_batch_size=config['batch_size'], + per_device_eval_batch_size=config['batch_size'], + learning_rate=config['learning_rate'], + lr_scheduler_type="cosine", + weight_decay=0.01, + logging_dir=f"./logs/{config['model_name']}", + logging_steps=1, + evaluation_strategy="steps", + eval_steps=config['eval_steps'], + save_strategy="steps", + save_steps=config['eval_steps'], + load_best_model_at_end=False, + metric_for_best_model="accuracy", + save_total_limit=1, + report_to=["wandb"], + ) + + trainer = Trainer( + model=model, + args=train_args, + train_dataset=train_ds, + eval_dataset=val_ds, + ) + + trainer.train() + eval_results = trainer.evaluate() + + wandb.finish() + return eval_results + + +def get_trainer_callbacks(name): + return [ + LossLoggerCallback( + log_dir=env_path("LOG_DIR", "./logs"), + phase="finetune", + model_name=name, + ), + WandbCallback(name, "finetune"), + ] + + +def main(config=None): + """ + Main pipeline for model comparison. + """ + if config is None: + config = { + "num_train_images": 100, + "proportion_per_transform": 0.2, + "resolution": 224, + "batch_size": 256, + "num_epochs": 3, + "eval_steps": 10, + "learning_rate": 1e-4, + "gpu_available": GPU_AVAILABLE, + } + + wandb_config = config.copy() + wandb_config["weight_decay"] = 0.01 models = [ {"name": "vit", "model_id": "google/vit-base-patch16-224", "type": "vit", "config": { - "image_size": resolution, + "image_size": config["resolution"], "num_labels": NUM_FILTERED_CLASSES, "ignore_mismatched_sizes": True }}, @@ -231,20 +328,41 @@ def main(num_train_images=100, proportion_per_transform=0.2, resolution=224, bat split="train", ) - dataset = balance_dataset(dataset, num_train_images, FILTERED_CLASSES) + dataset = balance_dataset(dataset, config["num_train_images"], FILTERED_CLASSES) for model_info in models: - model, preprocessor = initialize_model_and_preprocessor(model_info, resolution) + model, preprocessor = initialize_model_and_preprocessor(model_info, config["resolution"]) + train_ds, val_ds = prepare_datasets( + dataset, preprocessor, config["resolution"], apply_transforms=True + ) - train_ds, val_ds = prepare_datasets(dataset, preprocessor, resolution, proportion_per_transform, apply_transforms=True) + train_config = { + "model_name": model_info["name"], + "model_type": model_info["type"], + "resolution": config["resolution"], + "batch_size": config["batch_size"], + "num_epochs": config["num_epochs"], + "learning_rate": config["learning_rate"], + "eval_steps": config["eval_steps"], + "wandb_config": wandb_config, + } results = train_model( - model, train_ds, val_ds, model_info["name"], model_info["type"], - resolution, batch_size, num_epochs, learning_rate, eval_steps, wandb_config + model, train_ds, val_ds, train_config ) - print(f"Results for {model_info['name']}: {results}") + metrics = { + "learning_rate": config["learning_rate"], + "model_name": model_info["name"], + "model_type": model_info["type"], + "peak_memory_mb": get_gpu_memory(), + "flops_giga": None, # Placeholder for FLOPs, calculate if needed + "train_time_seconds": None, # Placeholder for training time, calculate if needed + "eval_time_seconds": None, # Placeholder for evaluation time, calculate if needed + "eval_metrics": results, + } + wandb.log({"metrics": metrics}) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/models/model_comparison_lr_sweep.py b/src/models/model_comparison_lr_sweep.py index 5185bbc..8983c91 100644 --- a/src/models/model_comparison_lr_sweep.py +++ b/src/models/model_comparison_lr_sweep.py @@ -1,4 +1,10 @@ -''' +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +""" This script is a baseline for comparing different image classification models at three different image compression levels, in comparison to the original. It has a set number of augmentation transforms and does NOT combine them. @@ -13,32 +19,24 @@ - Does not experiment with JPEG compression levels. - Logs results for each learning rate and saves them to a JSON file. - More flexible for hyperparameter tuning and ablation studies. - -''' +""" # Environment Setup import os -# Set memory optimization -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - -# Standard Libraries -import io +# Standard Library import json -import random import time import argparse +import shutil # Scientific & Visualization Libraries import numpy as np -import matplotlib.pyplot as plt -import seaborn as sns from PIL import Image # PyTorch & Torchvision import torch -import torch.nn as nn -from torch.utils.data import Dataset, Subset, ConcatDataset +from torch.utils.data import Subset, ConcatDataset from torchvision import transforms # Hugging Face Transformers & Datasets @@ -56,45 +54,35 @@ # Weights & Biases import wandb -# Metrics -from sklearn.metrics import ( - accuracy_score, - confusion_matrix, - f1_score, - roc_auc_score, -) - # Model Profiling & Vision Backbones import timm from thop import profile + # Local Application Imports -from utils.constants import HF_MODELS, SSL_MODEL, SIMCLR_BACKBONE, NUM_CLASSES, FILTERED_CLASSES, NUM_FILTERED_CLASSES -from utils.transforms_test import ( +from src.models.utils.constants import ( + HF_MODELS, SSL_MODEL, SIMCLR_BACKBONE, FILTERED_CLASSES, NUM_FILTERED_CLASSES +) +from src.models.utils.transforms import ( JPEGCompressionTransform, GaussianBlurTransform, ColorQuantizationTransform, ) -from utils.util_classes import ( +from src.models.utils.utils_classes import ( ISICDataset, SimCLRForClassification, - LossLoggerCallback, + LossLoggerCallback ) -from utils.util_methods import ( +from src.models.utils.utils_methods import ( env_path, compute_metrics, get_gpu_memory, freeze_backbone, + GPU_AVAILABLE ) -# GPU Memory Monitoring (optional) -try: - import pynvml - pynvml.nvmlInit() - GPU_AVAILABLE = True -except ImportError: - GPU_AVAILABLE = False - print("pynvml not installed, GPU memory monitoring disabled.") +# Set memory optimization +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Cache paths os.environ["TRANSFORMERS_CACHE"] = os.getenv( @@ -106,70 +94,65 @@ os.environ["HF_HOME"] = os.getenv("HF_HOME", "~/.cache/huggingface") class WandbCallback(TrainerCallback): + """ + Custom callback for logging metrics and evaluation results to Weights & Biases. + Tracks best accuracy and GPU memory usage if available. + """ def __init__(self, model_name, phase): + """ + Args: + model_name (str): Name of the model. + phase (str): Training phase (e.g., 'finetune'). + """ self.model_name = model_name self.phase = phase self.best_accuracy = 0.0 - def on_log(self, args, state, control, logs=None, **kwargs): + def on_log(self, _args, _state, _control, logs=None, **_kwargs): + """ + Log metrics to Weights & Biases. + + Args: + _args: Trainer arguments (unused). + _state: Trainer state (unused). + _control: Trainer control (unused). + logs (dict): Metrics to log. + **_kwargs: Additional keyword arguments (unused). + """ if logs is not None: - # Add model name and phase to logs logs["model"] = self.model_name logs["phase"] = self.phase - - # Track GPU memory if available if GPU_AVAILABLE: logs["gpu_memory_mb"] = get_gpu_memory() - - # Log to wandb wandb.log(logs) - def on_evaluate(self, args, state, control, metrics=None, **kwargs): + def on_evaluate(self, _args, _state, _control, metrics=None, **_kwargs): + """ + Log evaluation metrics to Weights & Biases. + + Args: + _args: Trainer arguments (unused). + _state: Trainer state (unused). + _control: Trainer control (unused). + metrics (dict): Evaluation metrics. + **_kwargs: Additional keyword arguments (unused). + """ if metrics is not None: - # Track best accuracy if "eval_accuracy" in metrics: self.best_accuracy = max(self.best_accuracy, metrics["eval_accuracy"]) metrics["best_accuracy"] = self.best_accuracy - - # Log evaluation metrics wandb.log(metrics) -def main(num_train_images=25000, proportion_per_transform=0.2, resolution=224, batch_size=256, num_epochs=3, eval_steps=10, learning_rate=1e-4): - - # Define the model to use for learning rate experiments - model_config = { - "name": "dinov2", - "model_id": "facebook/dinov2-base", - "type": "dinov2", - "config": { - "image_size": resolution, - "num_labels": NUM_FILTERED_CLASSES, - "ignore_mismatched_sizes": True - } - } - - # Define learning rates to test - learning_rates = [ - learning_rate - ] - - # Initialize results dictionary - results = {str(lr): {} for lr in learning_rates} - - dataset = load_dataset( - "MKZuziak/ISIC_2019_224", - cache_dir=os.environ["HF_DATASETS_CACHE"], - split="train", - ) - - print(f"Initial dataset size: {len(dataset)} images") - +def prepare_balanced_datasets(dataset, config): + """ + Filter, balance, and split the dataset into train and validation sets. + """ # Get indices of images with desired labels filtered_indices = [ i for i, label in enumerate(dataset["label"]) if str(label) in FILTERED_CLASSES # Convert to string for comparison ] - + # Select only those indices dataset = dataset.select(filtered_indices) print(f"Number of images after filtering for classes {FILTERED_CLASSES}: {len(dataset)}") @@ -181,22 +164,22 @@ def main(num_train_images=25000, proportion_per_transform=0.2, resolution=224, b class_counts = {label: 0 for label in FILTERED_CLASSES} for label in dataset["label"]: class_counts[str(label)] += 1 # Convert to string for dictionary key - + print(f"Class counts: {class_counts}") # Debug print to verify counts - + # Calculate how many images to use per class min_class_size = min(class_counts.values()) - images_per_class = min(num_train_images // 2, min_class_size) - + images_per_class = min(config["num_train_images"] // 2, min_class_size) + # Sample indices for each class np.random.seed(42) balanced_indices = [] for label in FILTERED_CLASSES: - class_indices = [i for i, l in enumerate(dataset["label"]) if str(l) == label] # Convert to string for comparison + class_indices = [i for i, l in enumerate(dataset["label"]) if str(l) == label] print(f"Found {len(class_indices)} images for class {label}") # Debug print sampled_indices = np.random.choice(class_indices, images_per_class, replace=False) balanced_indices.extend(sampled_indices) - + np.random.shuffle(balanced_indices) balanced_dataset = dataset.select(balanced_indices) @@ -206,21 +189,16 @@ def main(num_train_images=25000, proportion_per_transform=0.2, resolution=224, b ) train_dataset, val_dataset = full_dataset["train"], full_dataset["test"] - - degradation_transforms = [ - JPEGCompressionTransform(), - GaussianBlurTransform(), - ColorQuantizationTransform(), - ] - num_transforms = len(degradation_transforms) - num_images = len(train_dataset) - images_per_transform = int(num_images * proportion_per_transform) + return train_dataset, val_dataset - # Create a single preprocessor for each model type +def create_preprocessors(model_config, config): + """ + Create preprocessors for each model type. + """ preprocessors = {} for model_info in [model_config]: - name, model_id, typ, config = ( + _name, model_id, typ, _config = ( model_info["name"], model_info["model_id"], model_info["type"], @@ -229,9 +207,9 @@ def main(num_train_images=25000, proportion_per_transform=0.2, resolution=224, b if typ == "vit": preprocessors[typ] = ViTFeatureExtractor.from_pretrained( model_id, - size=resolution, + size=config["resolution"], do_resize=True, - resample=Image.LANCZOS, + resample=Image.LANCZOS, # pylint: disable=no-member do_normalize=True, image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225] @@ -239,9 +217,9 @@ def main(num_train_images=25000, proportion_per_transform=0.2, resolution=224, b elif typ == "dinov2": preprocessors[typ] = AutoImageProcessor.from_pretrained( model_id, - size=resolution, + size=config["resolution"], do_resize=True, - resample=Image.LANCZOS, + resample=Image.LANCZOS, # pylint: disable=no-member do_normalize=True, image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225] @@ -249,260 +227,335 @@ def main(num_train_images=25000, proportion_per_transform=0.2, resolution=224, b else: preprocessors[typ] = None - # Process each learning rate - for learning_rate in learning_rates: - print(f"\nTraining with learning rate: {learning_rate}") - - name, model_id, typ, config = ( - model_config["name"], - model_config["model_id"], - model_config["type"], - model_config["config"], - ) + return preprocessors + +def train_for_learning_rate( + learning_rate, model_config, train_dataset, val_dataset, config +): + """ + Train and evaluate the model for a given learning rate. + """ + preprocessors = config["preprocessors"] + name, model_id, typ, _config = ( + model_config["name"], + model_config["model_id"], + model_config["type"], + model_config["config"], + ) - transformed_datasets = [] - indices = np.arange(num_images) - np.random.shuffle(indices) - - used_indices = [] - for i, transform in enumerate(degradation_transforms): - subset_indices = indices[i * images_per_transform:(i + 1) * images_per_transform] - used_indices.extend(subset_indices) - subset = Subset(train_dataset, subset_indices) - transform_compose = transforms.Compose([transform]) - - transformed_ds = ISICDataset( - subset, - preprocessors[typ], - resolution, - transform_compose, - typ - ) - transformed_datasets.append(transformed_ds) - - remaining_indices = np.setdiff1d(indices, used_indices) - - if len(remaining_indices) > 0: - remaining_subset = Subset(train_dataset, remaining_indices) - untransformed_ds = ISICDataset( - remaining_subset, - preprocessors[typ], - resolution, - None, - typ - ) - transformed_datasets.append(untransformed_ds) - - train_ds = ConcatDataset(transformed_datasets) - val_ds = ISICDataset( - val_dataset, - preprocessors[typ], - resolution, - model_type=typ, - ) + train_ds = get_transformed_datasets(train_dataset, preprocessors, config, typ) + val_ds = ISICDataset( + val_dataset, + preprocessors[typ], + config["resolution"], + model_type=typ, + ) - if typ == "vit": - model = ViTForImageClassification.from_pretrained( - model_id, - num_labels=NUM_FILTERED_CLASSES, - ignore_mismatched_sizes=True, - image_size=resolution, - ) - elif typ == "dinov2": - model = AutoModelForImageClassification.from_pretrained( - model_id, - num_labels=NUM_FILTERED_CLASSES, - ignore_mismatched_sizes=True, - image_size=resolution - ) - elif typ == SSL_MODEL: - backbone = timm.create_model( - SIMCLR_BACKBONE, - pretrained=True, - num_classes=0 # Remove classification head - ) - model = SimCLRForClassification(backbone, NUM_FILTERED_CLASSES) - freeze_backbone(model, SSL_MODEL) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - - try: - dummy_input = torch.randn(1, 3, resolution, resolution).to(device) - model.to(device) - flops, _ = profile(model, inputs=(dummy_input,)) - flops /= 1e9 - except Exception as e: - print(f"FLOP profiling failed: {e}") - flops = -1 - - # Monitor disk space before training - import shutil - total, used, free = shutil.disk_usage("/") - print(f"Disk space before training {name}:") - print(f"Total: {total // (2**30)} GB") - print(f"Used: {used // (2**30)} GB") - print(f"Free: {free // (2**30)} GB") - - # If less than 1GB free, raise an error - if free < 1 * (2**30): # 1GB in bytes - raise RuntimeError("Not enough disk space to train model. Please free up at least 1GB of space.") - - train_args = TrainingArguments( - output_dir=os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}_lr_{learning_rate}"), - num_train_epochs=num_epochs, - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - learning_rate=learning_rate, - lr_scheduler_type="cosine", - weight_decay=0.01, - logging_dir=os.path.join(env_path("LOG_DIR", "."), f"{name}_lr_{learning_rate}"), - logging_steps=1, - eval_strategy="steps", - eval_steps=eval_steps, - save_strategy="steps", - save_steps=eval_steps, - load_best_model_at_end=False, - metric_for_best_model="accuracy", - save_total_limit=1, - save_safetensors=False, - hub_model_id=None, - hub_strategy="end", - push_to_hub=False, - save_only_model=True, - ) - - # Clean up old model directories before training - model_dirs = [ - os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}_lr_{learning_rate}"), - os.path.join(env_path("MODEL_DIR", "."), f"{name}_lr_{learning_rate}"), - os.path.join(env_path("LOG_DIR", "."), f"{name}_lr_{learning_rate}"), - ] - - for dir_path in model_dirs: - if os.path.exists(dir_path): - print(f"Cleaning up directory: {dir_path}") - import shutil - shutil.rmtree(dir_path) - os.makedirs(dir_path, exist_ok=True) - - # Initialize wandb for this learning rate - wandb.init( - entity="ericcui-use-stanford-university", - project="CS231N Test", - name=f"{name}_{resolution}_lr_{learning_rate}", - config={ - "model_name": name, - "resolution": resolution, - "batch_size": batch_size, - "num_epochs": num_epochs, - "eval_steps": eval_steps, - "learning_rate": learning_rate, - "weight_decay": 0.01, - "gpu_available": GPU_AVAILABLE, - }, - tags=["learning_rate_experiment", f"lr_{learning_rate}", f"resolution_{resolution}"], - ) - - trainer = Trainer( - model=model, - args=train_args, - train_dataset=train_ds, - eval_dataset=val_ds, - compute_metrics=lambda pred: compute_metrics(pred, name), - callbacks=[ - LossLoggerCallback( - log_dir=env_path("LOG_DIR", "./logs"), - phase="finetune", - model_name=name, - ), - WandbCallback(name, "finetune"), - ], - ) + model = get_model(typ, model_id, config) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + + flops = get_flops(model, config["resolution"]) + check_disk_space(min_gb=1) + cleanup_model_dirs(name, learning_rate) - # ---- TRAINING PHASE ---- - start_time = time.time() - peak_memory = get_gpu_memory() if GPU_AVAILABLE else -1 + wandb.init( + entity="ericcui-use-stanford-university", + project="CS231N Test", + name=f"{name}_{config['resolution']}_lr_{learning_rate}", + config={ + "model_name": name, + "resolution": config["resolution"], + "batch_size": config["batch_size"], + "num_epochs": config["num_epochs"], + "eval_steps": config["eval_steps"], + "learning_rate": learning_rate, + "weight_decay": 0.01, + "gpu_available": GPU_AVAILABLE, + }, + tags=[ + "learning_rate_experiment", + f"lr_{learning_rate}", + f"resolution_{config['resolution']}"], + ) - # Log model architecture - if typ in HF_MODELS: - wandb.watch(model, log="all", log_freq=100) - elif typ == SSL_MODEL: - wandb.watch(model.backbone, log="all", log_freq=100) + train_args = get_training_args(name, learning_rate, config) + callbacks = get_trainer_callbacks(name) - trainer.train() + trainer = Trainer( + model=model, + args=train_args, + train_dataset=train_ds, + eval_dataset=val_ds, + compute_metrics=lambda pred: compute_metrics(pred, name), + callbacks=callbacks, + ) - current_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - peak_memory = max(peak_memory, current_memory) + start_time = time.time() + peak_memory = get_gpu_memory() if GPU_AVAILABLE else -1 + + if typ in HF_MODELS: + wandb.watch(model, log="all", log_freq=100) + elif typ == SSL_MODEL: + wandb.watch(model.backbone, log="all", log_freq=100) + + trainer.train() + + current_memory = get_gpu_memory() if GPU_AVAILABLE else -1 + peak_memory = max(peak_memory, current_memory) + + eval_start_time = time.time() + eval_results = trainer.evaluate() + eval_time = time.time() - eval_start_time + train_time = time.time() - start_time - eval_time + + metrics = { + "learning_rate": learning_rate, + "model_name": name, + "model_type": typ, + "peak_memory_mb": peak_memory, + "flops_giga": flops, + "train_time_seconds": train_time, + "eval_time_seconds": eval_time, + "eval_metrics": eval_results, + } + wandb.log(metrics) - eval_start_time = time.time() - eval_results = trainer.evaluate() - eval_time = time.time() - eval_start_time - train_time = time.time() - start_time - eval_time + model_dir = save_model_and_preprocessor(model, preprocessors, typ, name, learning_rate) + log_wandb_artifact(model_dir, name, learning_rate) - # Log model-specific metrics - model_metrics = { - "model_name": name, - "model_type": typ, - "peak_memory_mb": peak_memory, - "flops_giga": flops, - "train_time_seconds": train_time, - "eval_time_seconds": eval_time, - "eval_metrics": eval_results, - } - wandb.log(model_metrics) + print(f"[Finetune] Learning Rate {learning_rate}: {metrics}") - model_dir = os.path.join( - env_path("MODEL_DIR", "."), f"{name}_lr_{learning_rate}" + wandb.finish() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return metrics + +def get_model(typ, model_id, config): + """ + Returns the initialized model based on type. + """ + if typ == "vit": + return ViTForImageClassification.from_pretrained( + model_id, + num_labels=NUM_FILTERED_CLASSES, + ignore_mismatched_sizes=True, + image_size=config["resolution"], ) - os.makedirs(model_dir, exist_ok=True) - - if typ in HF_MODELS: - model.save_pretrained(model_dir) - preprocessors[typ].save_pretrained(model_dir) - elif typ == SSL_MODEL: - torch.save( - model.state_dict(), os.path.join(model_dir, "pytorch_model.bin") - ) - with open(os.path.join(model_dir, "config.json"), "w") as f: - json.dump( - { - "model_type": SSL_MODEL, - "backbone": "resnet50", - "num_classes": NUM_FILTERED_CLASSES, - }, - f, - ) - - # Save model as wandb artifact - artifact = wandb.Artifact( - name=f"{name}_lr_{learning_rate}_model", - type="model", - description=f"Trained {name} model with {typ} architecture" + if typ == "dinov2": + return AutoModelForImageClassification.from_pretrained( + model_id, + num_labels=NUM_FILTERED_CLASSES, + ignore_mismatched_sizes=True, + image_size=config["resolution"] ) - artifact.add_dir(model_dir) - wandb.log_artifact(artifact) + if typ == SSL_MODEL: + backbone = timm.create_model( + SIMCLR_BACKBONE, + pretrained=True, + num_classes=0 + ) + model = SimCLRForClassification(backbone, NUM_FILTERED_CLASSES) + freeze_backbone(model, SSL_MODEL) + return model + raise ValueError(f"Unknown model type: {typ}") + +def get_transformed_datasets(train_dataset, preprocessors, config, typ): + """ + Returns a concatenated dataset with optional degradation transforms applied. + """ + num_images = len(train_dataset) + images_per_transform = int(num_images * config["proportion_per_transform"]) + indices = np.random.permutation(num_images) - # Update results dictionary - results[str(learning_rate)] = { - "learning_rate": learning_rate, - "model_name": name, - "model_type": typ, - "peak_memory_mb": peak_memory, - "flops_giga": flops, - "train_time_seconds": train_time, - "eval_time_seconds": eval_time, - "eval_metrics": eval_results, + transforms_list = [ + JPEGCompressionTransform(), + GaussianBlurTransform(), + ColorQuantizationTransform(), + ] + + def make_subset(indices_subset, transform=None): + subset = Subset(train_dataset, indices_subset) + transform_compose = transforms.Compose([transform]) if transform else None + return ISICDataset(subset, preprocessors[typ], config["resolution"], transform_compose, typ) + + datasets = [] + used_indices = set() + + for i, transform in enumerate(transforms_list): + idx = indices[i * images_per_transform : (i + 1) * images_per_transform] + used_indices.update(idx) + datasets.append(make_subset(idx, transform)) + + remaining = np.setdiff1d(indices, list(used_indices)) + if len(remaining) > 0: + datasets.append(make_subset(remaining)) + + return ConcatDataset(datasets) + + +def get_flops(model, resolution): + """ + Profile FLOPs for the given model and resolution. + Returns FLOPs in giga units, or -1 if profiling fails. + """ + try: + dummy_input = torch.randn(1, 3, resolution, resolution).to(next(model.parameters()).device) + flops, _ = profile(model, inputs=(dummy_input,)) + return flops / 1e9 + except Exception as e: # pylint: disable=broad-exception-caught + print(f"FLOP profiling failed: {e}") + return -1 + +def check_disk_space(min_gb=1): + """ + Checks if there is at least min_gb GB of free disk space. + Raises RuntimeError if not enough space. + """ + total, used, free = shutil.disk_usage("/") + print( + f"Disk space: Total={total // (2**30)} GB, " + f"Used={used // (2**30)} GB, " + f"Free={free // (2**30)} GB" + ) + if free < min_gb * (2**30): + raise RuntimeError(f"Not enough disk space. Please free up at least {min_gb}GB.") + +def save_model_and_preprocessor(model, preprocessors, typ, name, learning_rate): + """ + Saves the trained model and preprocessor to disk. + """ + model_dir = os.path.join(env_path("MODEL_DIR", "."), f"{name}_lr_{learning_rate}") + os.makedirs(model_dir, exist_ok=True) + if typ in HF_MODELS: + model.save_pretrained(model_dir) + preprocessors[typ].save_pretrained(model_dir) + elif typ == SSL_MODEL: + torch.save(model.state_dict(), os.path.join(model_dir, "pytorch_model.bin")) + with open(os.path.join(model_dir, "config.json"), "w", encoding="utf-8") as f: + json.dump({ + "model_type": SSL_MODEL, + "backbone": "resnet50", + "num_classes": NUM_FILTERED_CLASSES, + }, f) + return model_dir + +def log_wandb_artifact(model_dir, name, learning_rate): + """ + Logs the saved model directory as a wandb artifact. + """ + artifact = wandb.Artifact( + name=f"{name}_lr_{learning_rate}_model", + type="model", + description=f"Trained {name} model with {learning_rate} learning rate" + ) + artifact.add_dir(model_dir) + wandb.log_artifact(artifact) + +def get_training_args(name, learning_rate, config): + """ + Returns a TrainingArguments object for Hugging Face Trainer. + """ + return TrainingArguments( + output_dir=os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}_lr_{learning_rate}"), + num_train_epochs=config["num_epochs"], + per_device_train_batch_size=config["batch_size"], + per_device_eval_batch_size=config["batch_size"], + learning_rate=learning_rate, + lr_scheduler_type="cosine", + weight_decay=0.01, + logging_dir=os.path.join(env_path("LOG_DIR", "."), f"{name}_lr_{learning_rate}"), + logging_steps=1, + evaluation_strategy="steps", + eval_steps=config["eval_steps"], + save_strategy="steps", + save_steps=config["eval_steps"], + load_best_model_at_end=False, + metric_for_best_model="accuracy", + save_total_limit=1, + save_safetensors=False, + hub_model_id=None, + hub_strategy="end", + push_to_hub=False, + save_only_model=True, + ) + +def cleanup_model_dirs(name, learning_rate): + """ + Removes and recreates model/log directories for the current run. + """ + model_dirs = [ + os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}_lr_{learning_rate}"), + os.path.join(env_path("MODEL_DIR", "."), f"{name}_lr_{learning_rate}"), + os.path.join(env_path("LOG_DIR", "."), f"{name}_lr_{learning_rate}"), + ] + for dir_path in model_dirs: + if os.path.exists(dir_path): + print(f"Cleaning up directory: {dir_path}") + shutil.rmtree(dir_path) + os.makedirs(dir_path, exist_ok=True) + +def get_trainer_callbacks(name): + """ + Returns a list of Trainer callbacks for logging and monitoring. + """ + return [ + LossLoggerCallback( + log_dir=env_path("LOG_DIR", "./logs"), + phase="finetune", + model_name=name, + ), + WandbCallback(name, "finetune"), + ] + +def main(config=None): + """ + Main function for running learning rate sweep experiments on image classification models. + + Args: + config (dict): Configuration dictionary. + """ + if config is None: + config = { + "num_train_images": 25000, + "proportion_per_transform": 0.2, + "resolution": 224, + "batch_size": 256, + "num_epochs": 3, + "eval_steps": 10, + "learning_rate": 1e-4, } - print(f"[Finetune] Learning Rate {learning_rate}: {results[str(learning_rate)]}") - - # Close the wandb run for this learning rate - wandb.finish() + model_config = { + "name": "dinov2", + "model_id": "facebook/dinov2-base", + "type": "dinov2", + "config": { + "image_size": config["resolution"], + "num_labels": NUM_FILTERED_CLASSES, + "ignore_mismatched_sizes": True + } + } - # Clear GPU memory after model is done - if torch.cuda.is_available(): - torch.cuda.empty_cache() + learning_rates = [config["learning_rate"]] + results = {} + + dataset = load_dataset( + "MKZuziak/ISIC_2019_224", + cache_dir=os.environ["HF_DATASETS_CACHE"], + split="train", + ) + + train_dataset, val_dataset = prepare_balanced_datasets(dataset, config) + preprocessors = create_preprocessors(model_config, config) + config["preprocessors"] = preprocessors + + for lr in learning_rates: + results[str(lr)] = train_for_learning_rate( + lr, model_config, train_dataset, val_dataset, config + ) # Save results with open( @@ -510,25 +563,47 @@ def main(num_train_images=25000, proportion_per_transform=0.2, resolution=224, b env_path("TRAIN_OUTPUT_DIR", "."), "results_metrics_lr_experiment.json" ), "w", + encoding="utf-8" ) as f: json.dump(results, f, indent=4) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Learning rate experiment for image classification.") - parser.add_argument('--resolution', type=int, default=224, help='Input image resolution (default: 224)') - parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training and evaluation (default: 128)') - parser.add_argument('--num_train_images', type=int, default=500, help='Number of training images to use per class (default: 500)') - parser.add_argument('--num_epochs', type=int, default=3, help='Number of training epochs (default: 3)') - parser.add_argument('--eval_steps', type=int, default=100, help='Number of steps between evaluations (default: 100)') - parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate (default: 1e-4)') - args = parser.parse_args() - main( - resolution=args.resolution, - batch_size=args.batch_size, - num_train_images=args.num_train_images, - num_epochs=args.num_epochs, - eval_steps=args.eval_steps, - learning_rate=args.learning_rate + # Entry point for running the learning rate sweep experiment from the command line. + # Parses command-line arguments and calls main(). + parser = argparse.ArgumentParser( + description="Learning rate experiment for image classification." ) - + parser.add_argument( + '--resolution', type=int, default=224, + help='Input image resolution (default: 224)' + ) + parser.add_argument( + '--batch_size', type=int, default=128, + help='Batch size for training and evaluation (default: 128)' + ) + parser.add_argument( + '--num_train_images', type=int, default=500, + help='Number of training images to use per class (default: 500)' + ) + parser.add_argument( + '--num_epochs', type=int, default=3, + help='Number of training epochs (default: 3)' + ) + parser.add_argument( + '--eval_steps', type=int, default=100, + help='Number of steps between evaluations (default: 100)' + ) + parser.add_argument( + '--learning_rate', type=float, default=1e-4, + help='Learning rate (default: 1e-4)' + ) + args = parser.parse_args() + main({ + "resolution": args.resolution, + "batch_size": args.batch_size, + "num_train_images": args.num_train_images, + "num_epochs": args.num_epochs, + "eval_steps": args.eval_steps, + "learning_rate": args.learning_rate, + }) diff --git a/src/models/utils/__init__.py b/src/models/utils/__init__.py index 74b6809..e69de29 100644 --- a/src/models/utils/__init__.py +++ b/src/models/utils/__init__.py @@ -1,3 +0,0 @@ -""" -Utils package for model-related utilities. -""" \ No newline at end of file diff --git a/src/models/utils/constants.py b/src/models/utils/constants.py index de4f241..e745410 100644 --- a/src/models/utils/constants.py +++ b/src/models/utils/constants.py @@ -1,3 +1,11 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +"""Configuration constants for model training and evaluation.""" + HF_MODELS = ["vit", "dinov2"] SSL_MODEL = "simclr" SIMCLR_BACKBONE = "resnet50" diff --git a/src/models/utils/transforms.py b/src/models/utils/transforms.py index 93d0fa4..fcd3b88 100644 --- a/src/models/utils/transforms.py +++ b/src/models/utils/transforms.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + """ Image transformation utilities for data augmentation and degradation. """ @@ -7,8 +13,17 @@ from PIL import Image from torchvision import transforms +# Compatibility for LANCZOS resampling +try: + LANCZOS = Image.Resampling.LANCZOS +except AttributeError: + LANCZOS = Image.LANCZOS # pylint: disable=no-member + class JPEGCompressionTransform: + """ + Apply JPEG compression to an image to simulate lossy compression artifacts. + """ def __init__(self, quality=75): """ Apply JPEG compression to an image. @@ -42,12 +57,20 @@ def __call__(self, img): # Ensure size is maintained if img.size != original_size: - img = img.resize(original_size, Image.LANCZOS) + img = img.resize(original_size, LANCZOS) return img + def get_quality(self): + """ + Return the JPEG compression quality setting. + """ + return self.quality + class GaussianBlurTransform: + """Apply Gaussian blur to an image with a given probability. + """ def __init__(self, p=1): """ Apply Gaussian blur to an image with a given probability. @@ -81,16 +104,23 @@ def __call__(self, img): # Ensure size is maintained if img.size != original_size: - img = img.resize(original_size, Image.LANCZOS) + img = img.resize(original_size, LANCZOS) return img + def get_probability(self): + """ + Return the probability of applying Gaussian blur. + """ + return self.p + class ColorQuantizationTransform: + """ + Apply color quantization to an image with a given probability. + """ def __init__(self, p=1): """ - Apply color quantization to an image with a given probability. - Args: p (float): Probability of applying the quantization (0 to 1). """ @@ -119,6 +149,12 @@ def __call__(self, img): # Ensure size is maintained if img.size != original_size: - img = img.resize(original_size, Image.LANCZOS) + img = img.resize(original_size, LANCZOS) + + return img - return img \ No newline at end of file + def get_probability(self): + """ + Return the probability of applying color quantization. + """ + return self.p diff --git a/src/models/utils/utils_classes.py b/src/models/utils/utils_classes.py index 0fe3e88..d5c225d 100644 --- a/src/models/utils/utils_classes.py +++ b/src/models/utils/utils_classes.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + """ Utility classes for model training and data handling. """ @@ -6,54 +12,55 @@ import json import numpy as np import torch -import torch.nn as nn -from PIL import Image +from torch import nn from torch.utils.data import Dataset +from PIL import Image from torchvision import transforms from transformers import TrainerCallback from .constants import HF_MODELS, NUM_FILTERED_CLASSES, SSL_MODEL from .transforms import JPEGCompressionTransform +# Compatibility for LANCZOS resampling +try: + LANCZOS = Image.Resampling.LANCZOS +except AttributeError: + LANCZOS = Image.LANCZOS # pylint: disable=no-member + class ISICDataset(Dataset): - def __init__( - self, - dataset, - preprocessor=None, - resolution=224, - transform=None, - model_type="vit", - jpeg_quality=None, - ): + """ + Dataset class for handling ISIC image data with optional transformations. + """ + def __init__(self, dataset, config=None): """ - Dataset class for handling ISIC image data. - Args: dataset: The dataset to load. - preprocessor: Preprocessing function for Hugging Face models. - resolution: Target image resolution. - transform: Additional transformations to apply. - model_type: Type of model (e.g., "vit", "ssl"). - jpeg_quality: JPEG compression quality (if applicable). + config (dict, optional): Configuration dictionary with keys: + - preprocessor + - resolution + - transform + - model_type + - jpeg_quality """ self.dataset = dataset - self.preprocessor = preprocessor - self.resolution = resolution - self.transform = transform - self.model_type = model_type - self.jpeg_quality = jpeg_quality + config = config or {} + self.preprocessor = config.get("preprocessor", None) + self.resolution = config.get("resolution", 224) + self.transform = config.get("transform", None) + self.model_type = config.get("model_type", "vit") + self.jpeg_quality = config.get("jpeg_quality", None) # Base preprocessing pipeline for resizing and tensor conversion self.base_preprocessor = transforms.Compose([ - transforms.Resize((resolution, resolution), Image.LANCZOS), + transforms.Resize((self.resolution, self.resolution), LANCZOS), transforms.ToTensor(), ]) # Preprocessor for SSL models - if model_type == SSL_MODEL: + if self.model_type == SSL_MODEL: self.preprocessor = transforms.Compose([ - transforms.Resize((resolution, resolution), Image.LANCZOS), + transforms.Resize((self.resolution, self.resolution), LANCZOS), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) @@ -79,7 +86,7 @@ def __getitem__(self, idx): label = item["label"] # Always resize to target resolution first - image = image.resize((self.resolution, self.resolution), Image.LANCZOS) + image = image.resize((self.resolution, self.resolution), LANCZOS) # Apply additional transformations if provided if self.transform: @@ -106,6 +113,9 @@ def __getitem__(self, idx): class SimCLRForClassification(nn.Module): + """ + SimCLR-based classification model. + """ def __init__(self, backbone, num_classes=NUM_FILTERED_CLASSES): """ SimCLR-based classification model. @@ -170,6 +180,6 @@ def on_log(self, args, state, control, logs=None, **kwargs): """ if logs is None: return - with open(self.log_file, "a") as f: + with open(self.log_file, "a", encoding="utf-8") as f: json.dump({"step": state.global_step, **logs}, f) - f.write("\n") \ No newline at end of file + f.write("\n") diff --git a/src/models/utils/util_methods.py b/src/models/utils/utils_methods.py similarity index 76% rename from src/models/utils/util_methods.py rename to src/models/utils/utils_methods.py index 2dde99e..0d33427 100644 --- a/src/models/utils/util_methods.py +++ b/src/models/utils/utils_methods.py @@ -1,3 +1,13 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +"""Utility methods for model training and evaluation. +This module provides functions for computing evaluation metrics, managing GPU memory, +freezing model backbones, and handling environment paths.""" + import os import json import numpy as np @@ -18,11 +28,14 @@ def env_path(key, default): def compute_metrics(eval_pred, model_name): + """ + Compute evaluation metrics from model predictions. + """ logits, labels = eval_pred predictions = np.argmax(logits, axis=-1) acc = accuracy_score(labels, predictions) f1 = f1_score(labels, predictions, average="weighted") - + # For binary classification, use the probability of the positive class probs = torch.softmax(torch.tensor(logits), dim=1).numpy() # Use the probability of class 1 (positive class) for ROC AUC @@ -44,24 +57,32 @@ def compute_metrics(eval_pred, model_name): unique, counts = np.unique(predictions, return_counts=True) class_breakdown = {str(k): int(v) for k, v in zip(unique, counts)} - with open(os.path.join(plot_dir, "class_breakdown.json"), "w") as f: + with open(os.path.join(plot_dir, "class_breakdown.json"), "w", encoding="utf-8") as f: json.dump(class_breakdown, f) return {"accuracy": acc, "f1": f1, "auc": auc} def get_gpu_memory(device_id=0): + """ + Get the used GPU memory in MB for a specific device. + """ if not GPU_AVAILABLE: return -1 try: handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) return mem_info.used / 1024**2 - except: + except pynvml.NVMLError: return -1 - + except Exception: # pylint: disable=broad-exception-caught + return -1 + def freeze_backbone(model, model_type): + """ + Freeze the backbone of the model based on its type. + """ if model_type in HF_MODELS: for name, param in model.named_parameters(): if "classifier" not in name: @@ -72,4 +93,4 @@ def freeze_backbone(model, model_type): for param in model.classifier.parameters(): param.requires_grad = True else: - raise ValueError(f"Unsupported model_type: {model_type}") \ No newline at end of file + raise ValueError(f"Unsupported model_type: {model_type}") diff --git a/src/requirements.txt b/src/requirements.txt deleted file mode 100644 index 2d330480365fea94d5cdc9f0c7ab6c125710cd23..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1716 zcmZvcTaKGR5JmetQkLS6C|Sttz`};H@nB4V89x@ELrr>z0^UP55V4Z$$J<%!dCey<>0FE8her`_yl&EktOgV{vEp5h zRTJa&iyNBh)%uzvLtiTlPW$wONBYDpGlG_mbuzv2>0iSj2J@B>OP~*TZn`MviTBqGMVhvySSDksu{@*XlsOOEy zGp`fWp=Q~`e(!QW-!q1NR}1Z>KKK60z5=h$9IVeMzxHnp$^_FMj%CJsc2T*QniGFA zukf2W@V7;+3O1rf=d|`&)##aP=QyTgAkrLXZ=#F(f`^_i*Kl6YJG^tA1Sem;#fhDr zHkrxWSyUnQ3RMTP zoe?#ER^@TxS9w<_;+f-f#PcaqrQgIyuVSYOZtYU_3P#>M&aHKXuXN6HT``wdQv6Q8 zvsctBm@Vfq5PHyGdCD}Dn{2|Oj?p(~C+48k6}OO@mRl&!PMmlDt;Un*nY43TFrp`M zuQ+4gBk~B#maOO12lz56cc?dzlQ-Qb#`6+Y&1hG{CJ>BsQdDfdN7%LKSDcZFwf9oF zwIcTs^5fn2eaGK@leK_*h%*u`k6VMGxM6YyPVn8b6 Date: Fri, 11 Jul 2025 14:37:40 -0700 Subject: [PATCH 04/26] major refactoring of model comparison scripts --- src/models/model_comparison_lr_sweep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/model_comparison_lr_sweep.py b/src/models/model_comparison_lr_sweep.py index 8983c91..b382867 100644 --- a/src/models/model_comparison_lr_sweep.py +++ b/src/models/model_comparison_lr_sweep.py @@ -229,7 +229,7 @@ def create_preprocessors(model_config, config): return preprocessors -def train_for_learning_rate( +def train_for_learning_rate( # pylint: disable=too-many-locals learning_rate, model_config, train_dataset, val_dataset, config ): """ From ef80a97ce9e4036659e3f01a0f1e6e1786ade5a1 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 11 Jul 2025 14:43:02 -0700 Subject: [PATCH 05/26] pylint fix + license update --- LICENSE | 21 ---- LICENSES/MIT.txt | 18 ++++ media/CS231N Poster.png.license | 5 + requirements.txt | 2 +- results/results_metrics_finetune.json | 24 ----- results/results_metrics_lr_experiment.json | 117 --------------------- 6 files changed, 24 insertions(+), 163 deletions(-) delete mode 100644 LICENSE create mode 100644 LICENSES/MIT.txt create mode 100644 media/CS231N Poster.png.license delete mode 100644 results/results_metrics_finetune.json delete mode 100644 results/results_metrics_lr_experiment.json diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 97a7fea..0000000 --- a/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2025 Stanford Daneshjou Lab - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/LICENSES/MIT.txt b/LICENSES/MIT.txt new file mode 100644 index 0000000..d817195 --- /dev/null +++ b/LICENSES/MIT.txt @@ -0,0 +1,18 @@ +MIT License + +Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +associated documentation files (the "Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the +following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial +portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT +LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO +EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/media/CS231N Poster.png.license b/media/CS231N Poster.png.license new file mode 100644 index 0000000..3cc951b --- /dev/null +++ b/media/CS231N Poster.png.license @@ -0,0 +1,5 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b93e56d..5c4a69e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -63,7 +63,7 @@ sniffio soupsieve stack-data thop -threadpoolctladpoolctl +threadpoolctl timm torch torchvisionvision diff --git a/results/results_metrics_finetune.json b/results/results_metrics_finetune.json deleted file mode 100644 index 95c1365..0000000 --- a/results/results_metrics_finetune.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "vit": { - "model_name": "vit", - "model_type": "vit", - "peak_memory_mb": 6197.1875, - "flops_giga": 0.855424512, - "train_time_seconds": 2316.893568277359, - "eval_time_seconds": 68.61979532241821, - "eval_metrics": { - "eval_loss": 0.5946938395500183, - "eval_accuracy": 0.69, - "eval_f1": 0.6846709388668498, - "eval_auc": 0.7788, - "eval_runtime": 68.6175, - "eval_samples_per_second": 2.915, - "eval_steps_per_second": 0.102, - "epoch": 3.0, - "model": "vit", - "phase": "finetune", - "gpu_memory_mb": 2365.1875, - "best_accuracy": 0.78 - } - } -} \ No newline at end of file diff --git a/results/results_metrics_lr_experiment.json b/results/results_metrics_lr_experiment.json deleted file mode 100644 index 7e97905..0000000 --- a/results/results_metrics_lr_experiment.json +++ /dev/null @@ -1,117 +0,0 @@ -{ - "1e-05": { - "learning_rate": 1e-05, - "model_name": "vit", - "model_type": "vit", - "peak_memory_mb": 8067.1875, - "flops_giga": 0.855424512, - "train_time_seconds": 71.61356973648071, - "eval_time_seconds": 18.6774001121521, - "eval_metrics": { - "eval_loss": 0.8273601531982422, - "eval_accuracy": 0.4, - "eval_f1": 0.4, - "eval_auc": 0.30999999999999994, - "eval_runtime": 18.6702, - "eval_samples_per_second": 1.071, - "eval_steps_per_second": 0.107, - "epoch": 1.0, - "model": "vit", - "phase": "finetune", - "gpu_memory_mb": 8067.1875, - "best_accuracy": 0.4 - } - }, - "5e-05": { - "learning_rate": 5e-05, - "model_name": "vit", - "model_type": "vit", - "peak_memory_mb": 8067.1875, - "flops_giga": 0.855424512, - "train_time_seconds": 66.13869047164917, - "eval_time_seconds": 19.406721115112305, - "eval_metrics": { - "eval_loss": 0.6468731164932251, - "eval_accuracy": 0.6, - "eval_f1": 0.5959595959595959, - "eval_auc": 0.7100000000000001, - "eval_runtime": 19.4014, - "eval_samples_per_second": 1.031, - "eval_steps_per_second": 0.103, - "epoch": 1.0, - "model": "vit", - "phase": "finetune", - "gpu_memory_mb": 8067.1875, - "best_accuracy": 0.6 - } - }, - "0.0001": { - "learning_rate": 0.0001, - "model_name": "vit", - "model_type": "vit", - "peak_memory_mb": 8067.1875, - "flops_giga": 0.855424512, - "train_time_seconds": 68.64359831809998, - "eval_time_seconds": 18.199121713638306, - "eval_metrics": { - "eval_loss": 0.7177037000656128, - "eval_accuracy": 0.55, - "eval_f1": 0.52, - "eval_auc": 0.72, - "eval_runtime": 18.1963, - "eval_samples_per_second": 1.099, - "eval_steps_per_second": 0.11, - "epoch": 1.0, - "model": "vit", - "phase": "finetune", - "gpu_memory_mb": 8067.1875, - "best_accuracy": 0.55 - } - }, - "0.0005": { - "learning_rate": 0.0005, - "model_name": "vit", - "model_type": "vit", - "peak_memory_mb": 8067.1875, - "flops_giga": 0.855424512, - "train_time_seconds": 67.80419445037842, - "eval_time_seconds": 19.256648302078247, - "eval_metrics": { - "eval_loss": 0.6929414868354797, - "eval_accuracy": 0.55, - "eval_f1": 0.43573667711598746, - "eval_auc": 0.56, - "eval_runtime": 19.2506, - "eval_samples_per_second": 1.039, - "eval_steps_per_second": 0.104, - "epoch": 1.0, - "model": "vit", - "phase": "finetune", - "gpu_memory_mb": 8067.1875, - "best_accuracy": 0.55 - } - }, - "0.001": { - "learning_rate": 0.001, - "model_name": "vit", - "model_type": "vit", - "peak_memory_mb": 8067.1875, - "flops_giga": 0.855424512, - "train_time_seconds": 68.06890201568604, - "eval_time_seconds": 18.359089851379395, - "eval_metrics": { - "eval_loss": 0.7085598707199097, - "eval_accuracy": 0.5, - "eval_f1": 0.3333333333333333, - "eval_auc": 0.5800000000000001, - "eval_runtime": 18.3557, - "eval_samples_per_second": 1.09, - "eval_steps_per_second": 0.109, - "epoch": 1.0, - "model": "vit", - "phase": "finetune", - "gpu_memory_mb": 8067.1875, - "best_accuracy": 0.5 - } - } -} \ No newline at end of file From 903551a933d6a169af8aec80a6dfc4c34bf3de1d Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 11 Jul 2025 14:45:15 -0700 Subject: [PATCH 06/26] pylint fix + license update --- requirements.txt | 24 ++++++++++++------------ yamllint.license | 5 +++++ yamllint.txt | 14 ++++++++++++++ 3 files changed, 31 insertions(+), 12 deletions(-) create mode 100644 yamllint.license create mode 100644 yamllint.txt diff --git a/requirements.txt b/requirements.txt index 5c4a69e..c3c931f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ colorama comm contourpy cycler -datasets +datasets>=2.18.0 debugpy decorator distro @@ -35,10 +35,10 @@ loguru matplotlib matplotlib-inline nest_asyncio -numpy +numpy>=1.26.0 openai packaging -pandas +pandas>=2.2.0 parso pickleshare pillow @@ -46,16 +46,16 @@ platformdirs prompt_toolkit psutil pure_eval -pydantic -pydantic_core +pydantic>=2.6.0 +pydantic_core>=2.16.0 pynvml Pygments pyparsing python-dateutil python-dotenv requests -scikit-learn -scipy +scikit-learn>=1.3.0 +scipy>=1.11.0 seaborn setuptools six @@ -66,15 +66,15 @@ thop threadpoolctl timm torch -torchvisionvision +torchvision tqdm traitlets -tritonon -typing-inspectionspection +triton +typing-inspect typing_extensions tzdata -urllib33 +urllib3 wandb wcwidth wheel -zipp \ No newline at end of file +zipp diff --git a/yamllint.license b/yamllint.license new file mode 100644 index 0000000..3184dcc --- /dev/null +++ b/yamllint.license @@ -0,0 +1,5 @@ +# This source file is part of the ARPA-H CARE LLM project +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT \ No newline at end of file diff --git a/yamllint.txt b/yamllint.txt new file mode 100644 index 0000000..040176a --- /dev/null +++ b/yamllint.txt @@ -0,0 +1,14 @@ +--- +extends: default + +rules: + truthy: + level: warning + allowed-values: ["false", "true", "on", "off"] + document-start: + level: warning + line-length: + max: 180 + level: warning + document-start: + present: false \ No newline at end of file From fc54bd0faca773ccb2f4ad2d8befeff6a4fd9a52 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 11 Jul 2025 14:49:13 -0700 Subject: [PATCH 07/26] Add .yamllint file --- yamllint.txt => .yamllint | 0 yamllint.license => .yamllint.license | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename yamllint.txt => .yamllint (100%) rename yamllint.license => .yamllint.license (100%) diff --git a/yamllint.txt b/.yamllint similarity index 100% rename from yamllint.txt rename to .yamllint diff --git a/yamllint.license b/.yamllint.license similarity index 100% rename from yamllint.license rename to .yamllint.license From 86da6094650f3a0404ef21955a707dfc025be747 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 11 Jul 2025 15:41:41 -0700 Subject: [PATCH 08/26] pylint --- requirements.txt | 1 + src/models/model_comparison_baseline.py | 15 +++--- src/models/model_comparison_lr_sweep.py | 67 +++++------------------- src/models/utils/utils_classes.py | 6 +-- src/models/utils/utils_methods.py | 69 +++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 62 deletions(-) diff --git a/requirements.txt b/requirements.txt index c3c931f..e779364 100644 --- a/requirements.txt +++ b/requirements.txt @@ -69,6 +69,7 @@ torch torchvision tqdm traitlets +transformers triton typing-inspect typing_extensions diff --git a/src/models/model_comparison_baseline.py b/src/models/model_comparison_baseline.py index 9de0be4..4795202 100644 --- a/src/models/model_comparison_baseline.py +++ b/src/models/model_comparison_baseline.py @@ -39,13 +39,15 @@ import wandb import timm -from utils.constants import SSL_MODEL, SIMCLR_BACKBONE, FILTERED_CLASSES, NUM_FILTERED_CLASSES -from utils.transforms import ( +from src.models.utils.constants import ( + SSL_MODEL, SIMCLR_BACKBONE, FILTERED_CLASSES, NUM_FILTERED_CLASSES +) +from src.models.utils.transforms import ( JPEGCompressionTransform, GaussianBlurTransform, ) -from utils.util_classes import SimCLRForClassification, LossLoggerCallback -from utils.util_methods import get_gpu_memory, GPU_AVAILABLE, freeze_backbone -from transforms import ColorQuantizationTransform +from src.models.utils.utils_classes import SimCLRForClassification, LossLoggerCallback +from src.models.utils.utils_methods import get_gpu_memory, GPU_AVAILABLE, freeze_backbone +from src.models.utils.transforms import ColorQuantizationTransform # Compatibility for LANCZOS resampling try: @@ -285,9 +287,10 @@ def train_model( def get_trainer_callbacks(name): + """Get callbacks for the Trainer.""" return [ LossLoggerCallback( - log_dir=env_path("LOG_DIR", "./logs"), + log_dir=os.environ.get("LOG_DIR", "./logs"), phase="finetune", model_name=name, ), diff --git a/src/models/model_comparison_lr_sweep.py b/src/models/model_comparison_lr_sweep.py index b382867..d379e89 100644 --- a/src/models/model_comparison_lr_sweep.py +++ b/src/models/model_comparison_lr_sweep.py @@ -56,7 +56,6 @@ # Model Profiling & Vision Backbones import timm -from thop import profile # Local Application Imports @@ -78,7 +77,10 @@ compute_metrics, get_gpu_memory, freeze_backbone, - GPU_AVAILABLE + GPU_AVAILABLE, + get_flops, + check_disk_space, + save_model_and_preprocessor, ) # Set memory optimization @@ -244,11 +246,11 @@ def train_for_learning_rate( # pylint: disable=too-many-locals ) train_ds = get_transformed_datasets(train_dataset, preprocessors, config, typ) - val_ds = ISICDataset( + val_ds = ISICDataset( # pylint: disable=too-many-function-args val_dataset, preprocessors[typ], config["resolution"], - model_type=typ, + typ, ) model = get_model(typ, model_id, config) @@ -352,7 +354,7 @@ def get_model(typ, model_id, config): image_size=config["resolution"] ) if typ == SSL_MODEL: - backbone = timm.create_model( + backbone = timm.create_model( # pylint: disable=too-many-function-args SIMCLR_BACKBONE, pretrained=True, num_classes=0 @@ -379,7 +381,13 @@ def get_transformed_datasets(train_dataset, preprocessors, config, typ): def make_subset(indices_subset, transform=None): subset = Subset(train_dataset, indices_subset) transform_compose = transforms.Compose([transform]) if transform else None - return ISICDataset(subset, preprocessors[typ], config["resolution"], transform_compose, typ) + return ISICDataset( # pylint: disable=too-many-function-args + subset, + preprocessors[typ], + config["resolution"], + transform_compose, + typ + ) datasets = [] used_indices = set() @@ -395,53 +403,6 @@ def make_subset(indices_subset, transform=None): return ConcatDataset(datasets) - -def get_flops(model, resolution): - """ - Profile FLOPs for the given model and resolution. - Returns FLOPs in giga units, or -1 if profiling fails. - """ - try: - dummy_input = torch.randn(1, 3, resolution, resolution).to(next(model.parameters()).device) - flops, _ = profile(model, inputs=(dummy_input,)) - return flops / 1e9 - except Exception as e: # pylint: disable=broad-exception-caught - print(f"FLOP profiling failed: {e}") - return -1 - -def check_disk_space(min_gb=1): - """ - Checks if there is at least min_gb GB of free disk space. - Raises RuntimeError if not enough space. - """ - total, used, free = shutil.disk_usage("/") - print( - f"Disk space: Total={total // (2**30)} GB, " - f"Used={used // (2**30)} GB, " - f"Free={free // (2**30)} GB" - ) - if free < min_gb * (2**30): - raise RuntimeError(f"Not enough disk space. Please free up at least {min_gb}GB.") - -def save_model_and_preprocessor(model, preprocessors, typ, name, learning_rate): - """ - Saves the trained model and preprocessor to disk. - """ - model_dir = os.path.join(env_path("MODEL_DIR", "."), f"{name}_lr_{learning_rate}") - os.makedirs(model_dir, exist_ok=True) - if typ in HF_MODELS: - model.save_pretrained(model_dir) - preprocessors[typ].save_pretrained(model_dir) - elif typ == SSL_MODEL: - torch.save(model.state_dict(), os.path.join(model_dir, "pytorch_model.bin")) - with open(os.path.join(model_dir, "config.json"), "w", encoding="utf-8") as f: - json.dump({ - "model_type": SSL_MODEL, - "backbone": "resnet50", - "num_classes": NUM_FILTERED_CLASSES, - }, f) - return model_dir - def log_wandb_artifact(model_dir, name, learning_rate): """ Logs the saved model directory as a wandb artifact. diff --git a/src/models/utils/utils_classes.py b/src/models/utils/utils_classes.py index d5c225d..069d2c0 100644 --- a/src/models/utils/utils_classes.py +++ b/src/models/utils/utils_classes.py @@ -112,7 +112,7 @@ def __getitem__(self, idx): return {"pixel_values": pixel_values, "labels": label} -class SimCLRForClassification(nn.Module): +class SimCLRForClassification(nn.Module): # pylint: disable=too-few-public-methods """ SimCLR-based classification model. """ @@ -149,7 +149,7 @@ def forward(self, pixel_values, labels=None): ) -class LossLoggerCallback(TrainerCallback): +class LossLoggerCallback(TrainerCallback): # pylint: disable=too-few-public-methods """ Logs each training step's loss and other metrics to a structured JSON Lines file. """ @@ -168,7 +168,7 @@ def __init__(self, log_dir: str, phase: str, model_name: str): log_dir, f"{model_name}_{phase}_log.jsonl" ) - def on_log(self, args, state, control, logs=None, **kwargs): + def on_log(self, _args, state, _control, logs=None, **_kwargs): """ Log metrics to a JSON Lines file. diff --git a/src/models/utils/utils_methods.py b/src/models/utils/utils_methods.py index 0d33427..dec220e 100644 --- a/src/models/utils/utils_methods.py +++ b/src/models/utils/utils_methods.py @@ -9,8 +9,10 @@ freezing model backbones, and handling environment paths.""" import os +import shutil import json import numpy as np +from thop import profile import torch import matplotlib.pyplot as plt import seaborn as sns @@ -94,3 +96,70 @@ def freeze_backbone(model, model_type): param.requires_grad = True else: raise ValueError(f"Unsupported model_type: {model_type}") + +def get_flops(model, resolution): + """ + Profile FLOPs for the given model and resolution. + + Args: + model: The model to profile. + resolution (int): Input image resolution. + + Returns: + float: FLOPs in giga units, or -1 if profiling fails. + """ + try: + dummy_input = torch.randn(1, 3, resolution, resolution).to(next(model.parameters()).device) + flops, _ = profile(model, inputs=(dummy_input,)) + return flops / 1e9 + except Exception as e: # pylint: disable=broad-exception-caught + print(f"FLOP profiling failed: {e}") + return -1 + +def check_disk_space(min_gb=1): + """ + Checks if there is at least min_gb GB of free disk space. + + Args: + min_gb (int): Minimum required free disk space in GB. + + Raises: + RuntimeError: If not enough disk space is available. + """ + total, used, free = shutil.disk_usage("/") + print( + f"Disk space: Total={total // (2**30)} GB, " + f"Used={used // (2**30)} GB, " + f"Free={free // (2**30)} GB" + ) + if free < min_gb * (2**30): + raise RuntimeError(f"Not enough disk space. Please free up at least {min_gb}GB.") + +def save_model_and_preprocessor(model, preprocessors, typ, name, learning_rate): + """ + Saves the trained model and preprocessor to disk. + + Args: + model: The trained model. + preprocessors (dict): Preprocessors for each model type. + typ (str): Model type. + name (str): Model name. + learning_rate (float): Learning rate used for training. + + Returns: + str: Path to the saved model directory. + """ + model_dir = os.path.join(env_path("MODEL_DIR", "."), f"{name}_lr_{learning_rate}") + os.makedirs(model_dir, exist_ok=True) + if typ in HF_MODELS: + model.save_pretrained(model_dir) + preprocessors[typ].save_pretrained(model_dir) + elif typ == "simclr": + torch.save(model.state_dict(), os.path.join(model_dir, "pytorch_model.bin")) + with open(os.path.join(model_dir, "config.json"), "w", encoding="utf-8") as f: + json.dump({ + "model_type": "simclr", + "backbone": "resnet50", + "num_classes": model.classifier.out_features, + }, f) + return model_dir From f0839d25b0c80f9db8bb553682fa85b3b5440952 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 11 Jul 2025 17:54:49 -0700 Subject: [PATCH 09/26] refactor and restructure --- .DS_Store | Bin 0 -> 8196 bytes {scripts => jobs}/download_unpack_isic2019.sh | 0 jobs/run.sh | 29 --- jobs/run_compare_baseline.sh | 38 +++ jobs/run_compare_lr_sweep.sh | 38 +++ jobs/run_lr.sh | 29 --- jobs/run_models.sh | 29 --- {scripts => jobs}/submit_from_config.sh | 0 .../visualize_isic_results.py | 0 src/.DS_Store | Bin 0 -> 6148 bytes src/compressed_perception/.DS_Store | Bin 0 -> 6148 bytes src/compressed_perception/models/.DS_Store | Bin 0 -> 6148 bytes .../models/__init__.py | 0 .../models/comparison}/__init__.py | 0 .../models/comparison/compare_baseline.py} | 225 +++++++----------- .../models/comparison/compare_lr_sweep.py} | 152 ++---------- .../models/training/.DS_Store | Bin 0 -> 6148 bytes .../models/training/__init__.py | 0 .../models/training}/constants.py | 0 .../models/training}/utils_classes.py | 54 ++++- .../models/training}/utils_methods.py | 21 +- src/compressed_perception/modules/__init__.py | 0 .../modules/data_preparation/__init__.py | 0 .../modules/data_preparation/preparation.py | 101 ++++++++ .../modules/data_transformation/__init__.py | 0 .../image_transformation.py} | 0 26 files changed, 350 insertions(+), 366 deletions(-) create mode 100644 .DS_Store rename {scripts => jobs}/download_unpack_isic2019.sh (100%) delete mode 100644 jobs/run.sh create mode 100644 jobs/run_compare_baseline.sh create mode 100644 jobs/run_compare_lr_sweep.sh delete mode 100644 jobs/run_lr.sh delete mode 100644 jobs/run_models.sh rename {scripts => jobs}/submit_from_config.sh (100%) rename src/evaluation/evaluate_isic_results.py => scripts/visualize_isic_results.py (100%) create mode 100644 src/.DS_Store create mode 100644 src/compressed_perception/.DS_Store create mode 100644 src/compressed_perception/models/.DS_Store rename src/{ => compressed_perception}/models/__init__.py (100%) rename src/{models/utils => compressed_perception/models/comparison}/__init__.py (100%) rename src/{models/model_comparison_baseline.py => compressed_perception/models/comparison/compare_baseline.py} (55%) rename src/{models/model_comparison_lr_sweep.py => compressed_perception/models/comparison/compare_lr_sweep.py} (72%) create mode 100644 src/compressed_perception/models/training/.DS_Store create mode 100644 src/compressed_perception/models/training/__init__.py rename src/{models/utils => compressed_perception/models/training}/constants.py (100%) rename src/{models/utils => compressed_perception/models/training}/utils_classes.py (75%) rename src/{models/utils => compressed_perception/models/training}/utils_methods.py (87%) create mode 100644 src/compressed_perception/modules/__init__.py create mode 100644 src/compressed_perception/modules/data_preparation/__init__.py create mode 100644 src/compressed_perception/modules/data_preparation/preparation.py create mode 100644 src/compressed_perception/modules/data_transformation/__init__.py rename src/{models/utils/transforms.py => compressed_perception/modules/data_transformation/image_transformation.py} (100%) diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..4f9f60cbb31db681b5b5082a7386b91895c2a534 GIT binary patch literal 8196 zcmeHML2DC17=4q*WdJe zgZ>VGN5QNAz^g}nGqVZFwxOPa;7pl$lbP>*-^`oMZoVx5(df6UzzP71Rbpw0%^5|` z%X+Hh(veFdhqrktQ0H4{iSTmmcde*c?0i(cw zsersc*sK!kDGoKNTL&AJ0ub{YR)u}$0g7X%SWj`NQBmnrXAdH!iZsO#n$CG!niK0O z4mD~zgr-BJkwu!J2o;??TbV=DHJa8aU=%p5fXv+sXdyy~*OU2sCQM%!i_ z2iqC1Uua#QSXU9dw7iEcJRp{Ffc7*){37GS1f#?i`*?>w_bqz|tRZ(2lbzr_l4*W! zm|5S-d3^c_=)0x-#_*^8VV}3DjfwSKQ~qMECs3y9-Fjud5S%CLNg{Hh+Te+u^-ZOR~eq>@Sk|pRFtmNa@v|#2*k;R@(Z-?p_Q?cDxP8G4dSjBsH^sy*9|&xQmK3w zxrO|N#e!9^maLb~el&2BX3~!v&E9kQ>2VYV-l*U7o^ublyMeEAVTX%Qe zx4}Ueed49cBGRWNmi6ax`aR8}k2;=7 zc`nXiRx><#>5O^i`9D4r(=iInRe=Q^ZAIq)+q2*Q&vi{Em{GteFkb~ke#_ZvFy`pj zLN>h2wQbf{tg=YEp+-evqtbDrO2>)ge;BfF(^PWmDGoJqgvGoFP%@atC@@zA{s0M; BJZt~} literal 0 HcmV?d00001 diff --git a/scripts/download_unpack_isic2019.sh b/jobs/download_unpack_isic2019.sh similarity index 100% rename from scripts/download_unpack_isic2019.sh rename to jobs/download_unpack_isic2019.sh diff --git a/jobs/run.sh b/jobs/run.sh deleted file mode 100644 index 598137f..0000000 --- a/jobs/run.sh +++ /dev/null @@ -1,29 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -#!/bin/bash - -#SBATCH --job-name=231n_job -#SBATCH --time=2-23:59:00 -#SBATCH --output=job_output_%j.out -#SBATCH --gres=gpu:1 -#SBATCH -p roxanad -#SBATCH --mem=128G - - -# Loading python version -ml python/3.12 - -# Creating venv and installing requirements -python3 -m venv venv -source venv/bin/activate -pip install -r /home/groups/roxanad/eric/CS231N/requirements.txt - -# WANDB Key -export WANDB_API_KEY="7ab80eeb87ef06298c6bca1258208b1739ad32fe" - -# Running the script -python3 src/models/model_comparison.py \ No newline at end of file diff --git a/jobs/run_compare_baseline.sh b/jobs/run_compare_baseline.sh new file mode 100644 index 0000000..8893737 --- /dev/null +++ b/jobs/run_compare_baseline.sh @@ -0,0 +1,38 @@ +# This source file is part of the ARPA-H CARE LLM project +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +#!/bin/bash +#SBATCH --job-name=compare_baseline +#SBATCH --output=logs/compare_baseline_%j.out +#SBATCH --error=logs/compare_baseline_%j.err +#SBATCH --partition=roxanad +#SBATCH --gres=gpu:1 +#SBATCH --time=12:00:00 +#SBATCH --mem=32G +#SBATCH --cpus-per-task=4 +#SBATCH --output=logs/%x_%j.out +#SBATCH --error=logs/%x_%j.err + +# Load Python module +ml load python/3.12.1 + +# Setup virtual environment if it doesn't exist +if [ ! -d ".venv" ]; then + python3.12 -m pip install uv + uv venv + source .venv/bin/activate + + # Install from requirements.txt + uv pip install -r requirements.txt +else + source .venv/bin/activate +fi + +# WANDB Key +export WANDB_API_KEY="7ab80eeb87ef06298c6bca1258208b1739ad32fe" + +# Running the script +python src/models/comparison/compare_baseline.py --resolution 224 --batch_size 256 --num_train_images 25000 --num_epochs 10 --eval_steps 10 \ No newline at end of file diff --git a/jobs/run_compare_lr_sweep.sh b/jobs/run_compare_lr_sweep.sh new file mode 100644 index 0000000..fd56acd --- /dev/null +++ b/jobs/run_compare_lr_sweep.sh @@ -0,0 +1,38 @@ +# This source file is part of the ARPA-H CARE LLM project +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +#!/bin/bash +#SBATCH --job-name=compare_lr_sweep +#SBATCH --output=logs/compare_lr_sweep_%j.out +#SBATCH --error=logs/compare_lr_sweep_%j.err +#SBATCH --partition=roxanad +#SBATCH --gres=gpu:1 +#SBATCH --time=12:00:00 +#SBATCH --mem=32G +#SBATCH --cpus-per-task=4 +#SBATCH --output=logs/%x_%j.out +#SBATCH --error=logs/%x_%j.err + +# Load Python module +ml load python/3.12.1 + +# Setup virtual environment if it doesn't exist +if [ ! -d ".venv" ]; then + python3.12 -m pip install uv + uv venv + source .venv/bin/activate + + # Install from requirements.txt + uv pip install -r requirements.txt +else + source .venv/bin/activate +fi + +# WANDB Key +export WANDB_API_KEY="7ab80eeb87ef06298c6bca1258208b1739ad32fe" + +# Running the script +python src/models/comparison/compare_lr_sweep.py --resolution 224 --batch_size 256 --num_train_images 25000 --num_epochs 6 --eval_steps 10 --learning_rate 1e-4 \ No newline at end of file diff --git a/jobs/run_lr.sh b/jobs/run_lr.sh deleted file mode 100644 index a8951f7..0000000 --- a/jobs/run_lr.sh +++ /dev/null @@ -1,29 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -#!/bin/bash - -#SBATCH --job-name=231n_job -#SBATCH --time=2-23:59:00 -#SBATCH --output=job_output_%j.out -#SBATCH --gres=gpu:1 -#SBATCH -p roxanad -#SBATCH --mem=128G - - -# Loading python version -ml python/3.12 - -# Creating venv and installing requirements -python3 -m venv venv -source venv/bin/activate -pip install -r /home/groups/roxanad/eric/CS231N/requirements.txt - -# WANDB Key -export WANDB_API_KEY="7ab80eeb87ef06298c6bca1258208b1739ad32fe" - -# Running the script -python src/models/model_comparison_lr.py --resolution 224 --batch_size 256 --num_train_images 25000 --num_epochs 6 --eval_steps 10 --learning_rate 1e-4 \ No newline at end of file diff --git a/jobs/run_models.sh b/jobs/run_models.sh deleted file mode 100644 index 9cbb3f1..0000000 --- a/jobs/run_models.sh +++ /dev/null @@ -1,29 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -#!/bin/bash - -#SBATCH --job-name=231n_job -#SBATCH --time=2-23:59:00 -#SBATCH --output=job_output_%j.out -#SBATCH --gres=gpu:1 -#SBATCH -p roxanad -#SBATCH --mem=128G - - -# Loading python version -ml python/3.12 - -# Creating venv and installing requirements -python3 -m venv venv -source venv/bin/activate -pip install -r /home/groups/roxanad/eric/CS231N/requirements.txt - -# WANDB Key -export WANDB_API_KEY="7ab80eeb87ef06298c6bca1258208b1739ad32fe" - -# Running the script -python src/models/model_comparison_models.py --resolution 224 --batch_size 256 --num_train_images 25000 --num_epochs 10 --eval_steps 10 \ No newline at end of file diff --git a/scripts/submit_from_config.sh b/jobs/submit_from_config.sh similarity index 100% rename from scripts/submit_from_config.sh rename to jobs/submit_from_config.sh diff --git a/src/evaluation/evaluate_isic_results.py b/scripts/visualize_isic_results.py similarity index 100% rename from src/evaluation/evaluate_isic_results.py rename to scripts/visualize_isic_results.py diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e55dc795af18da210d844a92a4ec784082304d6f GIT binary patch literal 6148 zcmeHK!EVz)5S>i}brd1xKqY!xsn@791tTFYR%j0#xYSiSv?#>21s3jZWIIGCisXCy zLwpAkzeC^bZYtWqfgBL3cC6Vq-kmp&-#T6|5sBe69T4@1D1bAzI%s}lJkGviE!Wcq z3R5GennpNwX|j^-JN`!nc<#Dp&T>-p>->FvQ|R-g&jud zQbrYyj8asZ085Nzh-y^Sr0!JuzJ8@-R!!V*`j$+JysG#1eu&ma`@zG_U^D0j@8eH; z7T4o?S`EkLDW4r{T_nqRoSfv@Y}9-FTB~}TtE@DHJk5}D_AXbco(=U>rIq2vb|B~k zol$RlJ|BeP^S%rR;i50+FZaW~eDQ35vFHR_J5OI7jXoC_N?#gzNZ>DF?2*GITp{=i z@xH{RDzy3t*0SQVl1HzASHLT>ogHWo1g6P60JRN=N5!qU+%U0&p{v1sWe+~z~LD+{+n5$5WsFHJg$ z$f9q(0$zdJ3T)WZ9|SFf#bYEAUShxCTr@cvJuY literal 0 HcmV?d00001 diff --git a/src/compressed_perception/.DS_Store b/src/compressed_perception/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..dd228a931d57c45f936321ec25db8898f886506e GIT binary patch literal 6148 zcmeH~K}*9h6vtn5Zk;0RprE&a*MUy0AYRIx2M=DX=s{(!wrH_7vTp7$2EFSS@^c7& z9epo(g5C6}sN})RKTY0ilOH5$06;YRVFREB01i57X%>r5jQYtZY>HdTs6xLZfowlf zmqDU}iD=tk30MNZjR39PGTefK5F#k9U)mH7(7jh}dgz7NV36JMT7L@ek~r(Ft-Xk; z>C#NO;#8bD=iI+m1Hael^|NNDdrD77N+rR2y%QYA(V$&h*i~t-6Q@yEC&XcdA!jFX z8md84_0urZxxVRes!p|CTOJM@uDe;6Zo?hb<#2o5t;?;A_0g#6EUv8X9JX(g+f+U1 zVS~U8O4~f^yF6m0FsXaLn(!ruK=heY}ST#Iu|KNvU;q!QS3nzCKc7B z3cg|plaB3@<2=^fs7VLGmk+^T7JNey`s?_9sl!2dMr~;cSOP@?)8?^8=l|Z<_y1y& zeX;~Bfqz9nlv;kPiBsaUbz*XK)=KmXbTW#|jp`IE_&C-TI*OOjwPBy63aZDN8^snh O`y-$+*uoO{Q37wVX{43_ literal 0 HcmV?d00001 diff --git a/src/compressed_perception/models/.DS_Store b/src/compressed_perception/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..dc15ec31e3abe57cab7d97a046e4f15385f6288b GIT binary patch literal 6148 zcmeHK&x_MQ6n?W??R1OKgMxyGfY+ki)d-81P0Gk8mhXQNx4 z!>Uz|+K9{EY{ccrXC>BSxDj;xhy8Wzlct(k4!wIR`+NS_d;D;3xg79Y_Z~cX5q&C7g#3z4Ll{!HmKzpVaEeN2 zW1quP6jB^xY~{DIcX&k@uzwsMUXk^>WjY2N1DD5u$`2k?VPLT}s8a_D{R9B^(5(e^ z^|7FQJd1(F)*wb8%(w!LE3;1wX57*5S-ZeuYtXn8vyTsEXJ+<=!u0Hj?`d;lfk9U~ z1{?$H3~Za`NZtR3zkmO)JGqu)z%lS&F~E8U;lUJJvUh7|bLy`3pdX-0lwE63yP&Yw hv974A_%>7v#ynL41{Pa`=z+Ko0ZoG|90Qliz)u!!p|Jn} literal 0 HcmV?d00001 diff --git a/src/models/__init__.py b/src/compressed_perception/models/__init__.py similarity index 100% rename from src/models/__init__.py rename to src/compressed_perception/models/__init__.py diff --git a/src/models/utils/__init__.py b/src/compressed_perception/models/comparison/__init__.py similarity index 100% rename from src/models/utils/__init__.py rename to src/compressed_perception/models/comparison/__init__.py diff --git a/src/models/model_comparison_baseline.py b/src/compressed_perception/models/comparison/compare_baseline.py similarity index 55% rename from src/models/model_comparison_baseline.py rename to src/compressed_perception/models/comparison/compare_baseline.py index 4795202..39425e9 100644 --- a/src/models/model_comparison_baseline.py +++ b/src/compressed_perception/models/comparison/compare_baseline.py @@ -22,11 +22,8 @@ # Environment Setup import os -import io -import numpy as np +import argparse from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms from transformers import ( AutoImageProcessor, AutoModelForImageClassification, @@ -39,15 +36,19 @@ import wandb import timm -from src.models.utils.constants import ( + +from src.compressed_perception.models.training.constants import ( SSL_MODEL, SIMCLR_BACKBONE, FILTERED_CLASSES, NUM_FILTERED_CLASSES ) -from src.models.utils.transforms import ( - JPEGCompressionTransform, GaussianBlurTransform, +from src.compressed_perception.models.training.utils_classes import SimCLRForClassification +from src.compressed_perception.models.training.utils_methods import get_gpu_memory, GPU_AVAILABLE, freeze_backbone +from src.compressed_perception.modules.data_preparation.preparation import ( + filter_and_cast_dataset, + balance_dataset, + split_dataset, + get_default_transforms, + prepare_datasets, ) -from src.models.utils.utils_classes import SimCLRForClassification, LossLoggerCallback -from src.models.utils.utils_methods import get_gpu_memory, GPU_AVAILABLE, freeze_backbone -from src.models.utils.transforms import ColorQuantizationTransform # Compatibility for LANCZOS resampling try: @@ -64,31 +65,6 @@ GPU_AVAILABLE = False print("pynvml not installed, GPU memory monitoring disabled.") -class WandbCallback: - """Callback for logging to Weights & Biases.""" - def __init__(self, model_name, phase): - self.model_name = model_name - self.phase = phase - self.best_accuracy = 0.0 - - def on_log(self, _args, _state, _control, logs=None, **_kwargs): - """ - Log metrics to Weights & Biases.""" - if logs is not None: - logs["model"] = self.model_name - logs["phase"] = self.phase - if GPU_AVAILABLE: - logs["gpu_memory_mb"] = get_gpu_memory() - wandb.log(logs) - - def on_evaluate(self, _args, _state, _control, metrics=None, **_kwargs): - """Log evaluation metrics to Weights & Biases.""" - if metrics is not None: - if "eval_accuracy" in metrics: - self.best_accuracy = max(self.best_accuracy, metrics["eval_accuracy"]) - metrics["best_accuracy"] = self.best_accuracy - wandb.log(metrics) - def initialize_model_and_preprocessor(model_info, resolution): """ @@ -148,76 +124,6 @@ def initialize_model_and_preprocessor(model_info, resolution): return model, preprocessor -def balance_dataset(dataset, num_train_images, filtered_classes): - """ - Balance the dataset by sampling equal images per class. - """ - print("Balancing dataset...") - class_counts = {label: 0 for label in filtered_classes} - for label in dataset["label"]: - class_counts[str(label)] += 1 - - min_class_size = min(class_counts.values()) - images_per_class = min(num_train_images // len(filtered_classes), min_class_size) - - np.random.seed(42) - balanced_indices = [] - for label in filtered_classes: - class_indices = [i for i, l in enumerate(dataset["label"]) if str(l) == label] - sampled_indices = np.random.choice(class_indices, images_per_class, replace=False) - balanced_indices.extend(sampled_indices) - - np.random.shuffle(balanced_indices) - return dataset.select(balanced_indices) - - -def prepare_datasets(dataset, _preprocessor, resolution, apply_transforms=False): - """ - Prepare train and validation datasets. - """ - train_size = int(0.8 * len(dataset)) - val_size = len(dataset) - train_size - train_dataset = dataset.select(range(train_size)) - val_dataset = dataset.select(range(train_size, train_size + val_size)) - - transform = transforms.Compose([ - transforms.Resize((resolution, resolution)), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ]) - - if apply_transforms: - transform = transforms.Compose([ - transform, - JPEGCompressionTransform(quality=75), - GaussianBlurTransform(p=0.5), - ColorQuantizationTransform(p=0.5), - ]) - - class TorchDataset(Dataset): - """ - Custom dataset class for Hugging Face datasets. - """ - def __init__(self, hf_dataset, transform): - self.hf_dataset = hf_dataset - self.transform = transform - - def __len__(self): - return len(self.hf_dataset) - - def __getitem__(self, idx): - item = self.hf_dataset[idx] - image = Image.open(io.BytesIO(item["image"])).convert("RGB") - if self.transform: - image = self.transform(image) - label = int(item["label"]) - return {"pixel_values": image, "labels": label} - - train_ds = TorchDataset(train_dataset, transform) - val_ds = TorchDataset(val_dataset, transform) - return train_ds, val_ds - - def train_model( model, train_ds, @@ -286,19 +192,8 @@ def train_model( return eval_results -def get_trainer_callbacks(name): - """Get callbacks for the Trainer.""" - return [ - LossLoggerCallback( - log_dir=os.environ.get("LOG_DIR", "./logs"), - phase="finetune", - model_name=name, - ), - WandbCallback(name, "finetune"), - ] - -def main(config=None): +def main(config=None, dataset=None): """ Main pipeline for model comparison. """ @@ -317,27 +212,43 @@ def main(config=None): wandb_config = config.copy() wandb_config["weight_decay"] = 0.01 + MODEL_CONFIG_KEYS = { + "name": "name", + "model_id": "model_id", + "type": "type", + "config": "config", + } + models = [ - {"name": "vit", "model_id": "google/vit-base-patch16-224", "type": "vit", "config": { - "image_size": config["resolution"], - "num_labels": NUM_FILTERED_CLASSES, - "ignore_mismatched_sizes": True - }}, + { + MODEL_CONFIG_KEYS["name"]: "vit", + MODEL_CONFIG_KEYS["model_id"]: "google/vit-base-patch16-224", + MODEL_CONFIG_KEYS["type"]: "vit", + MODEL_CONFIG_KEYS["config"]: { + "image_size": config["resolution"], + "num_labels": NUM_FILTERED_CLASSES, + "ignore_mismatched_sizes": True + } + }, ] - dataset = load_dataset( - "MKZuziak/ISIC_2019_224", - cache_dir=os.environ["HF_DATASETS_CACHE"], - split="train", - ) + if dataset is None: + raise ValueError("Dataset must be provided via `dataset` argument or CLI.") + # Use preparation.py functions for filtering, balancing, and splitting + dataset = filter_and_cast_dataset(dataset, FILTERED_CLASSES, NUM_FILTERED_CLASSES) dataset = balance_dataset(dataset, config["num_train_images"], FILTERED_CLASSES) + splits = split_dataset(dataset, test_size=0.2, stratify_by_column="label", seed=42) + + # Get transforms + transform = get_default_transforms(config["resolution"], apply_transforms=True) + + # Prepare PyTorch datasets + train_ds, val_ds = prepare_datasets(splits["train"], transform, split_ratio=1.0) + val_ds, _ = prepare_datasets(splits["test"], transform, split_ratio=1.0) for model_info in models: - model, preprocessor = initialize_model_and_preprocessor(model_info, config["resolution"]) - train_ds, val_ds = prepare_datasets( - dataset, preprocessor, config["resolution"], apply_transforms=True - ) + model, _preprocessor = initialize_model_and_preprocessor(model_info, config["resolution"]) train_config = { "model_name": model_info["name"], @@ -355,17 +266,51 @@ def main(config=None): ) print(f"Results for {model_info['name']}: {results}") + METRIC_KEYS = { + "learning_rate": "learning_rate", + "model_name": "model_name", + "model_type": "model_type", + "peak_memory_mb": "peak_memory_mb", + "flops_giga": "flops_giga", + "train_time_seconds": "train_time_seconds", + "eval_time_seconds": "eval_time_seconds", + "eval_metrics": "eval_metrics", + } + metrics = { - "learning_rate": config["learning_rate"], - "model_name": model_info["name"], - "model_type": model_info["type"], - "peak_memory_mb": get_gpu_memory(), - "flops_giga": None, # Placeholder for FLOPs, calculate if needed - "train_time_seconds": None, # Placeholder for training time, calculate if needed - "eval_time_seconds": None, # Placeholder for evaluation time, calculate if needed - "eval_metrics": results, + METRIC_KEYS["learning_rate"]: config["learning_rate"], + METRIC_KEYS["model_name"]: model_info[MODEL_CONFIG_KEYS["name"]], + METRIC_KEYS["model_type"]: model_info[MODEL_CONFIG_KEYS["type"]], + METRIC_KEYS["peak_memory_mb"]: get_gpu_memory(), + METRIC_KEYS["flops_giga"]: None, + METRIC_KEYS["train_time_seconds"]: None, + METRIC_KEYS["eval_time_seconds"]: None, + METRIC_KEYS["eval_metrics"]: results, } wandb.log({"metrics": metrics}) + if __name__ == "__main__": - main() + parser = argparse.ArgumentParser(description="Baseline model comparison for image classification.") + parser.add_argument( + "--path_to_dataset", + type=str, + default=None, + help="Path to local dataset directory. If not provided, loads from Hugging Face.", + ) + args = parser.parse_args() + + dataset = None + if args.path_to_dataset: + dataset = load_dataset("imagefolder", data_dir=args.path_to_dataset, split="train") + else: + try: + dataset = load_dataset( + "MKZuziak/ISIC_2019_224", + cache_dir=os.environ["HF_DATASETS_CACHE"], + split="train", + ) + except Exception as e: + raise ValueError("No dataset provided and Hugging Face dataset failed to load.") from e + + main(dataset=dataset) \ No newline at end of file diff --git a/src/models/model_comparison_lr_sweep.py b/src/compressed_perception/models/comparison/compare_lr_sweep.py similarity index 72% rename from src/models/model_comparison_lr_sweep.py rename to src/compressed_perception/models/comparison/compare_lr_sweep.py index d379e89..719b66b 100644 --- a/src/models/model_comparison_lr_sweep.py +++ b/src/compressed_perception/models/comparison/compare_lr_sweep.py @@ -44,12 +44,11 @@ AutoImageProcessor, AutoModelForImageClassification, Trainer, - TrainerCallback, TrainingArguments, ViTFeatureExtractor, ViTForImageClassification, ) -from datasets import load_dataset, ClassLabel +from datasets import load_dataset # Weights & Biases import wandb @@ -59,22 +58,25 @@ # Local Application Imports -from src.models.utils.constants import ( +from src.compressed_perception.models.training.constants import ( HF_MODELS, SSL_MODEL, SIMCLR_BACKBONE, FILTERED_CLASSES, NUM_FILTERED_CLASSES ) -from src.models.utils.transforms import ( +from src.compressed_perception.modules.data_transformation.image_transformation import ( JPEGCompressionTransform, GaussianBlurTransform, ColorQuantizationTransform, ) -from src.models.utils.utils_classes import ( +from src.compressed_perception.models.training.utils_classes import ( ISICDataset, SimCLRForClassification, - LossLoggerCallback + WandbCallback, + LossLoggerCallback, + get_trainer_callbacks ) -from src.models.utils.utils_methods import ( +from src.compressed_perception.models.training.utils_methods import ( env_path, compute_metrics, + cleanup_model_dirs, get_gpu_memory, freeze_backbone, GPU_AVAILABLE, @@ -83,6 +85,9 @@ save_model_and_preprocessor, ) +from src.compressed_perception.modules import ( + filter_and_cast_dataset, balance_dataset, split_dataset, get_default_transforms, prepare_datasets +) # Set memory optimization os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" @@ -95,104 +100,6 @@ ) os.environ["HF_HOME"] = os.getenv("HF_HOME", "~/.cache/huggingface") -class WandbCallback(TrainerCallback): - """ - Custom callback for logging metrics and evaluation results to Weights & Biases. - Tracks best accuracy and GPU memory usage if available. - """ - def __init__(self, model_name, phase): - """ - Args: - model_name (str): Name of the model. - phase (str): Training phase (e.g., 'finetune'). - """ - self.model_name = model_name - self.phase = phase - self.best_accuracy = 0.0 - - def on_log(self, _args, _state, _control, logs=None, **_kwargs): - """ - Log metrics to Weights & Biases. - - Args: - _args: Trainer arguments (unused). - _state: Trainer state (unused). - _control: Trainer control (unused). - logs (dict): Metrics to log. - **_kwargs: Additional keyword arguments (unused). - """ - if logs is not None: - logs["model"] = self.model_name - logs["phase"] = self.phase - if GPU_AVAILABLE: - logs["gpu_memory_mb"] = get_gpu_memory() - wandb.log(logs) - - def on_evaluate(self, _args, _state, _control, metrics=None, **_kwargs): - """ - Log evaluation metrics to Weights & Biases. - - Args: - _args: Trainer arguments (unused). - _state: Trainer state (unused). - _control: Trainer control (unused). - metrics (dict): Evaluation metrics. - **_kwargs: Additional keyword arguments (unused). - """ - if metrics is not None: - if "eval_accuracy" in metrics: - self.best_accuracy = max(self.best_accuracy, metrics["eval_accuracy"]) - metrics["best_accuracy"] = self.best_accuracy - wandb.log(metrics) - -def prepare_balanced_datasets(dataset, config): - """ - Filter, balance, and split the dataset into train and validation sets. - """ - # Get indices of images with desired labels - filtered_indices = [ - i for i, label in enumerate(dataset["label"]) - if str(label) in FILTERED_CLASSES # Convert to string for comparison - ] - - # Select only those indices - dataset = dataset.select(filtered_indices) - print(f"Number of images after filtering for classes {FILTERED_CLASSES}: {len(dataset)}") - dataset = dataset.cast_column("label", ClassLabel(num_classes=NUM_FILTERED_CLASSES)) - - # Get class counts and balance dataset - optimized version - print("Balancing dataset...") - # Get counts for each class - class_counts = {label: 0 for label in FILTERED_CLASSES} - for label in dataset["label"]: - class_counts[str(label)] += 1 # Convert to string for dictionary key - - print(f"Class counts: {class_counts}") # Debug print to verify counts - - # Calculate how many images to use per class - min_class_size = min(class_counts.values()) - images_per_class = min(config["num_train_images"] // 2, min_class_size) - - # Sample indices for each class - np.random.seed(42) - balanced_indices = [] - for label in FILTERED_CLASSES: - class_indices = [i for i, l in enumerate(dataset["label"]) if str(l) == label] - print(f"Found {len(class_indices)} images for class {label}") # Debug print - sampled_indices = np.random.choice(class_indices, images_per_class, replace=False) - balanced_indices.extend(sampled_indices) - - np.random.shuffle(balanced_indices) - balanced_dataset = dataset.select(balanced_indices) - - # Split into train and validation - full_dataset = balanced_dataset.train_test_split( - test_size=0.2, stratify_by_column="label", seed=42 - ) - - train_dataset, val_dataset = full_dataset["train"], full_dataset["test"] - - return train_dataset, val_dataset def create_preprocessors(model_config, config): """ @@ -443,34 +350,6 @@ def get_training_args(name, learning_rate, config): save_only_model=True, ) -def cleanup_model_dirs(name, learning_rate): - """ - Removes and recreates model/log directories for the current run. - """ - model_dirs = [ - os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}_lr_{learning_rate}"), - os.path.join(env_path("MODEL_DIR", "."), f"{name}_lr_{learning_rate}"), - os.path.join(env_path("LOG_DIR", "."), f"{name}_lr_{learning_rate}"), - ] - for dir_path in model_dirs: - if os.path.exists(dir_path): - print(f"Cleaning up directory: {dir_path}") - shutil.rmtree(dir_path) - os.makedirs(dir_path, exist_ok=True) - -def get_trainer_callbacks(name): - """ - Returns a list of Trainer callbacks for logging and monitoring. - """ - return [ - LossLoggerCallback( - log_dir=env_path("LOG_DIR", "./logs"), - phase="finetune", - model_name=name, - ), - WandbCallback(name, "finetune"), - ] - def main(config=None): """ Main function for running learning rate sweep experiments on image classification models. @@ -509,7 +388,12 @@ def main(config=None): split="train", ) - train_dataset, val_dataset = prepare_balanced_datasets(dataset, config) + filtered_dataset = filter_and_cast_dataset(dataset, FILTERED_CLASSES, NUM_FILTERED_CLASSES) + balanced_dataset = balance_dataset(filtered_dataset, config["num_train_images"], FILTERED_CLASSES) + splits = split_dataset(balanced_dataset) + transform = get_default_transforms(resolution=224, apply_transforms=True) + train_dataset, val_dataset = prepare_datasets(splits["train"], transform) + preprocessors = create_preprocessors(model_config, config) config["preprocessors"] = preprocessors diff --git a/src/compressed_perception/models/training/.DS_Store b/src/compressed_perception/models/training/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b7e02b4c56b209f81be5eab78eb861c30c047690 GIT binary patch literal 6148 zcmeHKyH3ME5S$G`3JQfgDE$RU{J|*-1qBrqRUnB0OAbMaQ_%Sz`2aozX7^4JaSSB_ zv}^6nJ$Cl&Im_n*AfwIv9LN9+=!(4~79FPJ>U-8pJ%>cG93^T@c%9*SFWNhPqXN2i zLtC>FPx!okFPGJFHm#OZM(FEtMGx<0tO3WEVM)%KYko(pS<`iI^IluL6V9?A;{gTd zt+6up8u!e$vXPn*nzrOSTme_$&ndt?TWv6m=%p**3b+DW1@!xn(-lL)CZc^hSlAJO*k`vH>-w{(oIGJj z*hJ(AO(K;TsgjNu66x&EQ(Q>cM2vJuI($f)S<(qbY+)`Mvt(dX86`#`0*q>{K V7!o!S*+cUm0h7TCSKvn#_yA-yUa$ZF literal 0 HcmV?d00001 diff --git a/src/compressed_perception/models/training/__init__.py b/src/compressed_perception/models/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/utils/constants.py b/src/compressed_perception/models/training/constants.py similarity index 100% rename from src/models/utils/constants.py rename to src/compressed_perception/models/training/constants.py diff --git a/src/models/utils/utils_classes.py b/src/compressed_perception/models/training/utils_classes.py similarity index 75% rename from src/models/utils/utils_classes.py rename to src/compressed_perception/models/training/utils_classes.py index 069d2c0..7ac0925 100644 --- a/src/models/utils/utils_classes.py +++ b/src/compressed_perception/models/training/utils_classes.py @@ -18,8 +18,8 @@ from torchvision import transforms from transformers import TrainerCallback -from .constants import HF_MODELS, NUM_FILTERED_CLASSES, SSL_MODEL -from .transforms import JPEGCompressionTransform +from src.compressed_perception.models.training.constants import HF_MODELS, NUM_FILTERED_CLASSES, SSL_MODEL +from src.compressed_perception.modules.data_transformation.image_transformation import JPEGCompressionTransform # Compatibility for LANCZOS resampling try: @@ -28,6 +28,56 @@ LANCZOS = Image.LANCZOS # pylint: disable=no-member +from transformers import TrainerCallback + +class WandbCallback(TrainerCallback): + """ + Custom callback for logging metrics and evaluation results to Weights & Biases. + Tracks best accuracy and GPU memory usage if available. + """ + MODEL_KEY = "model" + PHASE_KEY = "phase" + BEST_ACCURACY_KEY = "best_accuracy" + EVAL_ACCURACY_KEY = "eval_accuracy" + GPU_MEMORY_KEY = "gpu_memory_mb" + + def __init__(self, model_name, phase): + self.model_name = model_name + self.phase = phase + self.best_accuracy = 0.0 + + def on_log(self, _args, _state, _control, logs=None, **_kwargs): + if logs is not None: + logs[self.MODEL_KEY] = self.model_name + logs[self.PHASE_KEY] = self.phase + try: + from src.compressed_perception.models.training.utils_methods import GPU_AVAILABLE, get_gpu_memory + if GPU_AVAILABLE: + logs[self.GPU_MEMORY_KEY] = get_gpu_memory() + except ImportError: + pass + import wandb + wandb.log(logs) + + def on_evaluate(self, _args, _state, _control, metrics=None, **_kwargs): + if metrics is not None: + if self.EVAL_ACCURACY_KEY in metrics: + self.best_accuracy = max(self.best_accuracy, metrics[self.EVAL_ACCURACY_KEY]) + metrics[self.BEST_ACCURACY_KEY] = self.best_accuracy + import wandb + wandb.log(metrics) + +def get_trainer_callbacks(name): + """Get callbacks for the Trainer.""" + return [ + LossLoggerCallback( + log_dir=os.environ.get("LOG_DIR", "./logs"), + phase="finetune", + model_name=name, + ), + WandbCallback(name, "finetune"), + ] + class ISICDataset(Dataset): """ Dataset class for handling ISIC image data with optional transformations. diff --git a/src/models/utils/utils_methods.py b/src/compressed_perception/models/training/utils_methods.py similarity index 87% rename from src/models/utils/utils_methods.py rename to src/compressed_perception/models/training/utils_methods.py index dec220e..c716ea4 100644 --- a/src/models/utils/utils_methods.py +++ b/src/compressed_perception/models/training/utils_methods.py @@ -18,7 +18,9 @@ import seaborn as sns from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix import pynvml -from .constants import HF_MODELS + +# Local imports +from src.compressed_perception.models.training.constants import HF_MODELS # Constants GPU_AVAILABLE = torch.cuda.is_available() @@ -64,6 +66,20 @@ def compute_metrics(eval_pred, model_name): return {"accuracy": acc, "f1": f1, "auc": auc} +def cleanup_model_dirs(name, learning_rate): + """ + Removes and recreates model/log directories for the current run. + """ + model_dirs = [ + os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}_lr_{learning_rate}"), + os.path.join(env_path("MODEL_DIR", "."), f"{name}_lr_{learning_rate}"), + os.path.join(env_path("LOG_DIR", "."), f"{name}_lr_{learning_rate}"), + ] + for dir_path in model_dirs: + if os.path.exists(dir_path): + print(f"Cleaning up directory: {dir_path}") + shutil.rmtree(dir_path) + os.makedirs(dir_path, exist_ok=True) def get_gpu_memory(device_id=0): """ @@ -80,7 +96,6 @@ def get_gpu_memory(device_id=0): except Exception: # pylint: disable=broad-exception-caught return -1 - def freeze_backbone(model, model_type): """ Freeze the backbone of the model based on its type. @@ -110,7 +125,7 @@ def get_flops(model, resolution): """ try: dummy_input = torch.randn(1, 3, resolution, resolution).to(next(model.parameters()).device) - flops, _ = profile(model, inputs=(dummy_input,)) + flops, _, _ = profile(model, inputs=(dummy_input,)) return flops / 1e9 except Exception as e: # pylint: disable=broad-exception-caught print(f"FLOP profiling failed: {e}") diff --git a/src/compressed_perception/modules/__init__.py b/src/compressed_perception/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/compressed_perception/modules/data_preparation/__init__.py b/src/compressed_perception/modules/data_preparation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/compressed_perception/modules/data_preparation/preparation.py b/src/compressed_perception/modules/data_preparation/preparation.py new file mode 100644 index 0000000..7c2615b --- /dev/null +++ b/src/compressed_perception/modules/data_preparation/preparation.py @@ -0,0 +1,101 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +import numpy as np +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +import io +from datasets import ClassLabel + +def filter_and_cast_dataset(dataset, filtered_classes, num_classes): + """ + Filter dataset by class labels and cast label column. + """ + filtered_indices = [ + i for i, label in enumerate(dataset["label"]) + if str(label) in filtered_classes + ] + dataset = dataset.select(filtered_indices) + dataset = dataset.cast_column("label", ClassLabel(num_classes=num_classes)) + return dataset + +def balance_dataset(dataset, num_train_images, filtered_classes): + """ + Balance the dataset by sampling an equal number of images per class. + """ + class_counts = {label: 0 for label in filtered_classes} + for label in dataset["label"]: + class_counts[str(label)] += 1 + + min_class_size = min(class_counts.values()) + images_per_class = min(num_train_images // len(filtered_classes), min_class_size) + + np.random.seed(42) + balanced_indices = [] + for label in filtered_classes: + class_indices = [i for i, l in enumerate(dataset["label"]) if str(l) == label] + sampled_indices = np.random.choice(class_indices, images_per_class, replace=False) + balanced_indices.extend(sampled_indices) + + np.random.shuffle(balanced_indices) + return dataset.select(balanced_indices) + +def split_dataset(dataset, test_size=0.2, stratify_by_column="label", seed=42): + """ + Split dataset into train and validation sets. + """ + return dataset.train_test_split(test_size=test_size, stratify_by_column=stratify_by_column, seed=seed) + +def get_default_transforms(resolution, apply_transforms=False): + """ + Get torchvision transform pipeline. + """ + transform_list = [ + transforms.Resize((resolution, resolution)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + + if apply_transforms: + from src.compressed_perception.modules.data_transformation.image_transformation import ( + JPEGCompressionTransform, + GaussianBlurTransform, + ColorQuantizationTransform + ) + transform_list.extend([ + JPEGCompressionTransform(quality=75), + GaussianBlurTransform(p=0.5), + ColorQuantizationTransform(p=0.5), + ]) + + return transforms.Compose(transform_list) + +class TorchDataset(Dataset): + """ + PyTorch Dataset wrapper for Hugging Face datasets. + """ + def __init__(self, hf_dataset, transform): + self.hf_dataset = hf_dataset + self.transform = transform + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + item = self.hf_dataset[idx] + image = Image.open(io.BytesIO(item["image"])).convert("RGB") + image = self.transform(image) if self.transform else image + return {"pixel_values": image, "labels": int(item["label"])} + +def prepare_datasets(dataset, transform, split_ratio=0.8): + """ + Prepare PyTorch-compatible train and val datasets. + """ + train_size = int(split_ratio * len(dataset)) + train_dataset = dataset.select(range(train_size)) + val_dataset = dataset.select(range(train_size, len(dataset))) + return TorchDataset(train_dataset, transform), TorchDataset(val_dataset, transform) diff --git a/src/compressed_perception/modules/data_transformation/__init__.py b/src/compressed_perception/modules/data_transformation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/utils/transforms.py b/src/compressed_perception/modules/data_transformation/image_transformation.py similarity index 100% rename from src/models/utils/transforms.py rename to src/compressed_perception/modules/data_transformation/image_transformation.py From 68b39e3665da59260ed3bf4fc5f242d8b4660874 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 11 Jul 2025 18:06:17 -0700 Subject: [PATCH 10/26] refactor and restructure --- .DS_Store | Bin 8196 -> 0 bytes .gitignore | 2 + src/.DS_Store | Bin 6148 -> 0 bytes src/compressed_perception/.DS_Store | Bin 6148 -> 0 bytes src/compressed_perception/models/.DS_Store | Bin 6148 -> 0 bytes .../models/comparison/compare_baseline.py | 125 +++++++++++------- .../models/comparison/compare_lr_sweep.py | 2 + .../models/training/.DS_Store | Bin 6148 -> 0 bytes 8 files changed, 78 insertions(+), 51 deletions(-) delete mode 100644 .DS_Store delete mode 100644 src/.DS_Store delete mode 100644 src/compressed_perception/.DS_Store delete mode 100644 src/compressed_perception/models/.DS_Store delete mode 100644 src/compressed_perception/models/training/.DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 4f9f60cbb31db681b5b5082a7386b91895c2a534..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHML2DC17=4q*WdJe zgZ>VGN5QNAz^g}nGqVZFwxOPa;7pl$lbP>*-^`oMZoVx5(df6UzzP71Rbpw0%^5|` z%X+Hh(veFdhqrktQ0H4{iSTmmcde*c?0i(cw zsersc*sK!kDGoKNTL&AJ0ub{YR)u}$0g7X%SWj`NQBmnrXAdH!iZsO#n$CG!niK0O z4mD~zgr-BJkwu!J2o;??TbV=DHJa8aU=%p5fXv+sXdyy~*OU2sCQM%!i_ z2iqC1Uua#QSXU9dw7iEcJRp{Ffc7*){37GS1f#?i`*?>w_bqz|tRZ(2lbzr_l4*W! zm|5S-d3^c_=)0x-#_*^8VV}3DjfwSKQ~qMECs3y9-Fjud5S%CLNg{Hh+Te+u^-ZOR~eq>@Sk|pRFtmNa@v|#2*k;R@(Z-?p_Q?cDxP8G4dSjBsH^sy*9|&xQmK3w zxrO|N#e!9^maLb~el&2BX3~!v&E9kQ>2VYV-l*U7o^ublyMeEAVTX%Qe zx4}Ueed49cBGRWNmi6ax`aR8}k2;=7 zc`nXiRx><#>5O^i`9D4r(=iInRe=Q^ZAIq)+q2*Q&vi{Em{GteFkb~ke#_ZvFy`pj zLN>h2wQbf{tg=YEp+-evqtbDrO2>)ge;BfF(^PWmDGoJqgvGoFP%@atC@@zA{s0M; BJZt~} diff --git a/.gitignore b/.gitignore index f80baca..d2304b4 100644 --- a/.gitignore +++ b/.gitignore @@ -211,3 +211,5 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + +**/.DS_Store diff --git a/src/.DS_Store b/src/.DS_Store deleted file mode 100644 index e55dc795af18da210d844a92a4ec784082304d6f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK!EVz)5S>i}brd1xKqY!xsn@791tTFYR%j0#xYSiSv?#>21s3jZWIIGCisXCy zLwpAkzeC^bZYtWqfgBL3cC6Vq-kmp&-#T6|5sBe69T4@1D1bAzI%s}lJkGviE!Wcq z3R5GennpNwX|j^-JN`!nc<#Dp&T>-p>->FvQ|R-g&jud zQbrYyj8asZ085Nzh-y^Sr0!JuzJ8@-R!!V*`j$+JysG#1eu&ma`@zG_U^D0j@8eH; z7T4o?S`EkLDW4r{T_nqRoSfv@Y}9-FTB~}TtE@DHJk5}D_AXbco(=U>rIq2vb|B~k zol$RlJ|BeP^S%rR;i50+FZaW~eDQ35vFHR_J5OI7jXoC_N?#gzNZ>DF?2*GITp{=i z@xH{RDzy3t*0SQVl1HzASHLT>ogHWo1g6P60JRN=N5!qU+%U0&p{v1sWe+~z~LD+{+n5$5WsFHJg$ z$f9q(0$zdJ3T)WZ9|SFf#bYEAUShxCTr@cvJuY diff --git a/src/compressed_perception/.DS_Store b/src/compressed_perception/.DS_Store deleted file mode 100644 index dd228a931d57c45f936321ec25db8898f886506e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~K}*9h6vtn5Zk;0RprE&a*MUy0AYRIx2M=DX=s{(!wrH_7vTp7$2EFSS@^c7& z9epo(g5C6}sN})RKTY0ilOH5$06;YRVFREB01i57X%>r5jQYtZY>HdTs6xLZfowlf zmqDU}iD=tk30MNZjR39PGTefK5F#k9U)mH7(7jh}dgz7NV36JMT7L@ek~r(Ft-Xk; z>C#NO;#8bD=iI+m1Hael^|NNDdrD77N+rR2y%QYA(V$&h*i~t-6Q@yEC&XcdA!jFX z8md84_0urZxxVRes!p|CTOJM@uDe;6Zo?hb<#2o5t;?;A_0g#6EUv8X9JX(g+f+U1 zVS~U8O4~f^yF6m0FsXaLn(!ruK=heY}ST#Iu|KNvU;q!QS3nzCKc7B z3cg|plaB3@<2=^fs7VLGmk+^T7JNey`s?_9sl!2dMr~;cSOP@?)8?^8=l|Z<_y1y& zeX;~Bfqz9nlv;kPiBsaUbz*XK)=KmXbTW#|jp`IE_&C-TI*OOjwPBy63aZDN8^snh O`y-$+*uoO{Q37wVX{43_ diff --git a/src/compressed_perception/models/.DS_Store b/src/compressed_perception/models/.DS_Store deleted file mode 100644 index dc15ec31e3abe57cab7d97a046e4f15385f6288b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK&x_MQ6n?W??R1OKgMxyGfY+ki)d-81P0Gk8mhXQNx4 z!>Uz|+K9{EY{ccrXC>BSxDj;xhy8Wzlct(k4!wIR`+NS_d;D;3xg79Y_Z~cX5q&C7g#3z4Ll{!HmKzpVaEeN2 zW1quP6jB^xY~{DIcX&k@uzwsMUXk^>WjY2N1DD5u$`2k?VPLT}s8a_D{R9B^(5(e^ z^|7FQJd1(F)*wb8%(w!LE3;1wX57*5S-ZeuYtXn8vyTsEXJ+<=!u0Hj?`d;lfk9U~ z1{?$H3~Za`NZtR3zkmO)JGqu)z%lS&F~E8U;lUJJvUh7|bLy`3pdX-0lwE63yP&Yw hv974A_%>7v#ynL41{Pa`=z+Ko0ZoG|90Qliz)u!!p|Jn} diff --git a/src/compressed_perception/models/comparison/compare_baseline.py b/src/compressed_perception/models/comparison/compare_baseline.py index 39425e9..8b23db7 100644 --- a/src/compressed_perception/models/comparison/compare_baseline.py +++ b/src/compressed_perception/models/comparison/compare_baseline.py @@ -4,6 +4,7 @@ # # SPDX-License-Identifier: MIT +# pylint: disable=invalid-name """ This script is a baseline for comparing different image classification models at three different image compression levels, in comparison to the original. @@ -41,7 +42,9 @@ SSL_MODEL, SIMCLR_BACKBONE, FILTERED_CLASSES, NUM_FILTERED_CLASSES ) from src.compressed_perception.models.training.utils_classes import SimCLRForClassification -from src.compressed_perception.models.training.utils_methods import get_gpu_memory, GPU_AVAILABLE, freeze_backbone +from src.compressed_perception.models.training.utils_methods import ( + get_gpu_memory, GPU_AVAILABLE, freeze_backbone +) from src.compressed_perception.modules.data_preparation.preparation import ( filter_and_cast_dataset, balance_dataset, @@ -209,9 +212,25 @@ def main(config=None, dataset=None): "gpu_available": GPU_AVAILABLE, } - wandb_config = config.copy() - wandb_config["weight_decay"] = 0.01 + # Add weight_decay to config instead of creating a separate wandb_config + config["weight_decay"] = 0.01 + + # Define all model configurations + models = get_model_configs(config["resolution"]) + + if dataset is None: # Changed from 'data' to 'dataset' + raise ValueError("Dataset must be provided via `dataset` argument or CLI.") + + # Process dataset and create training/validation splits + train_ds, val_ds = process_dataset(dataset, config) # Changed from 'data' to 'dataset' + # Train and evaluate each model + for model_info in models: + train_and_evaluate_model(model_info, train_ds, val_ds, config) + + +def get_model_configs(resolution): + """Return the model configurations for the baseline comparison.""" MODEL_CONFIG_KEYS = { "name": "name", "model_id": "model_id", @@ -219,23 +238,22 @@ def main(config=None, dataset=None): "config": "config", } - models = [ + return [ { MODEL_CONFIG_KEYS["name"]: "vit", MODEL_CONFIG_KEYS["model_id"]: "google/vit-base-patch16-224", MODEL_CONFIG_KEYS["type"]: "vit", MODEL_CONFIG_KEYS["config"]: { - "image_size": config["resolution"], + "image_size": resolution, "num_labels": NUM_FILTERED_CLASSES, "ignore_mismatched_sizes": True } }, ] - if dataset is None: - raise ValueError("Dataset must be provided via `dataset` argument or CLI.") - # Use preparation.py functions for filtering, balancing, and splitting +def process_dataset(dataset, config): + """Process the dataset and create training/validation splits.""" dataset = filter_and_cast_dataset(dataset, FILTERED_CLASSES, NUM_FILTERED_CLASSES) dataset = balance_dataset(dataset, config["num_train_images"], FILTERED_CLASSES) splits = split_dataset(dataset, test_size=0.2, stratify_by_column="label", seed=42) @@ -244,54 +262,59 @@ def main(config=None, dataset=None): transform = get_default_transforms(config["resolution"], apply_transforms=True) # Prepare PyTorch datasets - train_ds, val_ds = prepare_datasets(splits["train"], transform, split_ratio=1.0) + train_ds, _ = prepare_datasets(splits["train"], transform, split_ratio=1.0) val_ds, _ = prepare_datasets(splits["test"], transform, split_ratio=1.0) - for model_info in models: - model, _preprocessor = initialize_model_and_preprocessor(model_info, config["resolution"]) - - train_config = { - "model_name": model_info["name"], - "model_type": model_info["type"], - "resolution": config["resolution"], - "batch_size": config["batch_size"], - "num_epochs": config["num_epochs"], - "learning_rate": config["learning_rate"], - "eval_steps": config["eval_steps"], - "wandb_config": wandb_config, - } + return train_ds, val_ds - results = train_model( - model, train_ds, val_ds, train_config - ) - print(f"Results for {model_info['name']}: {results}") - - METRIC_KEYS = { - "learning_rate": "learning_rate", - "model_name": "model_name", - "model_type": "model_type", - "peak_memory_mb": "peak_memory_mb", - "flops_giga": "flops_giga", - "train_time_seconds": "train_time_seconds", - "eval_time_seconds": "eval_time_seconds", - "eval_metrics": "eval_metrics", - } - metrics = { - METRIC_KEYS["learning_rate"]: config["learning_rate"], - METRIC_KEYS["model_name"]: model_info[MODEL_CONFIG_KEYS["name"]], - METRIC_KEYS["model_type"]: model_info[MODEL_CONFIG_KEYS["type"]], - METRIC_KEYS["peak_memory_mb"]: get_gpu_memory(), - METRIC_KEYS["flops_giga"]: None, - METRIC_KEYS["train_time_seconds"]: None, - METRIC_KEYS["eval_time_seconds"]: None, - METRIC_KEYS["eval_metrics"]: results, - } - wandb.log({"metrics": metrics}) +def train_and_evaluate_model(model_info, train_ds, val_ds, config): + """Train and evaluate a single model, then log results.""" + model, _preprocessor = initialize_model_and_preprocessor(model_info, config["resolution"]) + + train_config = { + "model_name": model_info["name"], + "model_type": model_info["type"], + "resolution": config["resolution"], + "batch_size": config["batch_size"], + "num_epochs": config["num_epochs"], + "learning_rate": config["learning_rate"], + "eval_steps": config["eval_steps"], + "wandb_config": config, # Pass the entire config for wandb + } + + results = train_model(model, train_ds, val_ds, train_config) + print(f"Results for {model_info['name']}: {results}") + + # Define metric keys and log results to wandb + METRIC_KEYS = { + "learning_rate": "learning_rate", + "model_name": "model_name", + "model_type": "model_type", + "peak_memory_mb": "peak_memory_mb", + "flops_giga": "flops_giga", + "train_time_seconds": "train_time_seconds", + "eval_time_seconds": "eval_time_seconds", + "eval_metrics": "eval_metrics", + } + + metrics = { + METRIC_KEYS["learning_rate"]: config["learning_rate"], + METRIC_KEYS["model_name"]: model_info["name"], + METRIC_KEYS["model_type"]: model_info["type"], + METRIC_KEYS["peak_memory_mb"]: get_gpu_memory(), + METRIC_KEYS["flops_giga"]: None, + METRIC_KEYS["train_time_seconds"]: None, + METRIC_KEYS["eval_time_seconds"]: None, + METRIC_KEYS["eval_metrics"]: results, + } + wandb.log({"metrics": metrics}) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Baseline model comparison for image classification.") + parser = argparse.ArgumentParser( + description="Baseline model comparison for image classification." + ) parser.add_argument( "--path_to_dataset", type=str, @@ -300,7 +323,7 @@ def main(config=None, dataset=None): ) args = parser.parse_args() - dataset = None + dataset = None # Changed from 'data' to 'dataset' if args.path_to_dataset: dataset = load_dataset("imagefolder", data_dir=args.path_to_dataset, split="train") else: @@ -313,4 +336,4 @@ def main(config=None, dataset=None): except Exception as e: raise ValueError("No dataset provided and Hugging Face dataset failed to load.") from e - main(dataset=dataset) \ No newline at end of file + main(dataset=dataset) diff --git a/src/compressed_perception/models/comparison/compare_lr_sweep.py b/src/compressed_perception/models/comparison/compare_lr_sweep.py index 719b66b..366db98 100644 --- a/src/compressed_perception/models/comparison/compare_lr_sweep.py +++ b/src/compressed_perception/models/comparison/compare_lr_sweep.py @@ -4,6 +4,8 @@ # # SPDX-License-Identifier: MIT +# pylint: skip-file + """ This script is a baseline for comparing different image classification models at three different image compression levels, in comparison to the original. diff --git a/src/compressed_perception/models/training/.DS_Store b/src/compressed_perception/models/training/.DS_Store deleted file mode 100644 index b7e02b4c56b209f81be5eab78eb861c30c047690..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKyH3ME5S$G`3JQfgDE$RU{J|*-1qBrqRUnB0OAbMaQ_%Sz`2aozX7^4JaSSB_ zv}^6nJ$Cl&Im_n*AfwIv9LN9+=!(4~79FPJ>U-8pJ%>cG93^T@c%9*SFWNhPqXN2i zLtC>FPx!okFPGJFHm#OZM(FEtMGx<0tO3WEVM)%KYko(pS<`iI^IluL6V9?A;{gTd zt+6up8u!e$vXPn*nzrOSTme_$&ndt?TWv6m=%p**3b+DW1@!xn(-lL)CZc^hSlAJO*k`vH>-w{(oIGJj z*hJ(AO(K;TsgjNu66x&EQ(Q>cM2vJuI($f)S<(qbY+)`Mvt(dX86`#`0*q>{K V7!o!S*+cUm0h7TCSKvn#_yA-yUa$ZF From 83908e322b5e84fce125e5570714cbcbfc433389 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 11 Jul 2025 18:19:34 -0700 Subject: [PATCH 11/26] refactor --- .github/workflows/build-and-test.yml | 2 +- .../models/comparison/compare_baseline.py | 8 +++--- .../models/training/utils_classes.py | 22 +++++++++------- .../modules/data_preparation/preparation.py | 26 ++++++++++++++----- 4 files changed, 37 insertions(+), 21 deletions(-) diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 23d3d99..64cb848 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10", "3.12"] steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 diff --git a/src/compressed_perception/models/comparison/compare_baseline.py b/src/compressed_perception/models/comparison/compare_baseline.py index 8b23db7..634fc0a 100644 --- a/src/compressed_perception/models/comparison/compare_baseline.py +++ b/src/compressed_perception/models/comparison/compare_baseline.py @@ -323,12 +323,12 @@ def train_and_evaluate_model(model_info, train_ds, val_ds, config): ) args = parser.parse_args() - dataset = None # Changed from 'data' to 'dataset' + loaded_dataset = None if args.path_to_dataset: - dataset = load_dataset("imagefolder", data_dir=args.path_to_dataset, split="train") + loaded_dataset = load_dataset("imagefolder", data_dir=args.path_to_dataset, split="train") else: try: - dataset = load_dataset( + loaded_dataset = load_dataset( "MKZuziak/ISIC_2019_224", cache_dir=os.environ["HF_DATASETS_CACHE"], split="train", @@ -336,4 +336,4 @@ def train_and_evaluate_model(model_info, train_ds, val_ds, config): except Exception as e: raise ValueError("No dataset provided and Hugging Face dataset failed to load.") from e - main(dataset=dataset) + main(dataset=loaded_dataset) diff --git a/src/compressed_perception/models/training/utils_classes.py b/src/compressed_perception/models/training/utils_classes.py index 7ac0925..1c0ade4 100644 --- a/src/compressed_perception/models/training/utils_classes.py +++ b/src/compressed_perception/models/training/utils_classes.py @@ -17,9 +17,18 @@ from PIL import Image from torchvision import transforms from transformers import TrainerCallback - -from src.compressed_perception.models.training.constants import HF_MODELS, NUM_FILTERED_CLASSES, SSL_MODEL -from src.compressed_perception.modules.data_transformation.image_transformation import JPEGCompressionTransform +import wandb + +# Local imports +from src.compressed_perception.models.training.constants import ( + HF_MODELS, NUM_FILTERED_CLASSES, SSL_MODEL +) +from src.compressed_perception.models.training.utils_methods import ( + GPU_AVAILABLE, get_gpu_memory + ) +from src.compressed_perception.modules.data_transformation.image_transformation import ( + JPEGCompressionTransform +) # Compatibility for LANCZOS resampling try: @@ -28,8 +37,6 @@ LANCZOS = Image.LANCZOS # pylint: disable=no-member -from transformers import TrainerCallback - class WandbCallback(TrainerCallback): """ Custom callback for logging metrics and evaluation results to Weights & Biases. @@ -51,12 +58,10 @@ def on_log(self, _args, _state, _control, logs=None, **_kwargs): logs[self.MODEL_KEY] = self.model_name logs[self.PHASE_KEY] = self.phase try: - from src.compressed_perception.models.training.utils_methods import GPU_AVAILABLE, get_gpu_memory if GPU_AVAILABLE: logs[self.GPU_MEMORY_KEY] = get_gpu_memory() except ImportError: pass - import wandb wandb.log(logs) def on_evaluate(self, _args, _state, _control, metrics=None, **_kwargs): @@ -64,7 +69,6 @@ def on_evaluate(self, _args, _state, _control, metrics=None, **_kwargs): if self.EVAL_ACCURACY_KEY in metrics: self.best_accuracy = max(self.best_accuracy, metrics[self.EVAL_ACCURACY_KEY]) metrics[self.BEST_ACCURACY_KEY] = self.best_accuracy - import wandb wandb.log(metrics) def get_trainer_callbacks(name): @@ -77,7 +81,7 @@ def get_trainer_callbacks(name): ), WandbCallback(name, "finetune"), ] - + class ISICDataset(Dataset): """ Dataset class for handling ISIC image data with optional transformations. diff --git a/src/compressed_perception/modules/data_preparation/preparation.py b/src/compressed_perception/modules/data_preparation/preparation.py index 7c2615b..1a34862 100644 --- a/src/compressed_perception/modules/data_preparation/preparation.py +++ b/src/compressed_perception/modules/data_preparation/preparation.py @@ -4,13 +4,26 @@ # # SPDX-License-Identifier: MIT +""" +Utility functions for dataset preparation. +""" +# Standard library imports +import io + +# Thrid-party imports import numpy as np from PIL import Image from torch.utils.data import Dataset from torchvision import transforms -import io from datasets import ClassLabel +# Local imports +from src.compressed_perception.modules.data_transformation.image_transformation import ( + JPEGCompressionTransform, + GaussianBlurTransform, + ColorQuantizationTransform + ) + def filter_and_cast_dataset(dataset, filtered_classes, num_classes): """ Filter dataset by class labels and cast label column. @@ -48,7 +61,11 @@ def split_dataset(dataset, test_size=0.2, stratify_by_column="label", seed=42): """ Split dataset into train and validation sets. """ - return dataset.train_test_split(test_size=test_size, stratify_by_column=stratify_by_column, seed=seed) + return dataset.train_test_split( + test_size=test_size, + stratify_by_column=stratify_by_column, + seed=seed + ) def get_default_transforms(resolution, apply_transforms=False): """ @@ -61,11 +78,6 @@ def get_default_transforms(resolution, apply_transforms=False): ] if apply_transforms: - from src.compressed_perception.modules.data_transformation.image_transformation import ( - JPEGCompressionTransform, - GaussianBlurTransform, - ColorQuantizationTransform - ) transform_list.extend([ JPEGCompressionTransform(quality=75), GaussianBlurTransform(p=0.5), From ebe690ec946c3b7d6ec2753ee6c53dfb84571a00 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 11 Jul 2025 18:36:35 -0700 Subject: [PATCH 12/26] update README.md --- README.md | 150 ++++++++++++------ .../models/comparison/compare_baseline.py | 2 +- 2 files changed, 105 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 0e42e05..221694b 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,9 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT + # Finetuning Pretrained Models for Compressed Dermatology Image Analysis @@ -33,54 +34,111 @@ This project explores how compressed and degraded dermatology images (from the I ## Project Structure ``` -CS231N/ -├── configs/ -│ └── example_config.yaml # Configs for job submissions +compressed-perception/ +├── README.md # Project overview & documentation +├── LICENSES/ # Directory containing license files (REUSE compliance) +│ └── MIT.txt # MIT license text │ -├── scripts/ # Lightweight utility or shell scripts -│ ├── download_unpack_isic2019.sh # Downloads and unpacks ISIC data -│ └── submit_from_config.sh # SLURM submission helper +├── pyproject.toml # Python packaging config +├── setup.py # Installation script for the package +├── setup.cfg # Configuration for setup tools │ -├── jobs/ # SLURM-related job definitions -│ └── job_template.slurm +├── requirements.txt # Dependencies file +├── requirements.txt.license # Dependencies file license +├── .yamllint # YAML linter configuration +├── .yamllint.license # YAML linter configuration license │ -├── src/ # Source code, logically grouped -│ ├── __init__.py -│ ├── finetune/ # Fine-tuning workflows -│ │ └── baseline_finetuning.py -│ ├── evaluation/ # Evaluation + plotting -│ │ └── evaluate_isic_results.py -│ └── models/ # Model-related scripts -│ ├── model_comparison.py # Config file with constant strings -│ ├── model_comparison.py -│ └── model_comparison_2.py - +├── .github/ # GitHub specific files +│ └── workflows/ # CI/CD workflow definitions +│ ├── build-and-test.yml +│ └── pull_request.yml +│ +├── .reuse/ # REUSE compliance configuration +│ └── dep5 # Copyright and license information +│ +├── docs/ # Documentation +│ └── pipeline.md # Pipeline documentation +│ +├── scripts/ # Standalone scripts +│ ├── ... +│ └── visualize_isic_results.py # Visualize metrics for model comparison (TODO) +│ +├── configs/ # Configuration files (TODO) +│ ├── datasets/ # Dataset configs +│ │ └── isic2019.yaml # ISIC 2019 dataset config +│ ├── models/ # Model configs +│ │ ├── vit.yaml # ViT model config +│ │ ├── dinov2.yaml # DINOv2 model config +│ │ └── simclr.yaml # SimCLR model config +│ ├── experiments/ # Experiment configs +│ │ ├── baseline.yaml # Baseline experiment +│ │ └── lr_sweep.yaml # Learning rate sweep experiment +│ └── example_config.yaml # Example configuration file │ -├── results/ # Auto-generated results -│ ├── plots/ # Accuracy/f1/AUC plots -│ └── logs/ # Training logs or SLURM outputs │ -├── requirements.txt -├── .gitignore -├── .github -└── README.md +├── tests/ # Test suite (TODO) +│ ├── unit/ # Unit tests +│ │ └── test_transforms.py +│ ├── integration/ # Integration tests +│ │ └── test_pipeline.py +│ └── conftest.py # Test fixtures and configuration +│ +├── src/ +│ └── compressed_perception/ # Main package +│ ├── __init__.py # Package initialization +│ │ +│ ├── models/ # Model implementations +│ │ ├── __init__.py +│ │ ├── architectures/ # Model architecture definitions +│ │ │ ├── __init__.py +│ │ │ ├── vit.py # Vision Transformer adaptations +│ │ │ └── simclr.py # SimCLR adaptations +│ │ │ +│ │ ├── evaluation/ # Model evaluation code +│ │ │ ├── __init__.py +│ │ │ └── metrics.py # Evaluation metrics +│ │ │ +│ │ ├── comparison/ # Model comparison utilities +│ │ │ ├── __init__.py +│ │ │ ├── compare_baseline.py # Baseline comparison +│ │ │ └── compare_lr_sweep.py # Learning rate sweeping +│ │ │ +│ │ ├── training/ # Training infrastructure +│ │ │ ├── __init__.py +│ │ │ ├── trainers.py # Training loops +│ │ │ └── callbacks.py # Training callbacks +│ │ │ +│ │ └── utils/ # Model utilities +│ │ ├── __init__.py +│ │ ├── constants.py # Model constants +│ │ └── helpers.py # Helper functions +│ │ +│ ├── modules/ # Reusable modules +│ │ ├── __init__.py +│ │ ├── transforms/ # Image transformations +│ │ │ ├── __init__.py +│ │ │ ├── degradation.py # Image degradation transforms +│ │ │ └── augmentation.py # Data augmentation transforms +│ │ │ +│ │ └── data_preparation/ # Data preparation utilities +│ │ ├── __init__.py +│ │ └── preparation.py # Dataset preparation +│ │ +│ +│ +├── results/ +│ +├── jobs/ # Cluster job submission files +│ ├── job_template.slurm # SLURM job template +│ ├── run.sh # General run script +│ ├── rurun_compare_baseline.sh # Learning rate experiment script +│ ├── run_compare_lr_sweep.sh # Model comparison script +│ └── configs/ # Job configurations +│ +└── media/ # Media files for documentation + └── CS231N Poster.png # Project poster ``` -## Quick Start - -1. Install requirements: - ```bash - pip install -r requirements.txt - ``` - -2. Run training: - ```bash - python train_models.py - ``` - -3. View results - We use weights and biases for logging, so output plots can be seen there - ## 📦 Dataset - [ISIC 2019 (Hugging Face)](https://huggingface.co/datasets/MKZuziak/ISIC_2019_224) diff --git a/src/compressed_perception/models/comparison/compare_baseline.py b/src/compressed_perception/models/comparison/compare_baseline.py index 634fc0a..f4a2fe8 100644 --- a/src/compressed_perception/models/comparison/compare_baseline.py +++ b/src/compressed_perception/models/comparison/compare_baseline.py @@ -171,7 +171,7 @@ def train_model( weight_decay=0.01, logging_dir=f"./logs/{config['model_name']}", logging_steps=1, - evaluation_strategy="steps", + eval_strategy="steps", eval_steps=config['eval_steps'], save_strategy="steps", save_steps=config['eval_steps'], From 0a2000410fd42b3821d6188457b4b4f6544a63b8 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 11 Jul 2025 18:45:39 -0700 Subject: [PATCH 13/26] correct bash script --- jobs/run_compare_baseline.sh | 2 +- .../models/comparison/compare_baseline.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jobs/run_compare_baseline.sh b/jobs/run_compare_baseline.sh index 8893737..749a0c2 100644 --- a/jobs/run_compare_baseline.sh +++ b/jobs/run_compare_baseline.sh @@ -35,4 +35,4 @@ fi export WANDB_API_KEY="7ab80eeb87ef06298c6bca1258208b1739ad32fe" # Running the script -python src/models/comparison/compare_baseline.py --resolution 224 --batch_size 256 --num_train_images 25000 --num_epochs 10 --eval_steps 10 \ No newline at end of file +python -m src.compressed_perception.models.comparison.compare_baseline --resolution 224 --batch_size 256 --num_train_images 25000 --num_epochs 10 --eval_steps 10 \ No newline at end of file diff --git a/src/compressed_perception/models/comparison/compare_baseline.py b/src/compressed_perception/models/comparison/compare_baseline.py index f4a2fe8..4831da5 100644 --- a/src/compressed_perception/models/comparison/compare_baseline.py +++ b/src/compressed_perception/models/comparison/compare_baseline.py @@ -218,11 +218,11 @@ def main(config=None, dataset=None): # Define all model configurations models = get_model_configs(config["resolution"]) - if dataset is None: # Changed from 'data' to 'dataset' + if dataset is None: raise ValueError("Dataset must be provided via `dataset` argument or CLI.") # Process dataset and create training/validation splits - train_ds, val_ds = process_dataset(dataset, config) # Changed from 'data' to 'dataset' + train_ds, val_ds = process_dataset(dataset, config) # Train and evaluate each model for model_info in models: From f9a82def4035fd8c5dd166f157de6e2887fb6e06 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 11 Jul 2025 18:50:54 -0700 Subject: [PATCH 14/26] remove duplicate from .yamllint file --- .yamllint | 5 ++--- .../models/comparison/compare_lr_sweep.py | 8 ++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/.yamllint b/.yamllint index 040176a..023e567 100644 --- a/.yamllint +++ b/.yamllint @@ -7,8 +7,7 @@ rules: allowed-values: ["false", "true", "on", "off"] document-start: level: warning + present: false line-length: max: 180 - level: warning - document-start: - present: false \ No newline at end of file + level: warning \ No newline at end of file diff --git a/src/compressed_perception/models/comparison/compare_lr_sweep.py b/src/compressed_perception/models/comparison/compare_lr_sweep.py index 366db98..35b7a53 100644 --- a/src/compressed_perception/models/comparison/compare_lr_sweep.py +++ b/src/compressed_perception/models/comparison/compare_lr_sweep.py @@ -156,10 +156,10 @@ def train_for_learning_rate( # pylint: disable=too-many-locals train_ds = get_transformed_datasets(train_dataset, preprocessors, config, typ) val_ds = ISICDataset( # pylint: disable=too-many-function-args - val_dataset, - preprocessors[typ], - config["resolution"], - typ, + dataset=val_dataset, + transform=None, + resolution=config["resolution"], + model_type=typ, ) model = get_model(typ, model_id, config) From 7324dd748a2923dd44915fb989397eb875816356 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 17 Oct 2025 17:20:23 -0700 Subject: [PATCH 15/26] restore history --- LICENSE | 21 + jobs/download_unpack_isic2019.sh | 68 --- jobs/job_template.slurm | 54 --- jobs/run_compare_baseline.sh | 38 -- jobs/run_compare_lr_sweep.sh | 38 -- jobs/submit_from_config.sh | 62 --- scripts/test_transforms.py | 90 ++++ scripts/visualize_isic_results.py | 236 --------- src/__init__.py | 3 + src/cli/train.py | 139 ++++++ .../models/comparison/compare_baseline.py | 339 ------------- .../models/comparison/compare_lr_sweep.py | 456 ------------------ .../models/training/constants.py | 16 - .../models/training/utils_classes.py | 239 --------- .../models/training/utils_methods.py | 180 ------- .../modules/data_preparation/preparation.py | 113 ----- .../modules/data_transformation/__init__.py | 0 .../image_transformation.py | 160 ------ src/config.py | 64 +++ src/data/data_utils.py | 200 ++++++++ src/data/datamodule.py | 82 ++++ src/data/datasets.py | 14 + .../__init__.py => engines/distill_engine.py} | 0 .../finetune_engine.py} | 0 src/engines/linear_probe_engine.py | 224 +++++++++ src/evaluation/metrics.py | 63 +++ src/evaluation/visualization.py | 55 +++ src/evaluation/visualize_results.py | 173 +++++++ .../models/training => losses}/__init__.py | 0 src/losses/classification.py | 35 ++ src/losses/distillation.py | 84 ++++ src/models/factory.py | 191 ++++++++ src/requirements.txt | Bin 0 -> 1716 bytes .../modules => transformation}/__init__.py | 0 src/transformation/transforms.py | 111 +++++ src/utils/callbacks_hf.py | 55 +++ src/utils/constants.py | 19 + src/utils/logging.py | 114 +++++ src/utils/optim.py | 122 +++++ src/utils/training_utils.py | 43 ++ src/utils/utils.py | 80 +++ .../__init__.py => wrappers/distill.py} | 0 src/wrappers/finetune.py | 38 ++ src/wrappers/probe.py | 162 +++++++ 44 files changed, 2182 insertions(+), 1999 deletions(-) create mode 100644 LICENSE delete mode 100644 jobs/download_unpack_isic2019.sh delete mode 100644 jobs/job_template.slurm delete mode 100644 jobs/run_compare_baseline.sh delete mode 100644 jobs/run_compare_lr_sweep.sh delete mode 100644 jobs/submit_from_config.sh create mode 100644 scripts/test_transforms.py delete mode 100644 scripts/visualize_isic_results.py create mode 100644 src/cli/train.py delete mode 100644 src/compressed_perception/models/comparison/compare_baseline.py delete mode 100644 src/compressed_perception/models/comparison/compare_lr_sweep.py delete mode 100644 src/compressed_perception/models/training/constants.py delete mode 100644 src/compressed_perception/models/training/utils_classes.py delete mode 100644 src/compressed_perception/models/training/utils_methods.py delete mode 100644 src/compressed_perception/modules/data_preparation/preparation.py delete mode 100644 src/compressed_perception/modules/data_transformation/__init__.py delete mode 100644 src/compressed_perception/modules/data_transformation/image_transformation.py create mode 100644 src/config.py create mode 100644 src/data/data_utils.py create mode 100644 src/data/datamodule.py create mode 100644 src/data/datasets.py rename src/{compressed_perception/models/__init__.py => engines/distill_engine.py} (100%) rename src/{compressed_perception/models/comparison/__init__.py => engines/finetune_engine.py} (100%) create mode 100644 src/engines/linear_probe_engine.py create mode 100644 src/evaluation/metrics.py create mode 100644 src/evaluation/visualization.py create mode 100644 src/evaluation/visualize_results.py rename src/{compressed_perception/models/training => losses}/__init__.py (100%) create mode 100644 src/losses/classification.py create mode 100644 src/losses/distillation.py create mode 100644 src/models/factory.py create mode 100644 src/requirements.txt rename src/{compressed_perception/modules => transformation}/__init__.py (100%) create mode 100644 src/transformation/transforms.py create mode 100644 src/utils/callbacks_hf.py create mode 100644 src/utils/constants.py create mode 100644 src/utils/logging.py create mode 100644 src/utils/optim.py create mode 100644 src/utils/training_utils.py create mode 100644 src/utils/utils.py rename src/{compressed_perception/modules/data_preparation/__init__.py => wrappers/distill.py} (100%) create mode 100644 src/wrappers/finetune.py create mode 100644 src/wrappers/probe.py diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..97a7fea --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Stanford Daneshjou Lab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/jobs/download_unpack_isic2019.sh b/jobs/download_unpack_isic2019.sh deleted file mode 100644 index 7825179..0000000 --- a/jobs/download_unpack_isic2019.sh +++ /dev/null @@ -1,68 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -#!/bin/bash -set -e - -# -------------------- Config -------------------- -BASE_DIR="/oak/stanford/groups/roxanad/isic2019" -TRAIN_DIR="$BASE_DIR/data/train" -TEST_DIR="$BASE_DIR/data/test" -CHECKSUM_FILE="$BASE_DIR/checksums.sha256" - -mkdir -p "$TRAIN_DIR/images" "$TEST_DIR/images" -cd "$BASE_DIR" - -# -------------------- URLs -------------------- -TRAIN_IMAGES_URL="https://isic-challenge-data.s3.amazonaws.com/2019/ISIC_2019_Training_Input.zip" -TRAIN_LABELS_URL="https://isic-challenge-data.s3.amazonaws.com/2019/ISIC_2019_Training_GroundTruth.csv" -TEST_IMAGES_URL="https://isic-challenge-data.s3.amazonaws.com/2019/ISIC_2019_Test_Input.zip" -TEST_LABELS_URL="https://isic-challenge-data.s3.amazonaws.com/2019/ISIC_2019_Test_GroundTruth.csv" - -# -------------------- Download with Progress + Fail -------------------- -download() { - local url=$1 - local out=$2 - echo "Downloading: $out" - curl -fL# "$url" -o "$out" || { - echo "❌ Failed to download $out" - exit 1 - } -} - -download "$TRAIN_IMAGES_URL" ISIC_2019_Training_Input.zip -download "$TRAIN_LABELS_URL" ISIC_2019_Training_GroundTruth.csv -download "$TEST_IMAGES_URL" ISIC_2019_Test_Input.zip -download "$TEST_LABELS_URL" ISIC_2019_Test_GroundTruth.csv - -# -------------------- Checksum (Optional) -------------------- -echo "Generating SHA256 checksums..." -sha256sum ISIC_2019_Training_Input.zip \ - ISIC_2019_Training_GroundTruth.csv \ - ISIC_2019_Test_Input.zip \ - ISIC_2019_Test_GroundTruth.csv > "$CHECKSUM_FILE" - -echo "Verifying checksums..." -sha256sum -c "$CHECKSUM_FILE" || { - echo "❌ Checksum verification failed." - exit 1 -} - -# -------------------- Unpack and Organize -------------------- -echo "Unpacking training images..." -unzip -q ISIC_2019_Training_Input.zip -d "$TRAIN_DIR/images" -echo "Moving training labels..." -mv ISIC_2019_Training_GroundTruth.csv "$TRAIN_DIR/labels.csv" - -echo "Unpacking test images..." -unzip -q ISIC_2019_Test_Input.zip -d "$TEST_DIR/images" -echo "Moving test labels..." -mv ISIC_2019_Test_GroundTruth.csv "$TEST_DIR/labels.csv" - -# -------------------- Cleanup -------------------- -rm ISIC_2019_Training_Input.zip ISIC_2019_Test_Input.zip - -echo "✅ Dataset ready at: $BASE_DIR" diff --git a/jobs/job_template.slurm b/jobs/job_template.slurm deleted file mode 100644 index b5aae84..0000000 --- a/jobs/job_template.slurm +++ /dev/null @@ -1,54 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -#!/bin/bash -#SBATCH --job-name={{JOB_NAME}} -#SBATCH --partition={{PARTITION}} -#SBATCH --nodelist={{NODELIST}} -#SBATCH --gres=gpu:{{GPUS}} -#SBATCH --cpus-per-task={{CPUS_PER_TASK}} -#SBATCH --mem={{MEMORY}} -#SBATCH --time={{TIME_LIMIT}} -#SBATCH --output={{OUTPUT_FILE}} -#SBATCH --error={{ERROR_FILE}} - -set -e # Exit immediately on error - -# Set up project directories -export CODE_DIR={{CODE_DIR}} -export VENV_DIR=$CODE_DIR/venv -export PROJECT_ROOT={{PROJECT_ROOT}} - -ml load python/3.12.1 -# Create venv if not exists -if [ ! -d "$VENV_DIR" ]; then - echo "Creating virtual environment..." - python3 -m venv "$VENV_DIR" - source "$VENV_DIR/bin/activate" - pip install --upgrade pip - pip install -r "$CODE_DIR/requirements.txt" -else - echo "Using existing virtual environment" - source "$VENV_DIR/bin/activate" -fi - -# Create necessary output directories -mkdir -p $PROJECT_ROOT/{hf_cache/transformers,hf_cache/datasets,results,logs,models,plots} - -# Set environment variables for Hugging Face + outputs -export TRANSFORMERS_CACHE=$PROJECT_ROOT/hf_cache/transformers -export HF_DATASETS_CACHE=$PROJECT_ROOT/hf_cache/datasets -export HF_HOME=$PROJECT_ROOT/hf_cache -export TRAIN_OUTPUT_DIR=$PROJECT_ROOT/results -export LOG_DIR=$PROJECT_ROOT/logs -export MODEL_DIR=$PROJECT_ROOT/models -export PLOT_DIR=$PROJECT_ROOT/plots - -# Run training script -cd "$CODE_DIR" -echo "Starting training at $(date)" -python model_comparison_2.py -echo "Training complete at $(date)" diff --git a/jobs/run_compare_baseline.sh b/jobs/run_compare_baseline.sh deleted file mode 100644 index 749a0c2..0000000 --- a/jobs/run_compare_baseline.sh +++ /dev/null @@ -1,38 +0,0 @@ -# This source file is part of the ARPA-H CARE LLM project -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -#!/bin/bash -#SBATCH --job-name=compare_baseline -#SBATCH --output=logs/compare_baseline_%j.out -#SBATCH --error=logs/compare_baseline_%j.err -#SBATCH --partition=roxanad -#SBATCH --gres=gpu:1 -#SBATCH --time=12:00:00 -#SBATCH --mem=32G -#SBATCH --cpus-per-task=4 -#SBATCH --output=logs/%x_%j.out -#SBATCH --error=logs/%x_%j.err - -# Load Python module -ml load python/3.12.1 - -# Setup virtual environment if it doesn't exist -if [ ! -d ".venv" ]; then - python3.12 -m pip install uv - uv venv - source .venv/bin/activate - - # Install from requirements.txt - uv pip install -r requirements.txt -else - source .venv/bin/activate -fi - -# WANDB Key -export WANDB_API_KEY="7ab80eeb87ef06298c6bca1258208b1739ad32fe" - -# Running the script -python -m src.compressed_perception.models.comparison.compare_baseline --resolution 224 --batch_size 256 --num_train_images 25000 --num_epochs 10 --eval_steps 10 \ No newline at end of file diff --git a/jobs/run_compare_lr_sweep.sh b/jobs/run_compare_lr_sweep.sh deleted file mode 100644 index fd56acd..0000000 --- a/jobs/run_compare_lr_sweep.sh +++ /dev/null @@ -1,38 +0,0 @@ -# This source file is part of the ARPA-H CARE LLM project -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -#!/bin/bash -#SBATCH --job-name=compare_lr_sweep -#SBATCH --output=logs/compare_lr_sweep_%j.out -#SBATCH --error=logs/compare_lr_sweep_%j.err -#SBATCH --partition=roxanad -#SBATCH --gres=gpu:1 -#SBATCH --time=12:00:00 -#SBATCH --mem=32G -#SBATCH --cpus-per-task=4 -#SBATCH --output=logs/%x_%j.out -#SBATCH --error=logs/%x_%j.err - -# Load Python module -ml load python/3.12.1 - -# Setup virtual environment if it doesn't exist -if [ ! -d ".venv" ]; then - python3.12 -m pip install uv - uv venv - source .venv/bin/activate - - # Install from requirements.txt - uv pip install -r requirements.txt -else - source .venv/bin/activate -fi - -# WANDB Key -export WANDB_API_KEY="7ab80eeb87ef06298c6bca1258208b1739ad32fe" - -# Running the script -python src/models/comparison/compare_lr_sweep.py --resolution 224 --batch_size 256 --num_train_images 25000 --num_epochs 6 --eval_steps 10 --learning_rate 1e-4 \ No newline at end of file diff --git a/jobs/submit_from_config.sh b/jobs/submit_from_config.sh deleted file mode 100644 index d60b99d..0000000 --- a/jobs/submit_from_config.sh +++ /dev/null @@ -1,62 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -#!/bin/bash -set -e - -CONFIG_FILE="config.yaml" -TEMPLATE_FILE="job_template.slurm" -JOB_FILE="job_generated.slurm" - -if [ ! -f "$CONFIG_FILE" ] || [ ! -f "$TEMPLATE_FILE" ]; then - echo "Missing config or template" - exit 1 -fi - -# Function to extract config values using Python/YAML -read_config() { - python -c "import yaml; print(yaml.safe_load(open('$CONFIG_FILE')).get('$1', ''))" -} - -# Mapping of placeholders to config keys -declare -A config_map=( - ["JOB_NAME"]="job_name" - ["OUTPUT_FILE"]="output_file" - ["ERROR_FILE"]="error_file" - ["TIME_LIMIT"]="time_limit" - ["GPUS"]="gpus" - ["PARTITION"]="partition" - ["NODELIST"]="nodelist" - ["CPUS_PER_TASK"]="cpus_per_task" - ["MEMORY"]="memory" - ["CODE_DIR"]="code_dir" - ["PROJECT_ROOT"]="project_root" -) - -# Copy template to target job file -cp "$TEMPLATE_FILE" "$JOB_FILE" - -# Substitute placeholders with config values -for placeholder in "${!config_map[@]}"; do - value=$(read_config "${config_map[$placeholder]}") - if [ -z "$value" ]; then - echo "Warning: Config key '${config_map[$placeholder]}' is empty or missing" - fi - - # Escape characters that might break sed (like slashes or ampersands) - value_escaped=$(printf '%s\n' "$value" | sed 's/[&/\]/\\&/g') - - # Detect OS and use correct sed syntax - if [[ "$OSTYPE" == "darwin"* ]]; then - sed -i '' "s|{{${placeholder}}}|${value}|g" "$JOB_FILE" - else - sed -i "s|{{${placeholder}}}|${value}|g" "$JOB_FILE" - fi -done - -# Submit the job -chmod +x "$JOB_FILE" -sbatch "$JOB_FILE" diff --git a/scripts/test_transforms.py b/scripts/test_transforms.py new file mode 100644 index 0000000..933b2a3 --- /dev/null +++ b/scripts/test_transforms.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Minimal check: dataset loads and a degradation transform is applied to 5 images. +- Uses your BaseDataModule and get_degradation_transforms() +- No training, no W&B. +- Saves clean and degraded PNGs. +""" + +from types import SimpleNamespace +from pathlib import Path +import argparse + +from torchvision.utils import save_image +import torchvision.transforms.functional as TF +from PIL import Image + +from src.data.datamodule import BaseDataModule +from src.transformation.transforms import get_degradation_transforms +from src.utils import setup_environment + + +def main(): + ap = argparse.ArgumentParser(description="Quick dataset+transform smoke test (5 images).") + ap.add_argument("--dataset", type=str, required=True, choices=["isic", "tcga", "merlin"]) + ap.add_argument("--data_dir", type=str, required=True) + ap.add_argument("--out_dir", type=str, default="outputs/test_transforms") + ap.add_argument("--resolution", type=int, default=224) + ap.add_argument("--num_workers", type=int, default=2) + args = ap.parse_args() + + setup_environment() + + out_root = Path(args.out_dir) + out_root.mkdir(parents=True, exist_ok=True) + + # ---- minimal cfg your DataModule/get_dataset can read ---- + cfg = SimpleNamespace( + resolution=args.resolution, + batch_size=5, # just enough to grab 5 samples + ) + + # ---- build the DataModule exactly as your code expects ---- + dm = BaseDataModule( + cfg=cfg, + dataset_name=args.dataset, + data_dir=args.data_dir, + num_workers=args.num_workers, + batch_size=cfg.batch_size, + pin_memory=True, + drop_last=False, + ) + dm.setup(stage="fit") + + # ---- pull one small batch (5 images) from val ---- + val_loader = dm.val_dataloader() + batch = next(iter(val_loader)) + + if isinstance(batch, dict): + x = batch.get("pixel_values") or batch.get("images") or batch.get("x") + y = batch.get("labels") or batch.get("y") + else: + x, y = batch + + # keep exactly 5 + x = x[:5] + + for i in range(x.size(0)): + save_image(x[i], out_root / f"clean_{i}.png") + + degradations = get_degradation_transforms() + if not degradations: + print("No degradations returned by get_degradation_transforms(); saved only clean images.") + return + + degr = degradations[0] + tag = degr.__class__.__name__.lower() + print(f"Applying degradation: {tag}") + + for i in range(x.size(0)): + pil_img = TF.to_pil_image(x[i].cpu()) + pil_out = degr(pil_img) + t_out = TF.to_tensor(pil_out) + save_image(t_out, out_root / f"{tag}_{i}.png") + + print(f"✅ Done. Wrote images to: {out_root.resolve()}") + + +if __name__ == "__main__": + main() diff --git a/scripts/visualize_isic_results.py b/scripts/visualize_isic_results.py deleted file mode 100644 index 24c6eeb..0000000 --- a/scripts/visualize_isic_results.py +++ /dev/null @@ -1,236 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -""" -Evaluates results from ISIC 2019 fine-tuning and linear probing experiments. -Conducts paired t-tests to assess statistical significance across resolutions -and model variations. For single runs, skips t-tests and -summarizes perform ance. Generates plots and saves results to JSON. -""" -# pylint: disable=broad-exception-caught - -import json -import itertools -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import seaborn as sns -from scipy.stats import ttest_rel - - -def load_results( - finetune_file="results_metrics_finetune.json", - linear_probe_file="results_metrics_linear_probe.json", -): - """ - Loads fine-tuning and linear probing results from JSON files, supporting - multiple runs or single run per condition. - """ - results = {} - for file, mode in [ - (finetune_file, "finetune"), - (linear_probe_file, "linear_probe"), - ]: - if not Path(file).exists(): - print(f"Warning: {file} not found, skipping {mode} results.") - continue - with open(file, "r", encoding="utf-8") as f: - data = json.load(f) - for model, qualities in data.items(): - for quality, metrics in qualities.items(): - key = (model, int(quality), mode) - eval_metrics = metrics["eval_metrics"] - # Handle single run (dict) or multiple runs (list of dicts) - if isinstance(eval_metrics, dict): - results[key] = [eval_metrics] # Wrap single run in list - elif isinstance(eval_metrics, list): - results[key] = eval_metrics - else: - print(f"Warning: Invalid eval_metrics format for {key}, skipping.") - continue - return results - - -def paired_t_tests(metrics, models, qualities, modes): - """ - Conducts paired t-tests across models and modes if multiple - runs are available. - """ - t_test_results = [] - metric_names = ["accuracy", "f1", "auc"] - - def run_if_valid(k1, k2, label): - if k1 in metrics and k2 in metrics and len(metrics[k1]) == len(metrics[k2]) > 1: - for metric in metric_names: - try: - v1 = [r[metric] for r in metrics[k1]] - v2 = [r[metric] for r in metrics[k2]] - stat, p = ttest_rel(v1, v2) - t_test_results.append( - { - "comparison": label.format(metric=metric), - "statistic": stat, - "p_value": p, - } - ) - except Exception as e: - print(f"Error in t-test for {label.format(metric=metric)}: {e}") - - has_multiple_runs = any(len(r) > 1 for r in metrics.values()) - if not has_multiple_runs: - print("Warning: Only single-run metrics available. Skipping t-tests.") - return t_test_results - - # 1. Within-model: JPEG quality comparisons - for model, mode in itertools.product(models, modes): - for q1, q2 in itertools.combinations(qualities, 2): - run_if_valid( - (model, q1, mode), - (model, q2, mode), - f"{model}_{mode}" + "_{{metric}}_jpeg{q1}_vs_jpeg{q2}", - ) - - # 2. Across models: same quality & mode - for m1, m2 in itertools.combinations(models, 2): - for quality, mode in itertools.product(qualities, modes): - run_if_valid( - (m1, quality, mode), - (m2, quality, mode), - f"{m1}_vs_{m2}_{mode}" + "_{{metric}}_jpeg{quality}", - ) - - # 3. Finetune vs. Linear Probe: same model & quality - for model, quality in itertools.product(models, qualities): - run_if_valid( - (model, quality, "finetune"), - (model, quality, "linear_probe"), - f"{model}_finetune_vs_linear_probe" + "_{{metric}}_jpeg{quality}", - ) - - return t_test_results - - -def summarize_performance(metrics): - """ - Summarizes performance metrics (mean, std) across models, qualities, and modes. - """ - summary = [] - for (model, _, mode), runs in metrics.items(): - if not runs: - continue - - # Compute mean and std for each metric - accuracies = [run.get("accuracy") for run in runs if "accuracy" in run] - f1s = [run.get("f1") for run in runs if "f1" in run] - aucs = [run.get("auc") for run in runs if "auc" in run] - - summary.append( - { - "model": model, - "mode": mode, - "accuracy_mean": np.mean(accuracies), - "accuracy_std": np.std(accuracies, ddof=1) if len(runs) > 1 else 0, - "f1_mean": np.mean(f1s), - "f1_std": np.std(f1s, ddof=1) if len(runs) > 1 else 0, - "auc_mean": np.mean(aucs), - "auc_std": np.std(aucs, ddof=1) if len(runs) > 1 else 0, - } - ) - return pd.DataFrame(summary) - - -def plot_performance(df): - """Generates plots for performance metrics across models.""" - metrics = ["accuracy", "f1", "auc"] - modes = ["finetune", "linear_probe"] - - for metric in metrics: - plt.figure(figsize=(12, 6)) - for mode in modes: - subset = df[df["mode"] == mode] - sns.lineplot( - x="jpeg_quality", - y=f"{metric}_mean", - hue="model", - style="model", - markers=True, - dashes=False, - data=subset, - label=f"{mode}", - ) - plt.title(f"{metric.capitalize()} vs. JPEG Quality") - plt.xlabel("JPEG Quality") - plt.ylabel(f"{metric.capitalize()} Mean") - plt.legend() - plt.savefig(f"{metric}_vs_jpeg_quality.png", dpi=300, bbox_inches="tight") - plt.close() - - # Bar plot for model comparison - for mode in modes: - plt.figure(figsize=(12, 6)) - subset = df[df["mode"] == mode] - subset_melted = subset.melt( - id_vars=["model", "jpeg_quality", "mode"], - value_vars=["accuracy_mean", "f1_mean", "auc_mean"], - var_name="metric", - value_name="value", - ) - sns.barplot( - x="jpeg_quality", - y="value", - hue="model", - style="metric", - data=subset_melted, - ) - plt.title(f"Performance Metrics ({mode.capitalize()})") - plt.xlabel("JPEG Quality") - plt.ylabel("Metric Value") - plt.legend() - plt.savefig(f"metrics_{mode}_bar.png", dpi=300, bbox_inches="tight") - plt.close() - - -def main(): - """Main function to evaluate results and conduct statistical tests.""" - # Load results - metrics = load_results() - - if not metrics: - print("No results found. Please ensure JSON files exist.") - return - - # Conduct paired t-tests (if multiple runs) - # Dynamic model/quality/mode extraction - models = sorted(set(k[0] for k in metrics)) - qualities = sorted(set(k[1] for k in metrics)) - modes = sorted(set(k[2] for k in metrics)) - t_test_results = paired_t_tests(metrics, models, qualities, modes) - - # Summarize performance - performance_df = summarize_performance(metrics) - - # Generate plots - plot_performance(performance_df) - - # Save results - results = { - "t_test_results": t_test_results, - "performance_summary": performance_df.to_dict(orient="records"), - } - with open("evaluation_results.json", "w", encoding="utf-8") as f: - json.dump(results, f, indent=4) - - print("Evaluation complete. Results saved to 'evaluation_results.json'.") - print( - "Plots saved: accuracy_vs_jpeg_quality.png, f1_vs_jpeg_quality.png, " - "auc_vs_jpeg_quality.png, metrics_finetune_bar.png, metrics_linear_probe_bar.png" - ) - - -if __name__ == "__main__": - main() diff --git a/src/__init__.py b/src/__init__.py index e69de29..1a66f1d 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -0,0 +1,3 @@ +""" +Root package for the project. +""" diff --git a/src/cli/train.py b/src/cli/train.py new file mode 100644 index 0000000..5fe1460 --- /dev/null +++ b/src/cli/train.py @@ -0,0 +1,139 @@ +# src/cli/train.py +import os +import sys +import time +import json +import random +import logging +from pathlib import Path +from typing import Dict, Any + +import torch +import numpy as np +import hydra +from omegaconf import DictConfig, OmegaConf + +# ---- Optional: import wrappers (these implement run(cfg) and return a dict of metrics) +from src.wrappers import probe as probe_wrapper +from src.wrappers import finetune as finetune_wrapper +from src.wrappers import distill as distill_wrapper + +log = logging.getLogger("train") + + +def _is_rank_zero() -> bool: + # Works for both torchrun (DDP) and single process + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + return local_rank == 0 + + +def _select_device(cfg: DictConfig) -> torch.device: + # Honor cfg.train.device if present; otherwise auto + if "device" in cfg.train and cfg.train.device: + dev = cfg.train.device + return torch.device(dev) + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _seed_everything(seed: int, deterministic: bool = False): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: + torch.backends.cudnn.benchmark = True + + +def _save_resolved_config(cfg: DictConfig, run_dir: Path): + if not _is_rank_zero(): + return + run_dir.mkdir(parents=True, exist_ok=True) + with open(run_dir / "resolved_config.yaml", "w") as f: + OmegaConf.save(config=cfg, f=f.name) + + +def _print_run_header(cfg: DictConfig, run_dir: Path, device: torch.device): + if not _is_rank_zero(): + return + banner = ( + f"\n=== TRAIN START ===\n" + f"mode : {cfg.train.mode}\n" + f"dataset : {cfg.dataset.name}\n" + f"model : {getattr(cfg.model, 'name', 'N/A')}\n" + f"device : {device}\n" + f"seed : {cfg.seed}\n" + f"run_dir : {str(run_dir)}\n" + f"===================\n" + ) + print(banner, flush=True) + + +def _dispatch_wrapper(cfg: DictConfig) -> Dict[str, Any]: + mode = cfg.train.mode.lower() + if mode == "probe": + return probe_wrapper.run(cfg) + elif mode == "finetune": + return finetune_wrapper.run(cfg) + elif mode == "distill": + return distill_wrapper.run(cfg) + else: + raise ValueError( + f"Unknown train.mode='{cfg.train.mode}'. " + f"Expected one of: probe | finetune | distill" + ) + + +@hydra.main(config_path="../../configs", config_name="defaults", version_base=None) +def main(cfg: DictConfig): + # ---- Resolve and freeze config for safety + OmegaConf.set_struct(cfg, False) # allow wrappers to attach runtime fields if needed + + # ---- Set up run directory (Hydra changes CWD to a unique run dir automatically) + run_dir = Path(os.getcwd()) + + # ---- Device & seeding + device = _select_device(cfg) + _seed_everything(seed=int(cfg.seed), deterministic=getattr(cfg.train, "deterministic", False)) + + # ---- Optional: attach runtime info for wrappers/engines + cfg.runtime = { + "device": str(device), + "start_time": time.strftime("%Y-%m-%d %H:%M:%S"), + "run_dir": str(run_dir), + "rank_zero": _is_rank_zero(), + "world_size": int(os.environ.get("WORLD_SIZE", "1")), + } + + # ---- Save resolved config + _save_resolved_config(cfg, run_dir) + _print_run_header(cfg, run_dir, device) + + # ---- Kick off the selected training paradigm via wrapper + try: + metrics = _dispatch_wrapper(cfg) + except KeyboardInterrupt: + if _is_rank_zero(): + print("\n⚠️ Training interrupted by user.", flush=True) + raise + except Exception as e: + # Surface a readable error at rank zero; still propagate for proper exit codes + if _is_rank_zero(): + print(f"\n❌ Training failed: {e}\n", flush=True) + raise + + # ---- Persist final metrics + if _is_rank_zero(): + metrics = metrics or {} + with open(run_dir / "final_metrics.json", "w") as f: + json.dump(metrics, f, indent=2) + print("✅ Train done. Final metrics written to final_metrics.json\n", flush=True) + + +if __name__ == "__main__": + # Support `python -m src.cli.train train=probe dataset=isic2019` + sys.exit(main()) diff --git a/src/compressed_perception/models/comparison/compare_baseline.py b/src/compressed_perception/models/comparison/compare_baseline.py deleted file mode 100644 index 4831da5..0000000 --- a/src/compressed_perception/models/comparison/compare_baseline.py +++ /dev/null @@ -1,339 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -# pylint: disable=invalid-name -""" -This script is a baseline for comparing different image classification models -at three different image compression levels, in comparison to the original. -It supports fine-tuning, linear probing, and optional degradation transforms. - -Purpose: General pipeline for comparing image classification models (ViT, DINOv2, SimCLR) with -optional image degradation transforms. -Features: -- Supports fine-tuning and linear probing. -- Balances dataset across classes. -- Applies optional transforms (JPEG, blur, quantization). -- Uses a fixed model list (currently ViT). -- No command-line argument parsing; hyperparameters are set in the main() function. -- Designed for baseline model comparison across compression levels. -""" - -# Environment Setup -import os -import argparse -from PIL import Image -from transformers import ( - AutoImageProcessor, - AutoModelForImageClassification, - ViTFeatureExtractor, - ViTForImageClassification, - Trainer, - TrainingArguments, -) -from datasets import load_dataset -import wandb -import timm - - -from src.compressed_perception.models.training.constants import ( - SSL_MODEL, SIMCLR_BACKBONE, FILTERED_CLASSES, NUM_FILTERED_CLASSES -) -from src.compressed_perception.models.training.utils_classes import SimCLRForClassification -from src.compressed_perception.models.training.utils_methods import ( - get_gpu_memory, GPU_AVAILABLE, freeze_backbone -) -from src.compressed_perception.modules.data_preparation.preparation import ( - filter_and_cast_dataset, - balance_dataset, - split_dataset, - get_default_transforms, - prepare_datasets, -) - -# Compatibility for LANCZOS resampling -try: - LANCZOS = Image.Resampling.LANCZOS -except AttributeError: - LANCZOS = Image.LANCZOS # pylint: disable=no-member - -# GPU Memory Monitoring (optional) -try: - import pynvml - pynvml.nvmlInit() - GPU_AVAILABLE = True -except ImportError: - GPU_AVAILABLE = False - print("pynvml not installed, GPU memory monitoring disabled.") - - -def initialize_model_and_preprocessor(model_info, resolution): - """ - Initialize model and preprocessor. - """ - _, model_id, typ, _ = ( - model_info["name"], - model_info["model_id"], - model_info["type"], - model_info["config"], - ) - - if typ == "vit": - preprocessor = ViTFeatureExtractor.from_pretrained( - model_id, - size=resolution, - do_resize=True, - resample=LANCZOS, - do_normalize=True, - image_mean=[0.485, 0.456, 0.406], - image_std=[0.229, 0.224, 0.225], - ) - model = ViTForImageClassification.from_pretrained( - model_id, - num_labels=NUM_FILTERED_CLASSES, - ignore_mismatched_sizes=True, - image_size=resolution, - ) - elif typ == "dinov2": - preprocessor = AutoImageProcessor.from_pretrained( - model_id, - size=resolution, - do_resize=True, - resample=LANCZOS, - do_normalize=True, - image_mean=[0.485, 0.456, 0.406], - image_std=[0.229, 0.224, 0.225], - ) - model = AutoModelForImageClassification.from_pretrained( - model_id, - num_labels=NUM_FILTERED_CLASSES, - ignore_mismatched_sizes=True, - image_size=resolution, - ) - elif typ == SSL_MODEL: - backbone = timm.create_model( - SIMCLR_BACKBONE, - pretrained=True, - num_classes=0, - ) - model = SimCLRForClassification(backbone, NUM_FILTERED_CLASSES) - freeze_backbone(model, SSL_MODEL) - preprocessor = None - else: - raise ValueError(f"Unsupported model type: {typ}") - - return model, preprocessor - - -def train_model( - model, - train_ds, - val_ds, - config -): - """ - Train the model using Hugging Face Trainer. - - Args: - model: The model to train. - train_ds: Training dataset. - val_ds: Validation dataset. - config (dict): Configuration dictionary with keys: - - model_name - - model_type - - resolution - - batch_size - - num_epochs - - learning_rate - - eval_steps - - wandb_config - - Returns: - dict: Evaluation results. - """ - wandb.init( - project="Model Comparison Baseline", - name=f"{config['model_name']}_{config['resolution']}_{config['num_epochs']}_epochs", - config=config['wandb_config'], - tags=["baseline", config['model_name'], f"res_{config['resolution']}"], - reinit=True, - ) - - train_args = TrainingArguments( - output_dir=f"./outputs/{config['model_name']}", - num_train_epochs=config['num_epochs'], - per_device_train_batch_size=config['batch_size'], - per_device_eval_batch_size=config['batch_size'], - learning_rate=config['learning_rate'], - lr_scheduler_type="cosine", - weight_decay=0.01, - logging_dir=f"./logs/{config['model_name']}", - logging_steps=1, - eval_strategy="steps", - eval_steps=config['eval_steps'], - save_strategy="steps", - save_steps=config['eval_steps'], - load_best_model_at_end=False, - metric_for_best_model="accuracy", - save_total_limit=1, - report_to=["wandb"], - ) - - trainer = Trainer( - model=model, - args=train_args, - train_dataset=train_ds, - eval_dataset=val_ds, - ) - - trainer.train() - eval_results = trainer.evaluate() - - wandb.finish() - return eval_results - - - -def main(config=None, dataset=None): - """ - Main pipeline for model comparison. - """ - if config is None: - config = { - "num_train_images": 100, - "proportion_per_transform": 0.2, - "resolution": 224, - "batch_size": 256, - "num_epochs": 3, - "eval_steps": 10, - "learning_rate": 1e-4, - "gpu_available": GPU_AVAILABLE, - } - - # Add weight_decay to config instead of creating a separate wandb_config - config["weight_decay"] = 0.01 - - # Define all model configurations - models = get_model_configs(config["resolution"]) - - if dataset is None: - raise ValueError("Dataset must be provided via `dataset` argument or CLI.") - - # Process dataset and create training/validation splits - train_ds, val_ds = process_dataset(dataset, config) - - # Train and evaluate each model - for model_info in models: - train_and_evaluate_model(model_info, train_ds, val_ds, config) - - -def get_model_configs(resolution): - """Return the model configurations for the baseline comparison.""" - MODEL_CONFIG_KEYS = { - "name": "name", - "model_id": "model_id", - "type": "type", - "config": "config", - } - - return [ - { - MODEL_CONFIG_KEYS["name"]: "vit", - MODEL_CONFIG_KEYS["model_id"]: "google/vit-base-patch16-224", - MODEL_CONFIG_KEYS["type"]: "vit", - MODEL_CONFIG_KEYS["config"]: { - "image_size": resolution, - "num_labels": NUM_FILTERED_CLASSES, - "ignore_mismatched_sizes": True - } - }, - ] - - -def process_dataset(dataset, config): - """Process the dataset and create training/validation splits.""" - dataset = filter_and_cast_dataset(dataset, FILTERED_CLASSES, NUM_FILTERED_CLASSES) - dataset = balance_dataset(dataset, config["num_train_images"], FILTERED_CLASSES) - splits = split_dataset(dataset, test_size=0.2, stratify_by_column="label", seed=42) - - # Get transforms - transform = get_default_transforms(config["resolution"], apply_transforms=True) - - # Prepare PyTorch datasets - train_ds, _ = prepare_datasets(splits["train"], transform, split_ratio=1.0) - val_ds, _ = prepare_datasets(splits["test"], transform, split_ratio=1.0) - - return train_ds, val_ds - - -def train_and_evaluate_model(model_info, train_ds, val_ds, config): - """Train and evaluate a single model, then log results.""" - model, _preprocessor = initialize_model_and_preprocessor(model_info, config["resolution"]) - - train_config = { - "model_name": model_info["name"], - "model_type": model_info["type"], - "resolution": config["resolution"], - "batch_size": config["batch_size"], - "num_epochs": config["num_epochs"], - "learning_rate": config["learning_rate"], - "eval_steps": config["eval_steps"], - "wandb_config": config, # Pass the entire config for wandb - } - - results = train_model(model, train_ds, val_ds, train_config) - print(f"Results for {model_info['name']}: {results}") - - # Define metric keys and log results to wandb - METRIC_KEYS = { - "learning_rate": "learning_rate", - "model_name": "model_name", - "model_type": "model_type", - "peak_memory_mb": "peak_memory_mb", - "flops_giga": "flops_giga", - "train_time_seconds": "train_time_seconds", - "eval_time_seconds": "eval_time_seconds", - "eval_metrics": "eval_metrics", - } - - metrics = { - METRIC_KEYS["learning_rate"]: config["learning_rate"], - METRIC_KEYS["model_name"]: model_info["name"], - METRIC_KEYS["model_type"]: model_info["type"], - METRIC_KEYS["peak_memory_mb"]: get_gpu_memory(), - METRIC_KEYS["flops_giga"]: None, - METRIC_KEYS["train_time_seconds"]: None, - METRIC_KEYS["eval_time_seconds"]: None, - METRIC_KEYS["eval_metrics"]: results, - } - wandb.log({"metrics": metrics}) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Baseline model comparison for image classification." - ) - parser.add_argument( - "--path_to_dataset", - type=str, - default=None, - help="Path to local dataset directory. If not provided, loads from Hugging Face.", - ) - args = parser.parse_args() - - loaded_dataset = None - if args.path_to_dataset: - loaded_dataset = load_dataset("imagefolder", data_dir=args.path_to_dataset, split="train") - else: - try: - loaded_dataset = load_dataset( - "MKZuziak/ISIC_2019_224", - cache_dir=os.environ["HF_DATASETS_CACHE"], - split="train", - ) - except Exception as e: - raise ValueError("No dataset provided and Hugging Face dataset failed to load.") from e - - main(dataset=loaded_dataset) diff --git a/src/compressed_perception/models/comparison/compare_lr_sweep.py b/src/compressed_perception/models/comparison/compare_lr_sweep.py deleted file mode 100644 index 35b7a53..0000000 --- a/src/compressed_perception/models/comparison/compare_lr_sweep.py +++ /dev/null @@ -1,456 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -# pylint: skip-file - -""" -This script is a baseline for comparing different image classification models -at three different image compression levels, in comparison to the original. -It has a set number of augmentation transforms and does NOT combine them. -This does NOT experiment on JPEG compression levels - - -Purpose: Specialized for learning rate experiments with image classification models. -Features: -- Focuses on learning rate sweeps for a single model (currently DINOv2). -- Uses command-line argument parsing for hyperparameters. -- Applies a set number of augmentation transforms (does not combine them). -- Does not experiment with JPEG compression levels. -- Logs results for each learning rate and saves them to a JSON file. -- More flexible for hyperparameter tuning and ablation studies. -""" - -# Environment Setup -import os - -# Standard Library -import json -import time -import argparse -import shutil - -# Scientific & Visualization Libraries -import numpy as np -from PIL import Image - -# PyTorch & Torchvision -import torch -from torch.utils.data import Subset, ConcatDataset -from torchvision import transforms - -# Hugging Face Transformers & Datasets -from transformers import ( - AutoImageProcessor, - AutoModelForImageClassification, - Trainer, - TrainingArguments, - ViTFeatureExtractor, - ViTForImageClassification, -) -from datasets import load_dataset - -# Weights & Biases -import wandb - -# Model Profiling & Vision Backbones -import timm - - -# Local Application Imports -from src.compressed_perception.models.training.constants import ( - HF_MODELS, SSL_MODEL, SIMCLR_BACKBONE, FILTERED_CLASSES, NUM_FILTERED_CLASSES -) -from src.compressed_perception.modules.data_transformation.image_transformation import ( - JPEGCompressionTransform, - GaussianBlurTransform, - ColorQuantizationTransform, -) -from src.compressed_perception.models.training.utils_classes import ( - ISICDataset, - SimCLRForClassification, - WandbCallback, - LossLoggerCallback, - get_trainer_callbacks -) -from src.compressed_perception.models.training.utils_methods import ( - env_path, - compute_metrics, - cleanup_model_dirs, - get_gpu_memory, - freeze_backbone, - GPU_AVAILABLE, - get_flops, - check_disk_space, - save_model_and_preprocessor, -) - -from src.compressed_perception.modules import ( - filter_and_cast_dataset, balance_dataset, split_dataset, get_default_transforms, prepare_datasets -) -# Set memory optimization -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - -# Cache paths -os.environ["TRANSFORMERS_CACHE"] = os.getenv( - "TRANSFORMERS_CACHE", "~/.cache/huggingface/transformers" -) -os.environ["HF_DATASETS_CACHE"] = os.getenv( - "HF_DATASETS_CACHE", "~/.cache/huggingface/datasets" -) -os.environ["HF_HOME"] = os.getenv("HF_HOME", "~/.cache/huggingface") - - -def create_preprocessors(model_config, config): - """ - Create preprocessors for each model type. - """ - preprocessors = {} - for model_info in [model_config]: - _name, model_id, typ, _config = ( - model_info["name"], - model_info["model_id"], - model_info["type"], - model_info["config"], - ) - if typ == "vit": - preprocessors[typ] = ViTFeatureExtractor.from_pretrained( - model_id, - size=config["resolution"], - do_resize=True, - resample=Image.LANCZOS, # pylint: disable=no-member - do_normalize=True, - image_mean=[0.485, 0.456, 0.406], - image_std=[0.229, 0.224, 0.225] - ) - elif typ == "dinov2": - preprocessors[typ] = AutoImageProcessor.from_pretrained( - model_id, - size=config["resolution"], - do_resize=True, - resample=Image.LANCZOS, # pylint: disable=no-member - do_normalize=True, - image_mean=[0.485, 0.456, 0.406], - image_std=[0.229, 0.224, 0.225] - ) - else: - preprocessors[typ] = None - - return preprocessors - -def train_for_learning_rate( # pylint: disable=too-many-locals - learning_rate, model_config, train_dataset, val_dataset, config -): - """ - Train and evaluate the model for a given learning rate. - """ - preprocessors = config["preprocessors"] - name, model_id, typ, _config = ( - model_config["name"], - model_config["model_id"], - model_config["type"], - model_config["config"], - ) - - train_ds = get_transformed_datasets(train_dataset, preprocessors, config, typ) - val_ds = ISICDataset( # pylint: disable=too-many-function-args - dataset=val_dataset, - transform=None, - resolution=config["resolution"], - model_type=typ, - ) - - model = get_model(typ, model_id, config) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - - flops = get_flops(model, config["resolution"]) - check_disk_space(min_gb=1) - cleanup_model_dirs(name, learning_rate) - - wandb.init( - entity="ericcui-use-stanford-university", - project="CS231N Test", - name=f"{name}_{config['resolution']}_lr_{learning_rate}", - config={ - "model_name": name, - "resolution": config["resolution"], - "batch_size": config["batch_size"], - "num_epochs": config["num_epochs"], - "eval_steps": config["eval_steps"], - "learning_rate": learning_rate, - "weight_decay": 0.01, - "gpu_available": GPU_AVAILABLE, - }, - tags=[ - "learning_rate_experiment", - f"lr_{learning_rate}", - f"resolution_{config['resolution']}"], - ) - - train_args = get_training_args(name, learning_rate, config) - callbacks = get_trainer_callbacks(name) - - trainer = Trainer( - model=model, - args=train_args, - train_dataset=train_ds, - eval_dataset=val_ds, - compute_metrics=lambda pred: compute_metrics(pred, name), - callbacks=callbacks, - ) - - start_time = time.time() - peak_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - - if typ in HF_MODELS: - wandb.watch(model, log="all", log_freq=100) - elif typ == SSL_MODEL: - wandb.watch(model.backbone, log="all", log_freq=100) - - trainer.train() - - current_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - peak_memory = max(peak_memory, current_memory) - - eval_start_time = time.time() - eval_results = trainer.evaluate() - eval_time = time.time() - eval_start_time - train_time = time.time() - start_time - eval_time - - metrics = { - "learning_rate": learning_rate, - "model_name": name, - "model_type": typ, - "peak_memory_mb": peak_memory, - "flops_giga": flops, - "train_time_seconds": train_time, - "eval_time_seconds": eval_time, - "eval_metrics": eval_results, - } - wandb.log(metrics) - - model_dir = save_model_and_preprocessor(model, preprocessors, typ, name, learning_rate) - log_wandb_artifact(model_dir, name, learning_rate) - - print(f"[Finetune] Learning Rate {learning_rate}: {metrics}") - - wandb.finish() - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return metrics - -def get_model(typ, model_id, config): - """ - Returns the initialized model based on type. - """ - if typ == "vit": - return ViTForImageClassification.from_pretrained( - model_id, - num_labels=NUM_FILTERED_CLASSES, - ignore_mismatched_sizes=True, - image_size=config["resolution"], - ) - if typ == "dinov2": - return AutoModelForImageClassification.from_pretrained( - model_id, - num_labels=NUM_FILTERED_CLASSES, - ignore_mismatched_sizes=True, - image_size=config["resolution"] - ) - if typ == SSL_MODEL: - backbone = timm.create_model( # pylint: disable=too-many-function-args - SIMCLR_BACKBONE, - pretrained=True, - num_classes=0 - ) - model = SimCLRForClassification(backbone, NUM_FILTERED_CLASSES) - freeze_backbone(model, SSL_MODEL) - return model - raise ValueError(f"Unknown model type: {typ}") - -def get_transformed_datasets(train_dataset, preprocessors, config, typ): - """ - Returns a concatenated dataset with optional degradation transforms applied. - """ - num_images = len(train_dataset) - images_per_transform = int(num_images * config["proportion_per_transform"]) - indices = np.random.permutation(num_images) - - transforms_list = [ - JPEGCompressionTransform(), - GaussianBlurTransform(), - ColorQuantizationTransform(), - ] - - def make_subset(indices_subset, transform=None): - subset = Subset(train_dataset, indices_subset) - transform_compose = transforms.Compose([transform]) if transform else None - return ISICDataset( # pylint: disable=too-many-function-args - subset, - preprocessors[typ], - config["resolution"], - transform_compose, - typ - ) - - datasets = [] - used_indices = set() - - for i, transform in enumerate(transforms_list): - idx = indices[i * images_per_transform : (i + 1) * images_per_transform] - used_indices.update(idx) - datasets.append(make_subset(idx, transform)) - - remaining = np.setdiff1d(indices, list(used_indices)) - if len(remaining) > 0: - datasets.append(make_subset(remaining)) - - return ConcatDataset(datasets) - -def log_wandb_artifact(model_dir, name, learning_rate): - """ - Logs the saved model directory as a wandb artifact. - """ - artifact = wandb.Artifact( - name=f"{name}_lr_{learning_rate}_model", - type="model", - description=f"Trained {name} model with {learning_rate} learning rate" - ) - artifact.add_dir(model_dir) - wandb.log_artifact(artifact) - -def get_training_args(name, learning_rate, config): - """ - Returns a TrainingArguments object for Hugging Face Trainer. - """ - return TrainingArguments( - output_dir=os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}_lr_{learning_rate}"), - num_train_epochs=config["num_epochs"], - per_device_train_batch_size=config["batch_size"], - per_device_eval_batch_size=config["batch_size"], - learning_rate=learning_rate, - lr_scheduler_type="cosine", - weight_decay=0.01, - logging_dir=os.path.join(env_path("LOG_DIR", "."), f"{name}_lr_{learning_rate}"), - logging_steps=1, - evaluation_strategy="steps", - eval_steps=config["eval_steps"], - save_strategy="steps", - save_steps=config["eval_steps"], - load_best_model_at_end=False, - metric_for_best_model="accuracy", - save_total_limit=1, - save_safetensors=False, - hub_model_id=None, - hub_strategy="end", - push_to_hub=False, - save_only_model=True, - ) - -def main(config=None): - """ - Main function for running learning rate sweep experiments on image classification models. - - Args: - config (dict): Configuration dictionary. - """ - if config is None: - config = { - "num_train_images": 25000, - "proportion_per_transform": 0.2, - "resolution": 224, - "batch_size": 256, - "num_epochs": 3, - "eval_steps": 10, - "learning_rate": 1e-4, - } - - model_config = { - "name": "dinov2", - "model_id": "facebook/dinov2-base", - "type": "dinov2", - "config": { - "image_size": config["resolution"], - "num_labels": NUM_FILTERED_CLASSES, - "ignore_mismatched_sizes": True - } - } - - learning_rates = [config["learning_rate"]] - results = {} - - dataset = load_dataset( - "MKZuziak/ISIC_2019_224", - cache_dir=os.environ["HF_DATASETS_CACHE"], - split="train", - ) - - filtered_dataset = filter_and_cast_dataset(dataset, FILTERED_CLASSES, NUM_FILTERED_CLASSES) - balanced_dataset = balance_dataset(filtered_dataset, config["num_train_images"], FILTERED_CLASSES) - splits = split_dataset(balanced_dataset) - transform = get_default_transforms(resolution=224, apply_transforms=True) - train_dataset, val_dataset = prepare_datasets(splits["train"], transform) - - preprocessors = create_preprocessors(model_config, config) - config["preprocessors"] = preprocessors - - for lr in learning_rates: - results[str(lr)] = train_for_learning_rate( - lr, model_config, train_dataset, val_dataset, config - ) - - # Save results - with open( - os.path.join( - env_path("TRAIN_OUTPUT_DIR", "."), "results_metrics_lr_experiment.json" - ), - "w", - encoding="utf-8" - ) as f: - json.dump(results, f, indent=4) - - -if __name__ == "__main__": - # Entry point for running the learning rate sweep experiment from the command line. - # Parses command-line arguments and calls main(). - parser = argparse.ArgumentParser( - description="Learning rate experiment for image classification." - ) - parser.add_argument( - '--resolution', type=int, default=224, - help='Input image resolution (default: 224)' - ) - parser.add_argument( - '--batch_size', type=int, default=128, - help='Batch size for training and evaluation (default: 128)' - ) - parser.add_argument( - '--num_train_images', type=int, default=500, - help='Number of training images to use per class (default: 500)' - ) - parser.add_argument( - '--num_epochs', type=int, default=3, - help='Number of training epochs (default: 3)' - ) - parser.add_argument( - '--eval_steps', type=int, default=100, - help='Number of steps between evaluations (default: 100)' - ) - parser.add_argument( - '--learning_rate', type=float, default=1e-4, - help='Learning rate (default: 1e-4)' - ) - args = parser.parse_args() - main({ - "resolution": args.resolution, - "batch_size": args.batch_size, - "num_train_images": args.num_train_images, - "num_epochs": args.num_epochs, - "eval_steps": args.eval_steps, - "learning_rate": args.learning_rate, - }) diff --git a/src/compressed_perception/models/training/constants.py b/src/compressed_perception/models/training/constants.py deleted file mode 100644 index e745410..0000000 --- a/src/compressed_perception/models/training/constants.py +++ /dev/null @@ -1,16 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -"""Configuration constants for model training and evaluation.""" - -HF_MODELS = ["vit", "dinov2"] -SSL_MODEL = "simclr" -SIMCLR_BACKBONE = "resnet50" -NUM_CLASSES = 8 - -# Filter constants -FILTERED_CLASSES = ["0", "1"] # Classes to use after filtering -NUM_FILTERED_CLASSES = len(FILTERED_CLASSES) # Number of classes after filtering diff --git a/src/compressed_perception/models/training/utils_classes.py b/src/compressed_perception/models/training/utils_classes.py deleted file mode 100644 index 1c0ade4..0000000 --- a/src/compressed_perception/models/training/utils_classes.py +++ /dev/null @@ -1,239 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -""" -Utility classes for model training and data handling. -""" - -import os -import json -import numpy as np -import torch -from torch import nn -from torch.utils.data import Dataset -from PIL import Image -from torchvision import transforms -from transformers import TrainerCallback -import wandb - -# Local imports -from src.compressed_perception.models.training.constants import ( - HF_MODELS, NUM_FILTERED_CLASSES, SSL_MODEL -) -from src.compressed_perception.models.training.utils_methods import ( - GPU_AVAILABLE, get_gpu_memory - ) -from src.compressed_perception.modules.data_transformation.image_transformation import ( - JPEGCompressionTransform -) - -# Compatibility for LANCZOS resampling -try: - LANCZOS = Image.Resampling.LANCZOS -except AttributeError: - LANCZOS = Image.LANCZOS # pylint: disable=no-member - - -class WandbCallback(TrainerCallback): - """ - Custom callback for logging metrics and evaluation results to Weights & Biases. - Tracks best accuracy and GPU memory usage if available. - """ - MODEL_KEY = "model" - PHASE_KEY = "phase" - BEST_ACCURACY_KEY = "best_accuracy" - EVAL_ACCURACY_KEY = "eval_accuracy" - GPU_MEMORY_KEY = "gpu_memory_mb" - - def __init__(self, model_name, phase): - self.model_name = model_name - self.phase = phase - self.best_accuracy = 0.0 - - def on_log(self, _args, _state, _control, logs=None, **_kwargs): - if logs is not None: - logs[self.MODEL_KEY] = self.model_name - logs[self.PHASE_KEY] = self.phase - try: - if GPU_AVAILABLE: - logs[self.GPU_MEMORY_KEY] = get_gpu_memory() - except ImportError: - pass - wandb.log(logs) - - def on_evaluate(self, _args, _state, _control, metrics=None, **_kwargs): - if metrics is not None: - if self.EVAL_ACCURACY_KEY in metrics: - self.best_accuracy = max(self.best_accuracy, metrics[self.EVAL_ACCURACY_KEY]) - metrics[self.BEST_ACCURACY_KEY] = self.best_accuracy - wandb.log(metrics) - -def get_trainer_callbacks(name): - """Get callbacks for the Trainer.""" - return [ - LossLoggerCallback( - log_dir=os.environ.get("LOG_DIR", "./logs"), - phase="finetune", - model_name=name, - ), - WandbCallback(name, "finetune"), - ] - -class ISICDataset(Dataset): - """ - Dataset class for handling ISIC image data with optional transformations. - """ - def __init__(self, dataset, config=None): - """ - Args: - dataset: The dataset to load. - config (dict, optional): Configuration dictionary with keys: - - preprocessor - - resolution - - transform - - model_type - - jpeg_quality - """ - self.dataset = dataset - config = config or {} - self.preprocessor = config.get("preprocessor", None) - self.resolution = config.get("resolution", 224) - self.transform = config.get("transform", None) - self.model_type = config.get("model_type", "vit") - self.jpeg_quality = config.get("jpeg_quality", None) - - # Base preprocessing pipeline for resizing and tensor conversion - self.base_preprocessor = transforms.Compose([ - transforms.Resize((self.resolution, self.resolution), LANCZOS), - transforms.ToTensor(), - ]) - - # Preprocessor for SSL models - if self.model_type == SSL_MODEL: - self.preprocessor = transforms.Compose([ - transforms.Resize((self.resolution, self.resolution), LANCZOS), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx): - # Convert numpy.int64 to Python int if necessary - if isinstance(idx, (np.integer, np.int64)): - idx = int(idx) - - # Handle both direct dataset access and Subset access - if hasattr(self.dataset, 'dataset'): - # This is a Subset - subset_idx = int(self.dataset.indices[idx]) - item = self.dataset.dataset[subset_idx] - else: - # This is a direct dataset - item = self.dataset[idx] - - image = item["image"] - label = item["label"] - - # Always resize to target resolution first - image = image.resize((self.resolution, self.resolution), LANCZOS) - - # Apply additional transformations if provided - if self.transform: - image = self.transform(image) - - # Apply JPEG compression if specified - if self.jpeg_quality is not None: - image = JPEGCompressionTransform(self.jpeg_quality)(image) - - # Preprocessing for Hugging Face models - if self.model_type in HF_MODELS: - # Ensure the preprocessor doesn't resize again - if hasattr(self.preprocessor, 'size'): - self.preprocessor.size = self.resolution - encoding = self.preprocessor(images=image, return_tensors="pt") - pixel_values = encoding["pixel_values"].squeeze(0) - elif self.model_type == SSL_MODEL: - pixel_values = self.preprocessor(image) - else: - raise ValueError(f"Unsupported model_type: {self.model_type}") - - label = torch.tensor(label, dtype=torch.long) - return {"pixel_values": pixel_values, "labels": label} - - -class SimCLRForClassification(nn.Module): # pylint: disable=too-few-public-methods - """ - SimCLR-based classification model. - """ - def __init__(self, backbone, num_classes=NUM_FILTERED_CLASSES): - """ - SimCLR-based classification model. - - Args: - backbone: The backbone model (e.g., ResNet). - num_classes: Number of output classes. - """ - super().__init__() - self.backbone = backbone - self.classifier = nn.Linear(2048, num_classes) - - def forward(self, pixel_values, labels=None): - """ - Forward pass for the model. - - Args: - pixel_values: Input image tensors. - labels: Ground truth labels (optional). - - Returns: - dict: Dictionary containing logits and loss (if labels are provided). - """ - features = self.backbone(pixel_values) - logits = self.classifier(features) - loss = None - if labels is not None: - loss = nn.CrossEntropyLoss()(logits, labels) - return ( - {"logits": logits, "loss": loss} if loss is not None else {"logits": logits} - ) - - -class LossLoggerCallback(TrainerCallback): # pylint: disable=too-few-public-methods - """ - Logs each training step's loss and other metrics to a structured JSON Lines file. - """ - - def __init__(self, log_dir: str, phase: str, model_name: str): - """ - Initialize the callback. - - Args: - log_dir: Directory to save the log file. - phase: Training phase (e.g., "finetune"). - model_name: Name of the model. - """ - os.makedirs(log_dir, exist_ok=True) - self.log_file = os.path.join( - log_dir, f"{model_name}_{phase}_log.jsonl" - ) - - def on_log(self, _args, state, _control, logs=None, **_kwargs): - """ - Log metrics to a JSON Lines file. - - Args: - args: Training arguments. - state: Trainer state. - control: Trainer control. - logs: Metrics to log. - """ - if logs is None: - return - with open(self.log_file, "a", encoding="utf-8") as f: - json.dump({"step": state.global_step, **logs}, f) - f.write("\n") diff --git a/src/compressed_perception/models/training/utils_methods.py b/src/compressed_perception/models/training/utils_methods.py deleted file mode 100644 index c716ea4..0000000 --- a/src/compressed_perception/models/training/utils_methods.py +++ /dev/null @@ -1,180 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -"""Utility methods for model training and evaluation. -This module provides functions for computing evaluation metrics, managing GPU memory, -freezing model backbones, and handling environment paths.""" - -import os -import shutil -import json -import numpy as np -from thop import profile -import torch -import matplotlib.pyplot as plt -import seaborn as sns -from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix -import pynvml - -# Local imports -from src.compressed_perception.models.training.constants import HF_MODELS - -# Constants -GPU_AVAILABLE = torch.cuda.is_available() - - -def env_path(key, default): - """Get environment variable or default value.""" - return os.environ.get(key, default) - - -def compute_metrics(eval_pred, model_name): - """ - Compute evaluation metrics from model predictions. - """ - logits, labels = eval_pred - predictions = np.argmax(logits, axis=-1) - acc = accuracy_score(labels, predictions) - f1 = f1_score(labels, predictions, average="weighted") - - # For binary classification, use the probability of the positive class - probs = torch.softmax(torch.tensor(logits), dim=1).numpy() - # Use the probability of class 1 (positive class) for ROC AUC - auc = roc_auc_score(labels, probs[:, 1]) - - plot_dir = os.path.join( - env_path("PLOT_DIR", "."), model_name - ) - os.makedirs(plot_dir, exist_ok=True) - - conf_mat = confusion_matrix(labels, predictions) - plt.figure(figsize=(10, 10)) - sns.heatmap(conf_mat, annot=True, cmap="Blues") - plt.xlabel("Predicted labels") - plt.ylabel("True labels") - plt.title(f"{model_name}_conf_mat") - plt.savefig(os.path.join(plot_dir, "conf_mat.png"), dpi=300, bbox_inches="tight") - plt.close() - - unique, counts = np.unique(predictions, return_counts=True) - class_breakdown = {str(k): int(v) for k, v in zip(unique, counts)} - with open(os.path.join(plot_dir, "class_breakdown.json"), "w", encoding="utf-8") as f: - json.dump(class_breakdown, f) - - return {"accuracy": acc, "f1": f1, "auc": auc} - -def cleanup_model_dirs(name, learning_rate): - """ - Removes and recreates model/log directories for the current run. - """ - model_dirs = [ - os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}_lr_{learning_rate}"), - os.path.join(env_path("MODEL_DIR", "."), f"{name}_lr_{learning_rate}"), - os.path.join(env_path("LOG_DIR", "."), f"{name}_lr_{learning_rate}"), - ] - for dir_path in model_dirs: - if os.path.exists(dir_path): - print(f"Cleaning up directory: {dir_path}") - shutil.rmtree(dir_path) - os.makedirs(dir_path, exist_ok=True) - -def get_gpu_memory(device_id=0): - """ - Get the used GPU memory in MB for a specific device. - """ - if not GPU_AVAILABLE: - return -1 - try: - handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) - mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) - return mem_info.used / 1024**2 - except pynvml.NVMLError: - return -1 - except Exception: # pylint: disable=broad-exception-caught - return -1 - -def freeze_backbone(model, model_type): - """ - Freeze the backbone of the model based on its type. - """ - if model_type in HF_MODELS: - for name, param in model.named_parameters(): - if "classifier" not in name: - param.requires_grad = False - elif model_type == "simclr": - for param in model.backbone.parameters(): - param.requires_grad = False - for param in model.classifier.parameters(): - param.requires_grad = True - else: - raise ValueError(f"Unsupported model_type: {model_type}") - -def get_flops(model, resolution): - """ - Profile FLOPs for the given model and resolution. - - Args: - model: The model to profile. - resolution (int): Input image resolution. - - Returns: - float: FLOPs in giga units, or -1 if profiling fails. - """ - try: - dummy_input = torch.randn(1, 3, resolution, resolution).to(next(model.parameters()).device) - flops, _, _ = profile(model, inputs=(dummy_input,)) - return flops / 1e9 - except Exception as e: # pylint: disable=broad-exception-caught - print(f"FLOP profiling failed: {e}") - return -1 - -def check_disk_space(min_gb=1): - """ - Checks if there is at least min_gb GB of free disk space. - - Args: - min_gb (int): Minimum required free disk space in GB. - - Raises: - RuntimeError: If not enough disk space is available. - """ - total, used, free = shutil.disk_usage("/") - print( - f"Disk space: Total={total // (2**30)} GB, " - f"Used={used // (2**30)} GB, " - f"Free={free // (2**30)} GB" - ) - if free < min_gb * (2**30): - raise RuntimeError(f"Not enough disk space. Please free up at least {min_gb}GB.") - -def save_model_and_preprocessor(model, preprocessors, typ, name, learning_rate): - """ - Saves the trained model and preprocessor to disk. - - Args: - model: The trained model. - preprocessors (dict): Preprocessors for each model type. - typ (str): Model type. - name (str): Model name. - learning_rate (float): Learning rate used for training. - - Returns: - str: Path to the saved model directory. - """ - model_dir = os.path.join(env_path("MODEL_DIR", "."), f"{name}_lr_{learning_rate}") - os.makedirs(model_dir, exist_ok=True) - if typ in HF_MODELS: - model.save_pretrained(model_dir) - preprocessors[typ].save_pretrained(model_dir) - elif typ == "simclr": - torch.save(model.state_dict(), os.path.join(model_dir, "pytorch_model.bin")) - with open(os.path.join(model_dir, "config.json"), "w", encoding="utf-8") as f: - json.dump({ - "model_type": "simclr", - "backbone": "resnet50", - "num_classes": model.classifier.out_features, - }, f) - return model_dir diff --git a/src/compressed_perception/modules/data_preparation/preparation.py b/src/compressed_perception/modules/data_preparation/preparation.py deleted file mode 100644 index 1a34862..0000000 --- a/src/compressed_perception/modules/data_preparation/preparation.py +++ /dev/null @@ -1,113 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -""" -Utility functions for dataset preparation. -""" -# Standard library imports -import io - -# Thrid-party imports -import numpy as np -from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms -from datasets import ClassLabel - -# Local imports -from src.compressed_perception.modules.data_transformation.image_transformation import ( - JPEGCompressionTransform, - GaussianBlurTransform, - ColorQuantizationTransform - ) - -def filter_and_cast_dataset(dataset, filtered_classes, num_classes): - """ - Filter dataset by class labels and cast label column. - """ - filtered_indices = [ - i for i, label in enumerate(dataset["label"]) - if str(label) in filtered_classes - ] - dataset = dataset.select(filtered_indices) - dataset = dataset.cast_column("label", ClassLabel(num_classes=num_classes)) - return dataset - -def balance_dataset(dataset, num_train_images, filtered_classes): - """ - Balance the dataset by sampling an equal number of images per class. - """ - class_counts = {label: 0 for label in filtered_classes} - for label in dataset["label"]: - class_counts[str(label)] += 1 - - min_class_size = min(class_counts.values()) - images_per_class = min(num_train_images // len(filtered_classes), min_class_size) - - np.random.seed(42) - balanced_indices = [] - for label in filtered_classes: - class_indices = [i for i, l in enumerate(dataset["label"]) if str(l) == label] - sampled_indices = np.random.choice(class_indices, images_per_class, replace=False) - balanced_indices.extend(sampled_indices) - - np.random.shuffle(balanced_indices) - return dataset.select(balanced_indices) - -def split_dataset(dataset, test_size=0.2, stratify_by_column="label", seed=42): - """ - Split dataset into train and validation sets. - """ - return dataset.train_test_split( - test_size=test_size, - stratify_by_column=stratify_by_column, - seed=seed - ) - -def get_default_transforms(resolution, apply_transforms=False): - """ - Get torchvision transform pipeline. - """ - transform_list = [ - transforms.Resize((resolution, resolution)), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] - - if apply_transforms: - transform_list.extend([ - JPEGCompressionTransform(quality=75), - GaussianBlurTransform(p=0.5), - ColorQuantizationTransform(p=0.5), - ]) - - return transforms.Compose(transform_list) - -class TorchDataset(Dataset): - """ - PyTorch Dataset wrapper for Hugging Face datasets. - """ - def __init__(self, hf_dataset, transform): - self.hf_dataset = hf_dataset - self.transform = transform - - def __len__(self): - return len(self.hf_dataset) - - def __getitem__(self, idx): - item = self.hf_dataset[idx] - image = Image.open(io.BytesIO(item["image"])).convert("RGB") - image = self.transform(image) if self.transform else image - return {"pixel_values": image, "labels": int(item["label"])} - -def prepare_datasets(dataset, transform, split_ratio=0.8): - """ - Prepare PyTorch-compatible train and val datasets. - """ - train_size = int(split_ratio * len(dataset)) - train_dataset = dataset.select(range(train_size)) - val_dataset = dataset.select(range(train_size, len(dataset))) - return TorchDataset(train_dataset, transform), TorchDataset(val_dataset, transform) diff --git a/src/compressed_perception/modules/data_transformation/__init__.py b/src/compressed_perception/modules/data_transformation/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/compressed_perception/modules/data_transformation/image_transformation.py b/src/compressed_perception/modules/data_transformation/image_transformation.py deleted file mode 100644 index fcd3b88..0000000 --- a/src/compressed_perception/modules/data_transformation/image_transformation.py +++ /dev/null @@ -1,160 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -""" -Image transformation utilities for data augmentation and degradation. -""" - -import io -import random -from PIL import Image -from torchvision import transforms - -# Compatibility for LANCZOS resampling -try: - LANCZOS = Image.Resampling.LANCZOS -except AttributeError: - LANCZOS = Image.LANCZOS # pylint: disable=no-member - - -class JPEGCompressionTransform: - """ - Apply JPEG compression to an image to simulate lossy compression artifacts. - """ - def __init__(self, quality=75): - """ - Apply JPEG compression to an image. - - Args: - quality (int): Compression quality (1-100, higher is better quality). - """ - self.quality = quality - - def __call__(self, img): - """ - Apply JPEG compression to the input image. - - Args: - img (PIL.Image or Tensor): Input image. - - Returns: - PIL.Image: Compressed image. - """ - if not isinstance(img, Image.Image): - img = transforms.ToPILImage()(img) - - # Store original size - original_size = img.size - - # Apply JPEG compression - buffer = io.BytesIO() - img.save(buffer, format="JPEG", quality=self.quality) - buffer.seek(0) - img = Image.open(buffer) - - # Ensure size is maintained - if img.size != original_size: - img = img.resize(original_size, LANCZOS) - - return img - - def get_quality(self): - """ - Return the JPEG compression quality setting. - """ - return self.quality - - -class GaussianBlurTransform: - """Apply Gaussian blur to an image with a given probability. - """ - def __init__(self, p=1): - """ - Apply Gaussian blur to an image with a given probability. - - Args: - p (float): Probability of applying the blur (0 to 1). - """ - self.p = p - - def __call__(self, img): - """ - Apply Gaussian blur to the input image. - - Args: - img (PIL.Image or Tensor): Input image. - - Returns: - PIL.Image: Blurred image. - """ - if not isinstance(img, Image.Image): - img = transforms.ToPILImage()(img) - - # Store original size - original_size = img.size - - # Apply Gaussian blur with probability p - if random.random() < self.p: - kernel_size = random.choice([3, 5, 7]) - sigma = random.uniform(0.1, 2.0) - img = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma)(img) - - # Ensure size is maintained - if img.size != original_size: - img = img.resize(original_size, LANCZOS) - - return img - - def get_probability(self): - """ - Return the probability of applying Gaussian blur. - """ - return self.p - - -class ColorQuantizationTransform: - """ - Apply color quantization to an image with a given probability. - """ - def __init__(self, p=1): - """ - Args: - p (float): Probability of applying the quantization (0 to 1). - """ - self.p = p - - def __call__(self, img): - """ - Apply color quantization to the input image. - - Args: - img (PIL.Image or Tensor): Input image. - - Returns: - PIL.Image: Quantized image. - """ - if not isinstance(img, Image.Image): - img = transforms.ToPILImage()(img) - - # Store original size - original_size = img.size - - # Apply color quantization with probability p - if random.random() < self.p: - num_colors = random.randint(16, 64) - img = img.quantize(colors=num_colors, method=Image.Quantize.MAXCOVERAGE).convert("RGB") - - # Ensure size is maintained - if img.size != original_size: - img = img.resize(original_size, LANCZOS) - - return img - - def get_probability(self): - """ - Return the probability of applying color quantization. - """ - return self.p diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..88abf6f --- /dev/null +++ b/src/config.py @@ -0,0 +1,64 @@ +"""Configuration and constants.""" +from dataclasses import dataclass +from typing import List, Dict, Any + +# Model constants +HF_MODELS = ["vit", "dinov2"] + +# Dataset constants +NUM_CLASSES = 8 +FILTERED_CLASSES = ["0", "1"] +NUM_FILTERED_CLASSES = len(FILTERED_CLASSES) + +# Image constants +DEFAULT_IMAGE_SIZE = 224 +IMAGE_NORMALIZATION = { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], +} + +@dataclass +class TrainingConfig: + """Training configuration.""" + num_train_images: int = 100 + proportion_per_transform: float = 0.2 + resolution: int = 224 + batch_size: int = 256 + num_epochs: int = 3 + eval_steps: int = 10 + learning_rate: float = 1e-4 + weight_decay: float = 0.01 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for logging.""" + return self.__dict__ + + def to_wandb_config(self) -> Dict[str, Any]: + """Create wandb configuration.""" + import torch + return { + **self.to_dict(), + "gpu_available": torch.cuda.is_available(), + } + +# Model configurations +MODEL_REGISTRY = [ + { + "name": "vit", + "model_id": "google/vit-base-patch16-224", + "type": "vit", + "config": { + "num_labels": NUM_FILTERED_CLASSES, + "ignore_mismatched_sizes": True + } + }, + { + "name": "dinov2", + "model_id": "facebook/dinov2-base", + "type": "dinov2", + "config": { + "num_labels": NUM_FILTERED_CLASSES, + "ignore_mismatched_sizes": True + } + }, +] \ No newline at end of file diff --git a/src/data/data_utils.py b/src/data/data_utils.py new file mode 100644 index 0000000..351e018 --- /dev/null +++ b/src/data/data_utils.py @@ -0,0 +1,200 @@ +"""Dataset implementations and data utilities.""" +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset, ConcatDataset, Subset +from torchvision import transforms +from typing import Optional, List, Dict, Any, Union + +from src.config import HF_MODELS, DEFAULT_IMAGE_SIZE +from src.transformation.transforms import ResolutionReductionTransform + +class ISICDataset(Dataset): + """ISIC dataset with support for multiple model types and transformations.""" + + def __init__( + self, + dataset: Union[Dataset, Subset], + preprocessor: Optional[Any] = None, + resolution: int = DEFAULT_IMAGE_SIZE, + transform: Optional[transforms.Compose] = None, + model_type: str = "vit", + jpeg_quality: Optional[int] = None, + ): + self.dataset = dataset + self.preprocessor = preprocessor + self.resolution = resolution + self.transform = transform + self.model_type = model_type + self.jpeg_quality = jpeg_quality + + # Create base preprocessing + self.base_preprocessor = transforms.Compose([ + transforms.Resize((resolution, resolution), Image.LANCZOS), + transforms.ToTensor(), + ]) + + self.model_preprocessor = None + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + # Convert numpy types to Python int + if isinstance(idx, (np.integer, np.int64)): + idx = int(idx) + + # Handle both direct dataset and Subset access + if hasattr(self.dataset, 'dataset'): + # This is a Subset + subset_idx = int(self.dataset.indices[idx]) + item = self.dataset.dataset[subset_idx] + else: + # Direct dataset access + item = self.dataset[idx] + + image = item["image"] + label = item["label"] + + # Resize to target resolution + image = image.resize((self.resolution, self.resolution), Image.LANCZOS) + + # Apply optional transformations + if self.transform: + image = self.transform(image) + + if self.jpeg_quality is not None: + image = JPEGCompressionTransform(self.jpeg_quality)(image) + + # Apply model-specific preprocessing + if self.model_type in HF_MODELS: + # For HuggingFace models + if hasattr(self.preprocessor, 'size'): + self.preprocessor.size = self.resolution + encoding = self.preprocessor(images=image, return_tensors="pt") + pixel_values = encoding["pixel_values"].squeeze(0) + else: + raise ValueError(f"Unsupported model_type: {self.model_type}") + + label = torch.tensor(label, dtype=torch.long) + return {"pixel_values": pixel_values, "labels": label} + +def create_transformed_datasets( + train_dataset: Dataset, + val_dataset: Dataset, + transforms_list: List, + proportion_per_transform: float, + preprocessor: Optional[Any], + resolution: int, + model_type: str +) -> tuple[Dataset, Dataset]: + """ + Create train and validation datasets with transformations. + + Args: + train_dataset: Training dataset + val_dataset: Validation dataset + transforms_list: List of transforms to apply + proportion_per_transform: Proportion of data for each transform + preprocessor: Model preprocessor + resolution: Image resolution + model_type: Type of model + + Returns: + Tuple of (train_dataset, val_dataset) + """ + num_images = len(train_dataset) + images_per_transform = int(num_images * proportion_per_transform) + + # Shuffle indices + indices = np.arange(num_images) + np.random.shuffle(indices) + + transformed_datasets = [] + used_indices = [] + + # Apply each transform to a subset + for i, transform in enumerate(transforms_list): + start_idx = i * images_per_transform + end_idx = start_idx + images_per_transform + subset_indices = indices[start_idx:end_idx] + used_indices.extend(subset_indices) + + subset = Subset(train_dataset, subset_indices) + transform_compose = transforms.Compose([transform]) + + transformed_ds = ISICDataset( + subset, + preprocessor, + resolution, + transform_compose, + model_type + ) + transformed_datasets.append(transformed_ds) + + # Add remaining samples without transformation + remaining_indices = np.setdiff1d(indices, used_indices) + if len(remaining_indices) > 0: + remaining_subset = Subset(train_dataset, remaining_indices) + untransformed_ds = ISICDataset( + remaining_subset, + preprocessor, + resolution, + None, + model_type + ) + transformed_datasets.append(untransformed_ds) + + # Combine all training datasets + train_ds = ConcatDataset(transformed_datasets) + + # Create validation dataset (no transformations) + val_ds = ISICDataset( + val_dataset, + preprocessor, + resolution, + model_type=model_type, + ) + + return train_ds, val_ds + +def balance_dataset(dataset: Dataset, filtered_classes: List[str], num_train_images: int, seed: int = 42): + """ + Balance dataset by sampling equal numbers from each class. + + Args: + dataset: Input dataset + filtered_classes: Classes to keep + num_train_images: Total number of training images + seed: Random seed + + Returns: + Balanced dataset + """ + # Get class counts + class_counts = {label: 0 for label in filtered_classes} + class_indices = {label: [] for label in filtered_classes} + + for i, item in enumerate(dataset): + label_str = str(item["label"]) + if label_str in filtered_classes: + class_counts[label_str] += 1 + class_indices[label_str].append(i) + + print(f"Class counts before balancing: {class_counts}") + + # Calculate samples per class + min_class_size = min(class_counts.values()) + images_per_class = min(num_train_images // len(filtered_classes), min_class_size) + + # Sample from each class + np.random.seed(seed) + balanced_indices = [] + + for label in filtered_classes: + indices = class_indices[label] + sampled = np.random.choice(indices, images_per_class, replace=False) + balanced_indices.extend(sampled) + + np.random.shuffle(balanced_indices) + return dataset.select(balanced_indices) \ No newline at end of file diff --git a/src/data/datamodule.py b/src/data/datamodule.py new file mode 100644 index 0000000..b2e9a02 --- /dev/null +++ b/src/data/datamodule.py @@ -0,0 +1,82 @@ +# src/data/datamodule.py +# -*- coding: utf-8 -*- +from torch.utils.data import DataLoader, random_split +from typing import Optional +from .datasets import get_dataset + + +class BaseDataModule: + """ + Dataset-agnostic data module. + + Responsibilities: + - Instantiate train/val/test datasets using `get_dataset()`. + - Build and cache dataloaders. + - Apply consistent transforms and batching options across datasets. + """ + + def __init__( + self, + cfg, + dataset_name: str, + data_dir: str, + num_workers: int = 8, + batch_size: int = 32, + pin_memory: bool = True, + drop_last: bool = False, + ): + self.cfg = cfg + self.dataset_name = dataset_name + self.data_dir = data_dir + self.num_workers = num_workers + self.batch_size = batch_size + self.pin_memory = pin_memory + self.drop_last = drop_last + + self.train_set = None + self.val_set = None + self.test_set = None + + # ------------------------------------------------------------------ + def setup(self, stage: Optional[str] = None): + """ + Called once to initialize datasets. + stage: 'fit' | 'validate' | 'test' | None + """ + ds_full = get_dataset(self.dataset_name, self.data_dir, split="train", cfg=self.cfg) + n_total = len(ds_full) + n_val = int(0.1 * n_total) + n_train = n_total - n_val + self.train_set, self.val_set = random_split(ds_full, [n_train, n_val]) + + # test set + self.test_set = get_dataset(self.dataset_name, self.data_dir, split="test", cfg=self.cfg) + + # ------------------------------------------------------------------ + def train_dataloader(self): + return DataLoader( + self.train_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + drop_last=self.drop_last, + ) + + def val_dataloader(self): + return DataLoader( + self.val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) + + def test_dataloader(self): + return DataLoader( + self.test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) diff --git a/src/data/datasets.py b/src/data/datasets.py new file mode 100644 index 0000000..8568949 --- /dev/null +++ b/src/data/datasets.py @@ -0,0 +1,14 @@ +# src/data/datasets.py +from src.data.isic_dataset import ISICDataset +from src.data.tcga_dataset import TCGADataset +from src.data.merlin_dataset import MerlinDataset + +def get_dataset(dataset_name, data_dir, split, cfg): + if dataset_name.lower() == "isic": + return ISICDataset(data_dir=data_dir, split=split, cfg=cfg) + elif dataset_name.lower() == "tcga": + return TCGADataset(data_dir=data_dir, split=split, cfg=cfg) + elif dataset_name.lower() == "merlin": + return MerlinDataset(data_dir=data_dir, split=split, cfg=cfg) + else: + raise ValueError(f"Unknown dataset: {dataset_name}") diff --git a/src/compressed_perception/models/__init__.py b/src/engines/distill_engine.py similarity index 100% rename from src/compressed_perception/models/__init__.py rename to src/engines/distill_engine.py diff --git a/src/compressed_perception/models/comparison/__init__.py b/src/engines/finetune_engine.py similarity index 100% rename from src/compressed_perception/models/comparison/__init__.py rename to src/engines/finetune_engine.py diff --git a/src/engines/linear_probe_engine.py b/src/engines/linear_probe_engine.py new file mode 100644 index 0000000..9567b6f --- /dev/null +++ b/src/engines/linear_probe_engine.py @@ -0,0 +1,224 @@ +# src/engines/linear_probe_engine.py +# -*- coding: utf-8 -*- +"""Linear probing engine for training classification heads on frozen backbones.""" +from __future__ import annotations +from typing import Dict, Any, Tuple, Optional +import math +import torch +from torch import nn +from torch.utils.data import DataLoader + +# pylint: disable=import-error +from src.utils.logging import get_logger, MetricAverager, WandbLogger +from src.utils.optim import step_scheduler + +log = get_logger(__name__) + + +def _unpack_loaders(loaders: Any) -> Tuple[DataLoader, Optional[DataLoader], Optional[DataLoader]]: + """ + Accepts either: + - {"train": ..., "val": ..., "test": ...} + - (train_loader, val_loader) + Returns (train, val, test) + """ + if isinstance(loaders, dict): + return loaders.get("train"), loaders.get("val"), loaders.get("test") + if isinstance(loaders, (tuple, list)) and len(loaders) >= 2: + return loaders[0], loaders[1], loaders[2] if len(loaders) > 2 else None + raise ValueError( + "`loaders` must be a dict with keys train/val[/test] or a (train,val) tuple." + ) + + +@torch.no_grad() +def _evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Dict[str, float]: + model.eval() + correct = 0 + total = 0 + running_loss = 0.0 + + # if the model exposes a criterion attribute, prefer the external loss + # passed in train loop anyway. + ce = nn.CrossEntropyLoss() + + for batch in loader: + # support (x,y) or dict + if isinstance(batch, dict): + x = (batch.get("pixel_values") or + batch.get("images") or + batch.get("x")) + y = batch.get("labels") or batch.get("y") + else: + x, y = batch + + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + + logits = model(x) + if isinstance(logits, (tuple, list)): + logits = logits[0] + loss = ce(logits, y) + running_loss += float(loss.item()) * x.size(0) + + preds = logits.argmax(dim=-1) + correct += int((preds == y).sum().item()) + total += int(y.numel()) + + acc = correct / max(total, 1) + avg_loss = running_loss / max(total, 1) + return {"val_loss": avg_loss, "val_acc": acc} + + +def train_probe( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-branches,too-many-statements + model: nn.Module, + loaders: Any, + loss_fn, # callable: (logits, targets) -> loss tensor + optimizer: torch.optim.Optimizer, + scheduler: Any = None, # either a scheduler or (scheduler, meta) + device: torch.device = torch.device("cpu"), + epochs: int = 10, + grad_clip: Optional[float] = None, + mixed_precision: bool = True, + log_interval: int = 50, + wandb_logger: Optional[WandbLogger] = None, + metric_key: str = "val_acc", +) -> Dict[str, Any]: + """ + Linear probing engine. + Expects the backbone to be frozen already; only the head should have requires_grad=True. + """ + train_loader, val_loader, _ = _unpack_loaders(loaders) + assert train_loader is not None, "train_loader is required" + assert val_loader is not None, "val_loader is required" + + scaler = torch.amp.GradScaler(enabled=(mixed_precision and device.type == "cuda")) + + # unwrap (scheduler, meta) if it came from our utils + sched, sched_meta = None, {"by": "epoch"} + if scheduler is not None: + if isinstance(scheduler, tuple) and len(scheduler) == 2: + sched, sched_meta = scheduler[0], scheduler[1] + else: + sched = scheduler + + # sanity: log trainable parameter ratio + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + log.info(f"[probe] trainable params: {trainable_params:,} / {total_params:,} " + f"({100.0 * trainable_params / max(total_params,1):.2f}%)") + + history = {"train_loss": [], "val_loss": [], "val_acc": []} + best_metric = -math.inf + best_state: Optional[Dict[str, torch.Tensor]] = None + + model.to(device) + + for epoch in range(1, epochs + 1): + model.train() + ma = MetricAverager() + + for it, batch in enumerate(train_loader, start=1): + if isinstance(batch, dict): + x = (batch.get("pixel_values") or + batch.get("images") or + batch.get("x")) + y = batch.get("labels") or batch.get("y") + else: + x, y = batch + + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + + optimizer.zero_grad(set_to_none=True) + + if scaler.is_enabled(): + with torch.amp.autocast(device_type=device.type): + logits = model(x) + if isinstance(logits, (tuple, list)): + logits = logits[0] + loss = loss_fn(logits, y) + scaler.scale(loss).backward() + if grad_clip is not None and grad_clip > 0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_( + filter(lambda p: p.requires_grad, model.parameters()), + max_norm=grad_clip + ) + scaler.step(optimizer) + scaler.update() + else: + logits = model(x) + if isinstance(logits, (tuple, list)): + logits = logits[0] + loss = loss_fn(logits, y) + loss.backward() + if grad_clip is not None and grad_clip > 0: + torch.nn.utils.clip_grad_norm_( + filter(lambda p: p.requires_grad, model.parameters()), + max_norm=grad_clip + ) + optimizer.step() + + # running metrics + with torch.no_grad(): + preds = logits.argmax(dim=-1) + acc = (preds == y).float().mean().item() + ma.update(loss=float(loss.item()), acc=acc) + + if it % log_interval == 0: + avgs = ma.averages() + log.info(f"[probe] epoch {epoch:03d} iter {it:05d} " + f"| loss {avgs['loss']:.4f} | acc {avgs['acc']:.4f}") + if wandb_logger: + wandb_logger.log({ + "train/loss": avgs["loss"], + "train/acc": avgs["acc"], + "epoch": epoch, + "iter": it, + }) + + # end epoch → validation + val_metrics = _evaluate(model, val_loader, device=device) + history["train_loss"].append(ma.averages()["loss"]) + history["val_loss"].append(val_metrics["val_loss"]) + history["val_acc"].append(val_metrics["val_acc"]) + + log.info(f"[probe] epoch {epoch:03d} | val_loss {val_metrics['val_loss']:.4f} " + f"| val_acc {val_metrics['val_acc']:.4f}") + + if wandb_logger: + wandb_logger.log({ + "val/loss": val_metrics["val_loss"], + "val/acc": val_metrics["val_acc"], + "epoch": epoch, + }) + + # scheduler step + if sched is not None: + # if ReduceLROnPlateau, step on val metric; else per-epoch + if hasattr(sched, "step") and sched_meta.get("by") == "val_metric": + # for plateau: lower is better typically; if you want higher-better, pass negative + step_scheduler(sched, sched_meta, epoch=epoch, val_metric=val_metrics["val_loss"]) + else: + step_scheduler(sched, sched_meta, epoch=epoch) + + # track best + current = val_metrics.get(metric_key, val_metrics["val_acc"]) + if current > best_metric: + best_metric = current + best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} + + # restore best weights (optional but standard) + if best_state is not None: + model.load_state_dict(best_state) + + return { + "history": history, + "best_metric": float(best_metric), + "best_metric_name": metric_key, + "final_val_acc": float(history["val_acc"][-1]) if history["val_acc"] else None, + "final_val_loss": float(history["val_loss"][-1]) if history["val_loss"] else None, + "trainable_params": trainable_params, + "total_params": total_params, + } diff --git a/src/evaluation/metrics.py b/src/evaluation/metrics.py new file mode 100644 index 0000000..459c275 --- /dev/null +++ b/src/evaluation/metrics.py @@ -0,0 +1,63 @@ +# src/metrics/metrics.py +# -*- coding: utf-8 -*- +from __future__ import annotations +from typing import Dict, Tuple, Optional + +import numpy as np +import torch +from sklearn.metrics import accuracy_score, f1_score, roc_auc_score + +from src.evaluation.visualization import save_confusion_matrix, save_class_distribution + +def _softmax_np(logits: np.ndarray) -> np.ndarray: + logits_t = torch.tensor(logits) + probs = torch.softmax(logits_t, dim=1).cpu().numpy() + return probs + +def compute_metrics( + eval_pred: Tuple[np.ndarray, np.ndarray], + model_name: Optional[str] = None, + average: str = "weighted", + save_viz: bool = True, +) -> Dict[str, float]: + """ + Compute accuracy, F1, and AUC (binary or multiclass-ovr). + - eval_pred: (logits [N,C], labels [N]) + """ + logits, labels = eval_pred + labels = np.asarray(labels) + preds = np.argmax(logits, axis=-1) + probs = _softmax_np(logits) + + # core metrics + acc = accuracy_score(labels, preds) + f1 = f1_score(labels, preds, average=average) + + # AUC: binary if C==2 else OvR; guard for degenerate cases + try: + if probs.shape[1] == 2: + auc = roc_auc_score(labels, probs[:, 1]) + else: + auc = roc_auc_score(labels, probs, multi_class="ovr", average=average) + except Exception: + auc = float("nan") + + # optional viz + if save_viz and model_name: + try: + save_confusion_matrix(labels, preds, model_name=model_name, normalize=False) + save_class_distribution(preds, model_name=model_name) + except Exception as e: + print(f"[compute_metrics] Visualization failed: {e}") + + return {"accuracy": float(acc), "f1": float(f1), "auc": float(auc)} + +def create_compute_metrics_fn( + model_name: Optional[str] = None, + average: str = "weighted", + save_viz: bool = True, +): + """Closure for 🤗 Trainer: returns a callable(eval_pred)->metrics dict.""" + def _fn(eval_pred): + return compute_metrics(eval_pred, model_name=model_name, average=average, save_viz=save_viz) + return _fn diff --git a/src/evaluation/visualization.py b/src/evaluation/visualization.py new file mode 100644 index 0000000..913f99d --- /dev/null +++ b/src/evaluation/visualization.py @@ -0,0 +1,55 @@ +# src/metrics/visualization.py +# -*- coding: utf-8 -*- +from __future__ import annotations +import os +from typing import Iterable + +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.metrics import confusion_matrix + +from src.utils.training_utils import env_path + +def _plot_dir(model_name: str) -> str: + out = os.path.join(env_path("PLOT_DIR", "./plots"), model_name) + os.makedirs(out, exist_ok=True) + return out + +def save_confusion_matrix( + labels: Iterable[int], + predictions: Iterable[int], + model_name: str, + normalize: bool = False, + filename: str = "conf_mat.png", +): + """Save a confusion matrix image (optionally normalized).""" + labels = np.asarray(labels) + preds = np.asarray(predictions) + cm = confusion_matrix(labels, preds) + if normalize: + cm = cm.astype("float") / (cm.sum(axis=1, keepdims=True) + 1e-12) + + plt.figure(figsize=(8, 8)) + sns.heatmap(cm, annot=True, fmt=".2f" if normalize else "d", cmap="Blues") + plt.xlabel("Predicted") + plt.ylabel("True") + plt.title(f"{model_name} — Confusion Matrix" + (" (norm.)" if normalize else "")) + out = os.path.join(_plot_dir(model_name), filename) + plt.savefig(out, dpi=300, bbox_inches="tight") + plt.close() + +def save_class_distribution( + predictions: Iterable[int], + model_name: str, + filename: str = "class_breakdown.json", +): + """Save class histogram of predictions as JSON.""" + preds = np.asarray(predictions) + unique, counts = np.unique(preds, return_counts=True) + data = {str(int(k)): int(v) for k, v in zip(unique, counts)} + + out = os.path.join(_plot_dir(model_name), filename) + with open(out, "w") as f: + import json + json.dump(data, f, indent=2) diff --git a/src/evaluation/visualize_results.py b/src/evaluation/visualize_results.py new file mode 100644 index 0000000..3e3fb53 --- /dev/null +++ b/src/evaluation/visualize_results.py @@ -0,0 +1,173 @@ +# src/visualize_results.py +"""Visualize robustness results across models and degradations.""" +import json +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import numpy as np +from pathlib import Path + +def create_robustness_heatmap(results_path: str): + """Create heatmap showing model performance across degradations.""" + + # Load results + with open(results_path, 'r') as f: + results = json.load(f) + + # Prepare data for heatmap + models = [] + degradations = [] + accuracies = [] + + for training_mode in ['finetune', 'linear_probe']: + for model_name, model_results in results[training_mode].items(): + if 'error' not in model_results: + eval_results = model_results['eval_results_by_degradation'] + + for degradation, metrics in eval_results.items(): + models.append(f"{model_name}_{training_mode[:2]}") + degradations.append(degradation) + accuracies.append(metrics['accuracy']) + + # Create DataFrame + df = pd.DataFrame({ + 'Model': models, + 'Degradation': degradations, + 'Accuracy': accuracies + }) + + # Pivot for heatmap + pivot_df = df.pivot(index='Model', columns='Degradation', values='Accuracy') + + # Reorder columns logically + column_order = ['clean', + 'jpeg_90', 'jpeg_50', 'jpeg_20', + 'blur_1.0', 'blur_3.0', 'blur_5.0', + 'color_64', 'color_16', 'color_4'] + pivot_df = pivot_df[column_order] + + # Create figure + plt.figure(figsize=(14, 8)) + + # Create heatmap + sns.heatmap(pivot_df, + annot=True, + fmt='.3f', + cmap='RdYlGn', + vmin=0.5, + vmax=1.0, + cbar_kws={'label': 'Accuracy'}) + + plt.title('Model Robustness Across Different Degradations') + plt.xlabel('Degradation Type') + plt.ylabel('Model (ft=finetune, lp=linear probe)') + plt.tight_layout() + + # Save figure + output_path = Path(results_path).parent / 'robustness_heatmap.png' + plt.savefig(output_path, dpi=300, bbox_inches='tight') + plt.show() + + return pivot_df + +def create_degradation_curves(results_path: str): + """Create line plots showing accuracy degradation.""" + + with open(results_path, 'r') as f: + results = json.load(f) + + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + + # JPEG degradation + jpeg_qualities = [90, 50, 20] + ax = axes[0] + + for training_mode in ['finetune', 'linear_probe']: + for model_name, model_results in results[training_mode].items(): + if 'error' not in model_results: + eval_results = model_results['eval_results_by_degradation'] + + clean_acc = eval_results['clean']['accuracy'] + jpeg_accs = [eval_results[f'jpeg_{q}']['accuracy'] for q in jpeg_qualities] + + label = f"{model_name} ({training_mode})" + ax.plot([100] + jpeg_qualities, [clean_acc] + jpeg_accs, + marker='o', label=label) + + ax.set_xlabel('JPEG Quality') + ax.set_ylabel('Accuracy') + ax.set_title('JPEG Compression Robustness') + ax.legend() + ax.grid(True, alpha=0.3) + ax.invert_xaxis() + + # Blur degradation + blur_radii = [0, 1.0, 3.0, 5.0] + ax = axes[1] + + for training_mode in ['finetune', 'linear_probe']: + for model_name, model_results in results[training_mode].items(): + if 'error' not in model_results: + eval_results = model_results['eval_results_by_degradation'] + + blur_accs = [] + blur_accs.append(eval_results['clean']['accuracy']) + for r in blur_radii[1:]: + blur_accs.append(eval_results[f'blur_{r:.1f}']['accuracy']) + + label = f"{model_name} ({training_mode})" + ax.plot(blur_radii, blur_accs, marker='o', label=label) + + ax.set_xlabel('Blur Radius') + ax.set_ylabel('Accuracy') + ax.set_title('Gaussian Blur Robustness') + ax.legend() + ax.grid(True, alpha=0.3) + + # Color quantization + color_levels = [256, 64, 16, 4] + ax = axes[2] + + for training_mode in ['finetune', 'linear_probe']: + for model_name, model_results in results[training_mode].items(): + if 'error' not in model_results: + eval_results = model_results['eval_results_by_degradation'] + + color_accs = [] + color_accs.append(eval_results['clean']['accuracy']) + for c in color_levels[1:]: + color_accs.append(eval_results[f'color_{c}']['accuracy']) + + label = f"{model_name} ({training_mode})" + ax.plot(color_levels, color_accs, marker='o', label=label) + + ax.set_xlabel('Number of Colors') + ax.set_ylabel('Accuracy') + ax.set_title('Color Quantization Robustness') + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_xscale('log', base=2) + ax.invert_xaxis() + + plt.suptitle('Model Robustness to Different Degradation Types') + plt.tight_layout() + + # Save figure + output_path = Path(results_path).parent / 'degradation_curves.png' + plt.savefig(output_path, dpi=300, bbox_inches='tight') + plt.show() + +# Usage +if __name__ == "__main__": + results_path = "results/results_comprehensive_lr0.0001_bs128_ep3.json" + + # Create visualizations + pivot_df = create_robustness_heatmap(results_path) + create_degradation_curves(results_path) + + # Print summary statistics + print("\nModel Rankings by Robustness:") + print("-" * 40) + robustness_scores = pivot_df.mean(axis=1).sort_values(ascending=False) + for model, score in robustness_scores.items(): + print(f"{model}: {score:.3f}") \ No newline at end of file diff --git a/src/compressed_perception/models/training/__init__.py b/src/losses/__init__.py similarity index 100% rename from src/compressed_perception/models/training/__init__.py rename to src/losses/__init__.py diff --git a/src/losses/classification.py b/src/losses/classification.py new file mode 100644 index 0000000..4168dee --- /dev/null +++ b/src/losses/classification.py @@ -0,0 +1,35 @@ +# src/losses/classification.py +# -*- coding: utf-8 -*- +from typing import Optional +import torch.nn.functional as F +from torch import Tensor + + +def cross_entropy_loss( + label_smoothing: float = 0.0, + class_weight: Optional[Tensor] = None, + ignore_index: int = -100, + reduction: str = "mean", +): + """ + Standard cross-entropy with optional label smoothing and class weights. + + Args: + label_smoothing: in [0,1). 0 = vanilla CE. + class_weight: shape [C] tensor of per-class weights (on same device). + ignore_index: targets with this index are ignored. + reduction: 'none' | 'mean' | 'sum' + + Returns: + Callable loss(logits [B,C], targets [B]) + """ + def _loss(logits: Tensor, targets: Tensor) -> Tensor: + return F.cross_entropy( + logits, + targets, + weight=class_weight, + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing, + ) + return _loss diff --git a/src/losses/distillation.py b/src/losses/distillation.py new file mode 100644 index 0000000..27e985e --- /dev/null +++ b/src/losses/distillation.py @@ -0,0 +1,84 @@ +# src/losses/distillation.py +# -*- coding: utf-8 -*- +from typing import Dict +import torch.nn.functional as F +from torch import Tensor + + +def cosine_loss(reduction: str = "mean"): + """ + Cosine embedding loss for feature alignment. + Expects student and teacher embeddings with shape [B, D]. + """ + def _loss(s_embed: Tensor, t_embed: Tensor) -> Tensor: + s_norm = F.normalize(s_embed, dim=-1) + t_norm = F.normalize(t_embed, dim=-1) + sim = (s_norm * t_norm).sum(dim=-1) + loss = 1.0 - sim # minimize (1 - cos) + if reduction == "mean": + return loss.mean() + elif reduction == "sum": + return loss.sum() + return loss + return _loss + + +def kl_divergence_loss(temperature: float = 1.0, reduction: str = "batchmean"): + """ + KL divergence on logits with temperature scaling (Hinton et al., 2015). + """ + T = float(temperature) + T2 = T * T + + def _loss(s_logits: Tensor, t_logits: Tensor) -> Tensor: + log_p_s = F.log_softmax(s_logits / T, dim=-1) + p_t = F.softmax(t_logits / T, dim=-1) + kl = F.kl_div(log_p_s, p_t, reduction=reduction) + return kl * T2 + return _loss + + +def hybrid_distillation_loss( + alpha: float = 0.5, # cosine(embeds) + beta: float = 0.5, # KL(logits) + gamma: float = 0.0, # optional CE to ground truth + temperature: float = 2.0, + ce_loss_fn=None, +): + """ + Combined distillation objective: + L = alpha * (1 - cos(s_embed, t_embed)) + + beta * KL(softmax(s/T), softmax(t/T)) * T^2 + + gamma * CE(s_logits, y) + + Returns: + Callable loss_dict(outputs) where outputs = { + "s_logits", "t_logits", "s_embed", "t_embed", "targets" + } + """ + cos_fn = cosine_loss(reduction="mean") + kl_fn = kl_divergence_loss(temperature=temperature, reduction="batchmean") + + def _loss(outputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + s_logits = outputs["s_logits"] + t_logits = outputs["t_logits"] + s_embed = outputs["s_embed"] + t_embed = outputs["t_embed"] + + l_cos = cos_fn(s_embed, t_embed) if alpha > 0 else s_logits.new_zeros(()) + l_kl = kl_fn(s_logits, t_logits) if beta > 0 else s_logits.new_zeros(()) + + if gamma > 0 and ce_loss_fn is not None: + targets = outputs["targets"] + l_ce = ce_loss_fn(s_logits, targets) + else: + l_ce = s_logits.new_zeros(()) + + total = alpha * l_cos + beta * l_kl + gamma * l_ce + return { + "loss_total": total, + "loss_cos": l_cos.detach(), + "loss_kl": l_kl.detach(), + "loss_ce": l_ce.detach(), + } + return _loss diff --git a/src/models/factory.py b/src/models/factory.py new file mode 100644 index 0000000..b7210a0 --- /dev/null +++ b/src/models/factory.py @@ -0,0 +1,191 @@ +# src/models/factory.py +# -*- coding: utf-8 -*- +"""Unified model factory: HF vision models (ViT/DINOv2), optional timm, + plus matching preprocessors and helpers (freeze_backbone, save_model).""" + +from typing import Dict, Any + +import os +import torch.nn as nn +from PIL import Image + +# --- Hugging Face --- +from transformers import ( + ViTForImageClassification, + AutoModelForImageClassification, + ViTFeatureExtractor, + AutoImageProcessor, +) + +# --- Optional timm --- +try: + import timm # type: ignore + _TIMM_AVAILABLE = True +except Exception: + _TIMM_AVAILABLE = False + +# --- Project constants (small change from your code: avoid importing configs directly) --- +try: + from src.utils.constants import HF_MODELS # e.g., {"vit", "dinov2"} +except Exception: + # Fallback if constants not present yet + HF_MODELS = {"vit", "dinov2"} + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def create_model(model_info: Dict[str, Any], resolution: int = 224): + """ + Factory function to create models based on type. + + Args: + model_info: {"type": "vit"|"dinov2"|("timm"), "model_id": str, "config": {...}} + resolution: input image resolution + + Returns: + nn.Module + """ + model_type = model_info["type"] + model_id = model_info["model_id"] + config = model_info.get("config", {}) + + # --- HuggingFace ViT --- + if model_type == "vit": + return ViTForImageClassification.from_pretrained( + model_id, + num_labels=config["num_labels"], + ignore_mismatched_sizes=config.get("ignore_mismatched_sizes", True), + image_size=resolution, + ) + + # --- HuggingFace DINOv2 (AutoModel) --- + elif model_type == "dinov2": + return AutoModelForImageClassification.from_pretrained( + model_id, + num_labels=config["num_labels"], + ignore_mismatched_sizes=config.get("ignore_mismatched_sizes", True), + image_size=resolution, + ) + + # --- Optional timm branch (kept minimal, no breaking changes) --- + elif model_type == "timm": + if not _TIMM_AVAILABLE: + raise RuntimeError("timm is not installed but model_type='timm' was requested.") + num_classes = int(config.get("num_labels", 1000)) + pretrained = bool(config.get("pretrained", True)) + # If you need image_size-specific config, many timm models accept it via `img_size` + return timm.create_model(model_id, pretrained=pretrained, num_classes=num_classes) + + else: + raise ValueError(f"Unknown model type: {model_type}") + + +def create_preprocessor(model_info: Dict[str, Any], resolution: int = 224): + """ + Create appropriate preprocessor for model type. + + Args: + model_info: Dictionary with model configuration + resolution: Input image resolution + + Returns: + HF preprocessor (FeatureExtractor/ImageProcessor) or None (for timm) + """ + model_type = model_info["type"] + model_id = model_info["model_id"] + + if model_type == "vit": + return ViTFeatureExtractor.from_pretrained( + model_id, + size=resolution, + do_resize=True, + resample=Image.LANCZOS, + do_normalize=True, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + ) + + elif model_type == "dinov2": + return AutoImageProcessor.from_pretrained( + model_id, + size=resolution, + do_resize=True, + resample=Image.LANCZOS, + do_normalize=True, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + ) + + elif model_type == "timm": + # timm uses torchvision transforms; return None and build transforms in your datamodule + return None + + else: + raise ValueError(f"Unknown model type: {model_type}") + + +def freeze_backbone(model: nn.Module, model_type: str): + """ + Freeze backbone parameters for transfer learning. + For HF classifiers, keep 'classifier' or 'head' trainable; freeze the rest. + + Args: + model: nn.Module + model_type: 'vit' | 'dinov2' | ('timm' if you wire it similarly) + """ + if model_type in HF_MODELS: + for name, param in model.named_parameters(): + # Leave classifier/head trainable, freeze others + if ("classifier" in name) or ("head" in name): + param.requires_grad = True + else: + param.requires_grad = False + elif model_type == "timm": + # Optional: implement project-specific rules (e.g., freeze all except last classifier) + for name, param in model.named_parameters(): + if ("classifier" in name) or ("fc" in name) or ("head" in name): + param.requires_grad = True + else: + param.requires_grad = False + else: + raise ValueError(f"Unsupported model_type: {model_type}") + + +def save_model(model: nn.Module, model_info: Dict[str, Any], save_dir: str, preprocessor=None): + """ + Save model based on its type. + + Args: + model: Model to save + model_info: Model configuration + save_dir: Directory to save to + preprocessor: Optional HF preprocessor to save + """ + os.makedirs(save_dir, exist_ok=True) + model_type = model_info["type"] + + if model_type in HF_MODELS: + model.save_pretrained(save_dir) + if preprocessor is not None: + preprocessor.save_pretrained(save_dir) + elif model_type == "timm": + # Torch-style checkpoint for timm models + import torch + ckpt_path = os.path.join(save_dir, "pytorch_model.bin") + torch.save(model.state_dict(), ckpt_path) + # Minimal config export + with open(os.path.join(save_dir, "config.json"), "w") as f: + import json + json.dump( + { + "model_type": "timm", + "model_id": model_info["model_id"], + "num_labels": model_info.get("config", {}).get("num_labels", None), + }, + f, + indent=2, + ) + else: + raise ValueError(f"Unsupported model_type: {model_type}") diff --git a/src/requirements.txt b/src/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2d330480365fea94d5cdc9f0c7ab6c125710cd23 GIT binary patch literal 1716 zcmZvcTaKGR5JmetQkLS6C|Sttz`};H@nB4V89x@ELrr>z0^UP55V4Z$$J<%!dCey<>0FE8her`_yl&EktOgV{vEp5h zRTJa&iyNBh)%uzvLtiTlPW$wONBYDpGlG_mbuzv2>0iSj2J@B>OP~*TZn`MviTBqGMVhvySSDksu{@*XlsOOEy zGp`fWp=Q~`e(!QW-!q1NR}1Z>KKK60z5=h$9IVeMzxHnp$^_FMj%CJsc2T*QniGFA zukf2W@V7;+3O1rf=d|`&)##aP=QyTgAkrLXZ=#F(f`^_i*Kl6YJG^tA1Sem;#fhDr zHkrxWSyUnQ3RMTP zoe?#ER^@TxS9w<_;+f-f#PcaqrQgIyuVSYOZtYU_3P#>M&aHKXuXN6HT``wdQv6Q8 zvsctBm@Vfq5PHyGdCD}Dn{2|Oj?p(~C+48k6}OO@mRl&!PMmlDt;Un*nY43TFrp`M zuQ+4gBk~B#maOO12lz56cc?dzlQ-Qb#`6+Y&1hG{CJ>BsQdDfdN7%LKSDcZFwf9oF zwIcTs^5fn2eaGK@leK_*h%*u`k6VMGxM6YyPVn8b6 Image.Image: + """Apply resolution reduction.""" + if self.reduction_factor is None: + # Random reduction factor between 0.2 and 0.8 + reduction_factor = np.random.uniform(0.2, 0.8) + else: + reduction_factor = self.reduction_factor + + # Clamp reduction factor to valid range + reduction_factor = max(0.1, min(1.0, reduction_factor)) + + # Calculate new size + original_width, original_height = img.size + new_width = int(original_width * reduction_factor) + new_height = int(original_height * reduction_factor) + + # Ensure minimum size of 1x1 + new_width = max(1, new_width) + new_height = max(1, new_height) + + # Downsample and then upsample back to original size + downsampled = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + return downsampled.resize((original_width, original_height), Image.Resampling.LANCZOS) + +class JPEGCompressionTransform: # pylint: disable=too-few-public-methods + """Apply JPEG compression to images.""" + + def __init__(self, quality: Optional[int] = None): + """ + Args: + quality: JPEG quality (1-100). If None, random quality is used. + """ + self.quality = quality + + def __call__(self, img: Image.Image) -> Image.Image: + """Apply JPEG compression.""" + if self.quality is None: + quality = np.random.randint(10, 100) + else: + quality = self.quality + + # Save to bytes with JPEG compression + buffer = io.BytesIO() + img.save(buffer, format='JPEG', quality=quality) + buffer.seek(0) + + # Load back + return Image.open(buffer) + +class GaussianBlurTransform: # pylint: disable=too-few-public-methods + """Apply Gaussian blur to images.""" + + def __init__(self, radius: Optional[float] = None): + """ + Args: + radius: Blur radius. If None, random radius is used. + """ + self.radius = radius + + def __call__(self, img: Image.Image) -> Image.Image: + """Apply Gaussian blur.""" + if self.radius is None: + radius = np.random.uniform(0.5, 5.0) + else: + radius = self.radius + + return img.filter(ImageFilter.GaussianBlur(radius=radius)) + +class ColorQuantizationTransform: + """Reduce color palette of images.""" + + def __init__(self, n_colors: Optional[int] = None): + """ + Args: + n_colors: Number of colors. If None, random value is used. + """ + self.n_colors = n_colors + + def __call__(self, img: Image.Image) -> Image.Image: + """Apply color quantization.""" + if self.n_colors is None: + n_colors = np.random.randint(4, 128) + else: + n_colors = self.n_colors + + return img.quantize(colors=n_colors, method=Image.Quantize.MEDIANCUT).convert("RGB") + + +def get_degradation_transforms(): + """Get default list of degradation transforms.""" + return [ + ResolutionReductionTransform(), + ] diff --git a/src/utils/callbacks_hf.py b/src/utils/callbacks_hf.py new file mode 100644 index 0000000..630399a --- /dev/null +++ b/src/utils/callbacks_hf.py @@ -0,0 +1,55 @@ +# src/utils/callbacks_hf.py +# -*- coding: utf-8 -*- +from __future__ import annotations +import os +import json +from typing import Optional, Dict, Any +from transformers import TrainerCallback # type: ignore + +from src.utils.training_utils import get_gpu_memory + +# Optional W&B; safe if not installed +try: + import wandb # type: ignore + _WANDB = True +except Exception: + _WANDB = False + +class LossLoggerCallback(TrainerCallback): + """Append step logs to a JSONL file (robust, HF-friendly).""" + def __init__(self, log_dir: str, phase: str, model_name: str): + os.makedirs(log_dir, exist_ok=True) + self.log_file = os.path.join(log_dir, f"{model_name}_{phase}_log.jsonl") + + def on_log(self, args, state, control, logs=None, **kwargs): + if not logs: + return + payload: Dict[str, Any] = {"step": state.global_step, **logs} + with open(self.log_file, "a") as f: + f.write(json.dumps(payload) + "\n") + +class WandbCallback(TrainerCallback): + """Minimal W&B logger that adds model/phase and GPU memory (if available).""" + def __init__(self, model_name: str, phase: str): + self.model_name = model_name + self.phase = phase + self.best_accuracy = 0.0 + + def on_log(self, args, state, control, logs=None, **kwargs): + if not _WANDB or not logs: + return + logs = dict(logs) + logs["model"] = self.model_name + logs["phase"] = self.phase + mem = get_gpu_memory() + if mem > 0: + logs["gpu_memory_mb"] = mem + wandb.log(logs) + + def on_evaluate(self, args, state, control, metrics=None, **kwargs): + if not _WANDB or not metrics: + return + if "eval_accuracy" in metrics: + self.best_accuracy = max(self.best_accuracy, metrics["eval_accuracy"]) + metrics = dict(metrics, best_accuracy=self.best_accuracy) + wandb.log(metrics) diff --git a/src/utils/constants.py b/src/utils/constants.py new file mode 100644 index 0000000..9f97406 --- /dev/null +++ b/src/utils/constants.py @@ -0,0 +1,19 @@ +# src/utils/constants.py +# -*- coding: utf-8 -*- +"""Global constants and lightweight enums used across the training pipeline.""" + +# --------------------------------------------------------------------------- +# Model Families +# --------------------------------------------------------------------------- + +# Hugging Face vision models supported by the unified factory +HF_MODELS = {"vit", "dinov2"} + +# --------------------------------------------------------------------------- +# Dataset defaults +# --------------------------------------------------------------------------- + +NUM_CLASSES = 1000 # update dynamically per dataset if needed +NUM_FILTERED_CLASSES = 8 # for ISIC filtered subset example + + diff --git a/src/utils/logging.py b/src/utils/logging.py new file mode 100644 index 0000000..5a5a69c --- /dev/null +++ b/src/utils/logging.py @@ -0,0 +1,114 @@ +# src/utils/logging.py +# -*- coding: utf-8 -*- +from __future__ import annotations +import logging +import json +import os +from pathlib import Path +from typing import Dict, Any, Optional + +# Optional: Weights & Biases +_WANDB_AVAILABLE = False +try: + import wandb # type: ignore + _WANDB_AVAILABLE = True +except Exception: + _WANDB_AVAILABLE = False + + +def setup_logging(level: int = logging.INFO) -> None: + """Configure root logger format/level.""" + fmt = "[%(asctime)s] %(levelname)s - %(name)s: %(message)s" + datefmt = "%H:%M:%S" + logging.basicConfig(level=level, format=fmt, datefmt=datefmt) + + +def get_logger(name: str) -> logging.Logger: + """Get a module-specific logger.""" + return logging.getLogger(name) + + +class MetricAverager: + """ + Track running averages of named scalars. + Usage: + ma = MetricAverager() + ma.update(loss=0.1, acc=0.9) + avgs = ma.averages() # {"loss": ..., "acc": ...} + """ + def __init__(self) -> None: + self.totals: Dict[str, float] = {} + self.counts: Dict[str, int] = {} + + def update(self, **kwargs: float) -> None: + for k, v in kwargs.items(): + self.totals[k] = self.totals.get(k, 0.0) + float(v) + self.counts[k] = self.counts.get(k, 0) + 1 + + def averages(self) -> Dict[str, float]: + return {k: (self.totals[k] / max(self.counts[k], 1)) for k in self.totals} + + def reset(self) -> None: + self.totals.clear() + self.counts.clear() + + +class WandbLogger: + """ + Thin W&B wrapper that is safe when wandb is not installed. + """ + def __init__( + self, + project: str, + run_name: Optional[str] = None, + config: Any = None, + enabled: Optional[bool] = None, + entity: Optional[str] = None, + tags: Optional[list[str]] = None, + ) -> None: + self.enabled = (_WANDB_AVAILABLE if enabled is None else enabled) + self.run = None + if self.enabled: + self.run = wandb.init( + project=project, + name=run_name, + config=_maybe_serialize_config(config), + entity=entity, + tags=tags or [], + reinit=True, + ) + + def log(self, metrics: Dict[str, Any], step: Optional[int] = None, commit: bool = True) -> None: + if self.enabled and self.run is not None: + wandb.log(metrics, step=step, commit=commit) + + def watch_model(self, model, log: str = "gradients", log_freq: int = 100) -> None: + if self.enabled and self.run is not None: + wandb.watch(model, log=log, log_freq=log_freq) + + def save_artifact(self, path: str, name: Optional[str] = None, type_: str = "file") -> None: + if self.enabled and self.run is not None and os.path.exists(path): + art = wandb.Artifact(name or Path(path).name, type=type_) + art.add_file(path) + self.run.log_artifact(art) + + def finish(self) -> None: + if self.enabled and self.run is not None: + self.run.finish() + + +def _maybe_serialize_config(cfg: Any) -> Dict[str, Any]: + """Best-effort: convert config objects to plain dicts.""" + try: + from omegaconf import OmegaConf # type: ignore + if isinstance(cfg, dict): + return cfg + if OmegaConf.is_config(cfg): + return OmegaConf.to_container(cfg, resolve=True) # type: ignore + except Exception: + pass + try: + json.dumps(cfg) # type: ignore + return cfg # type: ignore + except Exception: + return {} diff --git a/src/utils/optim.py b/src/utils/optim.py new file mode 100644 index 0000000..7190423 --- /dev/null +++ b/src/utils/optim.py @@ -0,0 +1,122 @@ +# src/utils/optim.py +# -*- coding: utf-8 -*- +from __future__ import annotations +from typing import Iterable, Tuple, Dict, Any +import math +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler, LambdaLR, StepLR, ReduceLROnPlateau + + +def _param_groups_decay(model_or_params: Iterable, weight_decay: float) -> list: + """ + Split parameters into (decay / no_decay) groups. + - no_decay: bias, LayerNorm/BatchNorm weights. + """ + if isinstance(model_or_params, torch.nn.Module): + params = list(model_or_params.named_parameters()) + else: + # assume an iterable of parameters + params = [(f"p{i}", p) for i, p in enumerate(model_or_params)] + + decay, no_decay = [], [] + for name, p in params: + if not p.requires_grad: + continue + if p.ndim == 1 or name.endswith(".bias") or "norm" in name.lower(): + no_decay.append(p) + else: + decay.append(p) + + return [ + {"params": decay, "weight_decay": weight_decay}, + {"params": no_decay, "weight_decay": 0.0}, + ] + + +def _build_optimizer(cfg, params) -> Optimizer: + opt_cfg = getattr(cfg, "train").get("optimizer", {}) if hasattr(cfg, "train") else {} + name = str(opt_cfg.get("name", "adamw")).lower() + lr = float(opt_cfg.get("lr", 1e-4)) + wd = float(opt_cfg.get("weight_decay", 0.05)) + + if isinstance(params, torch.nn.Module) or (hasattr(params, "__iter__") and hasattr(next(iter(params)), "ndim")): + param_groups = _param_groups_decay(params, wd) + else: + # already groups + param_groups = params + + if name == "adamw": + betas = tuple(opt_cfg.get("betas", (0.9, 0.999))) + eps = float(opt_cfg.get("eps", 1e-8)) + return torch.optim.AdamW(param_groups, lr=lr, betas=betas, eps=eps) + elif name == "sgd": + momentum = float(opt_cfg.get("momentum", 0.9)) + nesterov = bool(opt_cfg.get("nesterov", True)) + return torch.optim.SGD(param_groups, lr=lr, momentum=momentum, nesterov=nesterov) + else: + raise ValueError(f"Unsupported optimizer: {name}") + + +def _warmup_cosine_lambda_fn(epochs: int, warmup_epochs: int, min_lr_ratio: float): + def lr_lambda(current_epoch: int): + if current_epoch < warmup_epochs: + return float(current_epoch + 1) / float(max(1, warmup_epochs)) + # cosine from 1.0 -> min_lr_ratio + progress = (current_epoch - warmup_epochs) / float(max(1, epochs - warmup_epochs)) + cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine + return lr_lambda + + +def _build_scheduler(cfg, optimizer: Optimizer) -> Tuple[_LRScheduler | None, Dict[str, Any]]: + sch_cfg = getattr(cfg, "train").get("scheduler", {}) if hasattr(cfg, "train") else {} + name = str(sch_cfg.get("name", "cosine")).lower() + epochs = int(sch_cfg.get("epochs", getattr(cfg.train, "epochs", 50))) + warmup_epochs = int(sch_cfg.get("warmup_epochs", 0)) + + if name in ("cosine", "cosineanneal", "cosine_anneal"): + min_lr = float(sch_cfg.get("min_lr", 1e-6)) + base_lr = float(getattr(cfg.train.get("optimizer", {}), "lr", 1e-4) if hasattr(cfg, "train") else 1e-4) + min_lr_ratio = max(min_lr / max(base_lr, 1e-12), 0.0) + lr_lambda = _warmup_cosine_lambda_fn(epochs, warmup_epochs, min_lr_ratio) + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler, {"by": "epoch"} + elif name == "step": + step_size = int(sch_cfg.get("step_size", 30)) + gamma = float(sch_cfg.get("gamma", 0.1)) + scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma) + return scheduler, {"by": "epoch"} + elif name in ("plateau", "reduceonplateau"): + patience = int(sch_cfg.get("patience", 5)) + factor = float(sch_cfg.get("factor", 0.5)) + scheduler = ReduceLROnPlateau(optimizer, patience=patience, factor=factor) + return scheduler, {"by": "val_metric"} + elif name in ("none", "off"): + return None, {"by": "none"} + else: + raise ValueError(f"Unsupported scheduler: {name}") + + +def step_scheduler(scheduler, meta: Dict[str, Any], epoch: int, val_metric: float | None = None): + """ + Step the scheduler depending on configuration (by epoch or by val metric). + meta["by"] is returned from _build_scheduler. + """ + if scheduler is None: + return + if meta.get("by") == "epoch": + scheduler.step() + elif meta.get("by") == "val_metric": + # lower is better by default; pass -val_metric if you want higher-better + scheduler.step(val_metric) + + +def make_optimizer_and_scheduler(cfg, params) -> Tuple[Optimizer, Tuple[_LRScheduler | None, Dict[str, Any]]]: + """ + Factory: returns (optimizer, (scheduler, meta)). + Use step_scheduler(scheduler, meta, epoch, val_metric) to step it. + """ + optimizer = _build_optimizer(cfg, params) + scheduler, meta = _build_scheduler(cfg, optimizer) + return optimizer, (scheduler, meta) diff --git a/src/utils/training_utils.py b/src/utils/training_utils.py new file mode 100644 index 0000000..9459c6c --- /dev/null +++ b/src/utils/training_utils.py @@ -0,0 +1,43 @@ +# src/utils/training_utils.py +# -*- coding: utf-8 -*- +from __future__ import annotations +import os +import json +from typing import Optional + +import torch + +def env_path(key: str, default: str = ".") -> str: + """Read an env var with a default; expand ~ and vars.""" + return os.path.expanduser(os.path.expandvars(os.getenv(key, default))) + +def get_gpu_memory() -> int: + """Return total used GPU memory (MB) across visible GPUs; 0 if unavailable.""" + try: + import pynvml # type: ignore + pynvml.nvmlInit() + n = pynvml.nvmlDeviceGetCount() + used = 0 + for i in range(n): + h = pynvml.nvmlDeviceGetHandleByIndex(i) + mem = pynvml.nvmlDeviceGetMemoryInfo(h) + used += mem.used + pynvml.nvmlShutdown() + return int(used / (1024 * 1024)) + except Exception: + return 0 + +def profile_model(model: torch.nn.Module, resolution: int) -> float: + """ + Estimate model FLOPs in GFLOPs using thop; returns -1 on failure. + """ + try: + from thop import profile # type: ignore + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dummy = torch.randn(1, 3, resolution, resolution, device=device) + model = model.to(device) + flops, _ = profile(model, inputs=(dummy,)) + return float(flops) / 1e9 + except Exception as e: + print(f"[profile_model] FLOP profiling failed: {e}") + return -1.0 diff --git a/src/utils/utils.py b/src/utils/utils.py new file mode 100644 index 0000000..daa6c7b --- /dev/null +++ b/src/utils/utils.py @@ -0,0 +1,80 @@ +"""General utilities for environment, GPU, and I/O operations.""" +import os +import json +import shutil +from typing import Dict, Any, Optional +import torch + +try: + import pynvml + pynvml.nvmlInit() + PYNVML_AVAILABLE = True +except ImportError: + PYNVML_AVAILABLE = False + print("pynvml not installed, GPU memory monitoring disabled.") + +def env_path(key: str, default: str) -> str: + """Get environment variable or default value.""" + return os.environ.get(key, default) + +def setup_environment(): + """Setup cache paths and environment variables.""" + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + os.environ["HF_HOME"] = os.getenv( + "HF_HOME", "~/.cache/huggingface/transformers" + ) + os.environ["HF_DATASETS_CACHE"] = os.getenv( + "HF_DATASETS_CACHE", "~/.cache/huggingface/datasets" + ) + os.environ["HF_HOME"] = os.getenv("HF_HOME", "~/.cache/huggingface") + +def get_gpu_memory(device_id: int = 0) -> float: + """ + Get GPU memory usage in MB. + + Returns: + Memory usage in MB, or -1 if unavailable. + """ + if not torch.cuda.is_available() or not PYNVML_AVAILABLE: + return -1 + + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + return mem_info.used / 1024**2 + except Exception as e: + print(f"Error getting GPU memory: {e}") + return -1 + +def check_disk_space(required_gb: float = 1.0) -> bool: + """ + Check if sufficient disk space is available. + + Args: + required_gb: Required space in GB + + Returns: + True if sufficient space available + """ + total, used, free = shutil.disk_usage("/") + free_gb = free / (2**30) + + print(f"Disk space: {free_gb:.2f} GB free") + + if free_gb < required_gb: + raise RuntimeError( + f"Insufficient disk space. Need {required_gb} GB, have {free_gb:.2f} GB" + ) + return True + +def save_results(results: Dict[str, Any], filepath: str): + """Save results to JSON file.""" + os.makedirs(os.path.dirname(filepath) or ".", exist_ok=True) + with open(filepath, 'w') as f: + json.dump(results, f, indent=4) + print(f"Results saved to: {filepath}") + +def load_json(filepath: str) -> Dict[str, Any]: + """Load JSON file.""" + with open(filepath, 'r') as f: + return json.load(f) diff --git a/src/compressed_perception/modules/data_preparation/__init__.py b/src/wrappers/distill.py similarity index 100% rename from src/compressed_perception/modules/data_preparation/__init__.py rename to src/wrappers/distill.py diff --git a/src/wrappers/finetune.py b/src/wrappers/finetune.py new file mode 100644 index 0000000..f76c6fa --- /dev/null +++ b/src/wrappers/finetune.py @@ -0,0 +1,38 @@ +from typing import Dict, Any +import torch + +from src.engines.finetune_engine import train_finetune +from src.data.datamodule import build_datamodule +from src.models.factory import build_classifier +from src.losses.classification import cross_entropy_loss +from src.utils.optim import make_optimizer_and_scheduler +from src.utils.logging import get_logger # TODO + +log = get_logger(__name__) + +def run(cfg) -> Dict[str, Any]: + device = torch.device(cfg.runtime["device"]) + dm = build_datamodule(cfg) + model = build_classifier(cfg).to(device) # backbone + classification head + + # Unfreeze all params for full finetuning + if hasattr(model, "unfreeze_all"): + model.unfreeze_all() + + loss_fn = cross_entropy_loss() + + optimizer, scheduler = make_optimizer_and_scheduler(cfg, model.parameters()) + + log.info("Starting end-to-end finetuning...") + metrics = train_finetune( + model=model, + loaders=dm.loaders(), + loss_fn=loss_fn, + optimizer=optimizer, + scheduler=scheduler, + device=device, + epochs=cfg.train.epochs, + grad_clip=getattr(cfg.train, "grad_clip_norm", None), + mixed_precision=getattr(cfg.train, "amp", True), + ) + return metrics diff --git a/src/wrappers/probe.py b/src/wrappers/probe.py new file mode 100644 index 0000000..f1c7c12 --- /dev/null +++ b/src/wrappers/probe.py @@ -0,0 +1,162 @@ +# src/wrappers/probe.py +# -*- coding: utf-8 -*- +"""Linear probing wrapper for training classification heads on frozen backbones.""" +from __future__ import annotations +from typing import Any, Dict + +import os +import torch +from torch.utils.data import DataLoader + +# pylint: disable=import-error +from src.utils.logging import setup_logging, get_logger, WandbLogger +from src.utils.optim import make_optimizer_and_scheduler +from src.losses.classification import cross_entropy_loss +from src.models.factory import ( + create_model, create_preprocessor, freeze_backbone, save_model +) +from src.data.datamodule import BaseDataModule +from src.engines.linear_probe_engine import train_probe +from src.utils.training_utils import profile_model + +log = get_logger(__name__) + + +class ProbeWrapper: # pylint: disable=too-many-instance-attributes,too-few-public-methods + """ + Orchestrates linear probing: + - builds model (+preprocessor if you need it for your datasets), + - freezes backbone, + - prepares dataloaders via DataModule, + - creates optimizer/scheduler/loss, + - calls the probe engine, + - saves best model. + """ + + def __init__(self, cfg: Any): + """ + Expected cfg fields (suggested): + cfg.model: {type, model_id, config{num_labels,...}} + cfg.data: {dataset_name, data_dir, image_size, batch_size, num_workers} + cfg.train: {epochs, optimizer{...}, scheduler{...}, grad_clip, mixed_precision} + cfg.loss: {label_smoothing, ignore_index, reduction} + cfg.logging: {project, entity, run_name, wandb_enabled} + cfg.runtime: {run_dir} + """ + self.cfg = cfg + setup_logging() + + # Build model + self.model_info = cfg.model + self.model = create_model(self.model_info, resolution=cfg.data.image_size) + + # Optionally create preprocessor (useful if datamodule needs it) + try: + self.preprocessor = create_preprocessor( + self.model_info, resolution=cfg.data.image_size + ) + except (ImportError, AttributeError, KeyError) as e: + log.debug(f"Preprocessor creation failed: {e}") + self.preprocessor = None + + # Freeze backbone for linear probing + freeze_backbone(self.model, self.model_info["type"]) + + # Data + self.dm = BaseDataModule( + cfg=cfg, + dataset_name=cfg.data.dataset_name, + data_dir=cfg.data.data_dir, + batch_size=cfg.data.batch_size, + num_workers=cfg.data.num_workers, + pin_memory=True, + ) + self.dm.setup("fit") + + # Optimizer & scheduler + self.optimizer, (self.scheduler, self.sched_meta) = ( + make_optimizer_and_scheduler(cfg, self.model.parameters()) + ) + + # Loss + self.loss_fn = cross_entropy_loss( + label_smoothing=float(getattr(cfg.loss, "label_smoothing", 0.0)), + class_weight=None, + ignore_index=int(getattr(cfg.loss, "ignore_index", -100)), + reduction=str(getattr(cfg.loss, "reduction", "mean")), + ) + + # W&B + self.wandb = WandbLogger( + project=getattr(cfg.logging, "project", "resolution-aware-probe"), + run_name=getattr(cfg.logging, "run_name", None), + config=cfg, + enabled=bool(getattr(cfg.logging, "wandb_enabled", True)), + entity=getattr(cfg.logging, "entity", None), + tags=getattr(cfg.logging, "tags", ["probe"]), + ) + + # Device + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) + + def _make_loaders(self) -> Dict[str, DataLoader]: + """Create data loaders for training and validation.""" + return { + "train": self.dm.train_dataloader(), + "val": self.dm.val_dataloader(), + # "test": self.dm.test_dataloader(), # optional + } + + def train(self) -> Dict[str, Any]: + """Run linear probing training.""" + log.info("Starting linear probe...") + + # Optional: profile FLOPs once + gflops = profile_model(self.model, self.cfg.data.image_size) + if self.wandb: + self.wandb.log({"model/gflops": gflops}) + + results = train_probe( + model=self.model, + loaders=self._make_loaders(), + loss_fn=self.loss_fn, + optimizer=self.optimizer, + scheduler=(self.scheduler, self.sched_meta), + device=self.device, + epochs=int(self.cfg.train.epochs), + grad_clip=getattr(self.cfg.train, "grad_clip", None), + mixed_precision=bool( + getattr(self.cfg.train, "mixed_precision", True) + ), + log_interval=int(getattr(self.cfg.train, "log_interval", 50)), + wandb_logger=self.wandb, + metric_key=str(getattr(self.cfg.train, "metric_key", "val_acc")), + ) + + # Save best model (HF format if applicable) + run_dir = getattr(self.cfg.runtime, "run_dir", "./runs/probe") + os.makedirs(run_dir, exist_ok=True) + try: + save_model( + self.model, + self.model_info, + save_dir=run_dir, + preprocessor=self.preprocessor + ) + except (OSError, ValueError, AttributeError) as e: + log.warning(f"Failed to save model in HF format; error: {e}") + + # Finish W&B + if self.wandb: + self.wandb.log({"best/metric": results.get("best_metric", None)}) + self.wandb.finish() + + log.info("Linear probe finished.") + return results + + +def run(cfg: Any) -> Dict[str, Any]: + """Convenience function to run from a CLI entrypoint.""" + wrapper = ProbeWrapper(cfg) + return wrapper.train() From ed69142d4691e94b06bb0e3c34adf7fd45ecc167 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 17 Oct 2025 17:28:17 -0700 Subject: [PATCH 16/26] resolve conflicts --- .gitignore | 5 +---- README.md | 1 - configs/example_config.yaml.license | 5 ----- 3 files changed, 1 insertion(+), 10 deletions(-) delete mode 100644 configs/example_config.yaml.license diff --git a/.gitignore b/.gitignore index 378c36b..2afc05e 100644 --- a/.gitignore +++ b/.gitignore @@ -212,7 +212,4 @@ marimo/_static/ marimo/_lsp/ __marimo__/ -**/.DS_Store - - -.DS_Store \ No newline at end of file +**/.DS_Store \ No newline at end of file diff --git a/README.md b/README.md index 70f31e9..b7278ae 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,6 @@ SPDX-License-Identifier: MIT This project explores how compressed and degraded dermatology images (from the ISIC 2019 dataset) affect classification performance using pretrained vision models. It compares fine-tuning vs. linear probing across multiple JPEG quality levels. -![System architecture diagram](<./media/CS231N Poster.png>) ## Project Goals diff --git a/configs/example_config.yaml.license b/configs/example_config.yaml.license deleted file mode 100644 index 3cc951b..0000000 --- a/configs/example_config.yaml.license +++ /dev/null @@ -1,5 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT \ No newline at end of file From ac08385413e5888c299e80a2da83afb7c5bdd6cf Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 17 Oct 2025 18:57:37 -0700 Subject: [PATCH 17/26] demo for transformation implementation for ISIC --- .DS_Store | Bin 6148 -> 6148 bytes requirements_demo.txt | 20 ++ scripts/demo_transformation.py | 99 ++++++ scripts/test_transforms.py | 90 ----- src/data/data_utils.py | 301 ++++++++++------- src/data/datasets.py | 14 - src/data/isic_loader.py | 42 +++ src/engines/distill_engine.py | 0 src/engines/finetune_engine.py | 0 src/engines/linear_probe_engine.py | 224 ------------- src/engines/train.py | 516 +++++++++++++++++++++++++++++ src/wrappers/distill.py | 0 12 files changed, 863 insertions(+), 443 deletions(-) create mode 100644 requirements_demo.txt create mode 100644 scripts/demo_transformation.py delete mode 100644 scripts/test_transforms.py delete mode 100644 src/data/datasets.py create mode 100644 src/data/isic_loader.py delete mode 100644 src/engines/distill_engine.py delete mode 100644 src/engines/finetune_engine.py delete mode 100644 src/engines/linear_probe_engine.py create mode 100644 src/engines/train.py delete mode 100644 src/wrappers/distill.py diff --git a/.DS_Store b/.DS_Store index a8899eb143a1ad8a3d2bec2155111732ede3569b..d1b88db71b01a9339795b5bbeb2700f40e616132 100644 GIT binary patch delta 48 zcmZoMXfc@JFUrcmz`)4BAi%)j#}Lfm>Y0=2.0.0 +torchvision>=0.15.0 + +# HuggingFace datasets for loading ISIC data +datasets>=2.14.0 + +# Image processing +Pillow>=9.0.0 + +# Data manipulation +numpy>=1.24.0 + +# Optional: For progress bars and better user experience +tqdm>=4.65.0 + +# Optional: For HuggingFace transformers (if using real preprocessors) +transformers>=4.30.0 diff --git a/scripts/demo_transformation.py b/scripts/demo_transformation.py new file mode 100644 index 0000000..080dbd7 --- /dev/null +++ b/scripts/demo_transformation.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# pylint: disable=broad-exception-caught + +""" +Demo script to apply ResolutionReductionTransform to ISIC dataset images and save results. +""" + +import sys +from pathlib import Path +from PIL import Image + +# Add src to Python path for imports +current_dir = Path(__file__).parent +project_root = current_dir.parent +sys.path.insert(0, str(project_root)) + +from datasets import load_dataset +from src.data.isic_loader import ISICBaseDataset +from src.transformation.transforms import ResolutionReductionTransform + +def main(): + print("🔄 Loading ISIC dataset from HuggingFace...") + try: + hf_dataset = load_dataset("MKZuziak/ISIC_2019_224", split="train") + isic_dataset = ISICBaseDataset(hf_dataset) + print(f"✅ Loaded {len(isic_dataset)} samples") + except Exception as e: + print(f"❌ Failed to load dataset: {e}") + print("Make sure you have 'datasets' installed: pip install datasets") + return + + # Create output directory + output_dir = Path("outputs") + output_dir.mkdir(exist_ok=True) + print(f"📁 Output directory: {output_dir.absolute()}") + + print("\nApplying ResolutionReductionTransform to first 5 images...") + + # Create the transform + resolution_transform = ResolutionReductionTransform() # Random reduction factor + + # Process first 5 images + num_images = min(5, len(isic_dataset)) + + for i in range(num_images): + print(f"\nProcessing image {i+1}/{num_images}...") + + try: + # Get the original image + sample = isic_dataset[i] + original_image = sample["image"] + label = sample["label"] + + print(f"Original size: {original_image.size}, Label: {label}") + + # Save original image + original_path = output_dir / f"original_{i}.png" + original_image.save(original_path) + print(f" 💾 Saved original: {original_path}") + + # Apply resolution reduction transform + transformed_image = resolution_transform(original_image) + print(f" 🔄 Transformed size: {transformed_image.size}") + + # Save transformed image + transformed_path = output_dir / f"resolution_reduced_{i}.png" + transformed_image.save(transformed_path) + print(f" 💾 Saved transformed: {transformed_path}") + + except Exception as e: + print(f" ❌ Error processing image {i}: {e}") + continue + + print(f"\n✅ All images saved to: {output_dir.absolute()}") + + # Show what reduction factors were used (they're random) + print("\nTesting with fixed reduction factors...") + for factor in [0.25, 0.5, 0.75]: + print(f"\nTesting reduction factor: {factor}") + try: + fixed_transform = ResolutionReductionTransform(reduction_factor=factor) + + # Use first image for this demo + sample = isic_dataset[0] + original_image = sample["image"] + + transformed = fixed_transform(original_image) + output_path = output_dir / f"fixed_reduction_{factor}_{0}.png" + transformed.save(output_path) + print(f" 💾 Saved: {output_path}") + + except Exception as e: + print(f" ❌ Error with factor {factor}: {e}") + + print(f"\n🎉 Demo completed! Check {output_dir.absolute()} for results.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/test_transforms.py b/scripts/test_transforms.py deleted file mode 100644 index 933b2a3..0000000 --- a/scripts/test_transforms.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Minimal check: dataset loads and a degradation transform is applied to 5 images. -- Uses your BaseDataModule and get_degradation_transforms() -- No training, no W&B. -- Saves clean and degraded PNGs. -""" - -from types import SimpleNamespace -from pathlib import Path -import argparse - -from torchvision.utils import save_image -import torchvision.transforms.functional as TF -from PIL import Image - -from src.data.datamodule import BaseDataModule -from src.transformation.transforms import get_degradation_transforms -from src.utils import setup_environment - - -def main(): - ap = argparse.ArgumentParser(description="Quick dataset+transform smoke test (5 images).") - ap.add_argument("--dataset", type=str, required=True, choices=["isic", "tcga", "merlin"]) - ap.add_argument("--data_dir", type=str, required=True) - ap.add_argument("--out_dir", type=str, default="outputs/test_transforms") - ap.add_argument("--resolution", type=int, default=224) - ap.add_argument("--num_workers", type=int, default=2) - args = ap.parse_args() - - setup_environment() - - out_root = Path(args.out_dir) - out_root.mkdir(parents=True, exist_ok=True) - - # ---- minimal cfg your DataModule/get_dataset can read ---- - cfg = SimpleNamespace( - resolution=args.resolution, - batch_size=5, # just enough to grab 5 samples - ) - - # ---- build the DataModule exactly as your code expects ---- - dm = BaseDataModule( - cfg=cfg, - dataset_name=args.dataset, - data_dir=args.data_dir, - num_workers=args.num_workers, - batch_size=cfg.batch_size, - pin_memory=True, - drop_last=False, - ) - dm.setup(stage="fit") - - # ---- pull one small batch (5 images) from val ---- - val_loader = dm.val_dataloader() - batch = next(iter(val_loader)) - - if isinstance(batch, dict): - x = batch.get("pixel_values") or batch.get("images") or batch.get("x") - y = batch.get("labels") or batch.get("y") - else: - x, y = batch - - # keep exactly 5 - x = x[:5] - - for i in range(x.size(0)): - save_image(x[i], out_root / f"clean_{i}.png") - - degradations = get_degradation_transforms() - if not degradations: - print("No degradations returned by get_degradation_transforms(); saved only clean images.") - return - - degr = degradations[0] - tag = degr.__class__.__name__.lower() - print(f"Applying degradation: {tag}") - - for i in range(x.size(0)): - pil_img = TF.to_pil_image(x[i].cpu()) - pil_out = degr(pil_img) - t_out = TF.to_tensor(pil_out) - save_image(t_out, out_root / f"{tag}_{i}.png") - - print(f"✅ Done. Wrote images to: {out_root.resolve()}") - - -if __name__ == "__main__": - main() diff --git a/src/data/data_utils.py b/src/data/data_utils.py index 351e018..e7d0b75 100644 --- a/src/data/data_utils.py +++ b/src/data/data_utils.py @@ -3,43 +3,25 @@ import torch from PIL import Image from torch.utils.data import Dataset, ConcatDataset, Subset -from torchvision import transforms from typing import Optional, List, Dict, Any, Union from src.config import HF_MODELS, DEFAULT_IMAGE_SIZE -from src.transformation.transforms import ResolutionReductionTransform -class ISICDataset(Dataset): - """ISIC dataset with support for multiple model types and transformations.""" - - def __init__( - self, - dataset: Union[Dataset, Subset], - preprocessor: Optional[Any] = None, - resolution: int = DEFAULT_IMAGE_SIZE, - transform: Optional[transforms.Compose] = None, - model_type: str = "vit", - jpeg_quality: Optional[int] = None, - ): - self.dataset = dataset - self.preprocessor = preprocessor - self.resolution = resolution - self.transform = transform - self.model_type = model_type - self.jpeg_quality = jpeg_quality +# ============================================================================ +# DATA LOADING +# ============================================================================ - # Create base preprocessing - self.base_preprocessor = transforms.Compose([ - transforms.Resize((resolution, resolution), Image.LANCZOS), - transforms.ToTensor(), - ]) +class DatasetWrapper(Dataset): + """Simple wrapper that handles dataset/subset access patterns.""" - self.model_preprocessor = None + def __init__(self, dataset: Union[Dataset, Subset]): + self.dataset = dataset def __len__(self) -> int: return len(self.dataset) - def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + def get_raw_item(self, idx: int) -> Dict[str, Any]: + """Get raw item from dataset, handling both Dataset and Subset.""" # Convert numpy types to Python int if isinstance(idx, (np.integer, np.int64)): idx = int(idx) @@ -48,100 +30,175 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: if hasattr(self.dataset, 'dataset'): # This is a Subset subset_idx = int(self.dataset.indices[idx]) - item = self.dataset.dataset[subset_idx] + return self.dataset.dataset[subset_idx] else: # Direct dataset access - item = self.dataset[idx] + return self.dataset[idx] - image = item["image"] - label = item["label"] - # Resize to target resolution - image = image.resize((self.resolution, self.resolution), Image.LANCZOS) +# ============================================================================ +# IMAGE PREPROCESSING +# ============================================================================ - # Apply optional transformations - if self.transform: - image = self.transform(image) +class ImageProcessor: + """Handles image preprocessing operations.""" - if self.jpeg_quality is not None: - image = JPEGCompressionTransform(self.jpeg_quality)(image) + def __init__(self, resolution: int = DEFAULT_IMAGE_SIZE): + self.resolution = resolution - # Apply model-specific preprocessing - if self.model_type in HF_MODELS: - # For HuggingFace models - if hasattr(self.preprocessor, 'size'): - self.preprocessor.size = self.resolution - encoding = self.preprocessor(images=image, return_tensors="pt") - pixel_values = encoding["pixel_values"].squeeze(0) - else: + def resize_image(self, image: Image.Image) -> Image.Image: + """Resize image to target resolution.""" + return image.resize( + (self.resolution, self.resolution), + Image.Resampling.LANCZOS + ) + + def apply_transforms(self, image: Image.Image, transform: Optional[Any]) -> Image.Image: + """Apply optional transformations to image.""" + if transform: + return transform(image) + return image + + +# ============================================================================ +# MODEL-SPECIFIC PREPROCESSING +# ============================================================================ + +class ModelPreprocessor: + """Handles model-specific preprocessing.""" + + def __init__(self, preprocessor: Optional[Any] = None, model_type: str = "vit"): + self.preprocessor = preprocessor + self.model_type = model_type + + if self.model_type not in HF_MODELS: raise ValueError(f"Unsupported model_type: {self.model_type}") + def preprocess(self, image: Image.Image, resolution: int) -> torch.Tensor: + """Apply model-specific preprocessing.""" + # For HuggingFace models + if hasattr(self.preprocessor, 'size'): + self.preprocessor.size = resolution + + encoding = self.preprocessor(images=image, return_tensors="pt") + return encoding["pixel_values"].squeeze(0) + + +# ============================================================================ +# COMBINED DATASET +# ============================================================================ + +class ISICDataset(Dataset): + """ISIC dataset that combines data loading, image processing, and model preprocessing.""" + + def __init__( + self, + dataset: Union[Dataset, Subset], + preprocessor: Optional[Any] = None, + resolution: int = DEFAULT_IMAGE_SIZE, + transform: Optional[Any] = None, + model_type: str = "vit", + ): + # Data loading + self.data_wrapper = DatasetWrapper(dataset) + + # Image processing + self.image_processor = ImageProcessor(resolution) + + # Model preprocessing + self.model_preprocessor = ModelPreprocessor(preprocessor, model_type) + + # Transformation + self.transform = transform + + def __len__(self) -> int: + return len(self.data_wrapper) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + # 1. Load raw data + item = self.data_wrapper.get_raw_item(idx) + image = item["image"] + label = item["label"] + + # 2. Process image + image = self.image_processor.resize_image(image) + image = self.image_processor.apply_transforms(image, self.transform) + + # 3. Apply model-specific preprocessing + pixel_values = self.model_preprocessor.preprocess( + image, self.image_processor.resolution + ) + + # 4. Prepare output label = torch.tensor(label, dtype=torch.long) return {"pixel_values": pixel_values, "labels": label} -def create_transformed_datasets( - train_dataset: Dataset, - val_dataset: Dataset, - transforms_list: List, - proportion_per_transform: float, - preprocessor: Optional[Any], - resolution: int, - model_type: str -) -> tuple[Dataset, Dataset]: - """ - Create train and validation datasets with transformations. - - Args: - train_dataset: Training dataset - val_dataset: Validation dataset - transforms_list: List of transforms to apply - proportion_per_transform: Proportion of data for each transform - preprocessor: Model preprocessor - resolution: Image resolution - model_type: Type of model - - Returns: - Tuple of (train_dataset, val_dataset) - """ - num_images = len(train_dataset) +# ============================================================================ +# DATASET CREATION UTILITIES +# ============================================================================ + +def split_dataset_for_transforms( + dataset: Dataset, + transforms_list: List[Any], + proportion_per_transform: float +) -> List[Subset]: + """Split dataset into subsets for applying different transforms.""" + num_images = len(dataset) images_per_transform = int(num_images * proportion_per_transform) # Shuffle indices indices = np.arange(num_images) np.random.shuffle(indices) - transformed_datasets = [] + subsets = [] used_indices = [] - # Apply each transform to a subset - for i, transform in enumerate(transforms_list): + # Create subset for each transform + for i, _ in enumerate(transforms_list): start_idx = i * images_per_transform end_idx = start_idx + images_per_transform subset_indices = indices[start_idx:end_idx] used_indices.extend(subset_indices) + subsets.append(Subset(dataset, subset_indices)) + + # Add remaining samples + remaining_indices = np.setdiff1d(indices, used_indices) + if len(remaining_indices) > 0: + subsets.append(Subset(dataset, remaining_indices)) + + return subsets + + +def create_transformed_datasets( + train_dataset: Dataset, + val_dataset: Dataset, + transforms_list: List[Any], + proportion_per_transform: float, + preprocessor: Optional[Any], + resolution: int, + model_type: str +) -> tuple[Dataset, Dataset]: + """Create train and validation datasets with transformations.""" + + # Split training data into subsets + train_subsets = split_dataset_for_transforms( + train_dataset, transforms_list, proportion_per_transform + ) - subset = Subset(train_dataset, subset_indices) - transform_compose = transforms.Compose([transform]) + # Create datasets with transforms + transformed_datasets = [] + # Apply each transform to corresponding subset + for i, (subset, transform) in enumerate(zip(train_subsets[:-1], transforms_list)): transformed_ds = ISICDataset( - subset, - preprocessor, - resolution, - transform_compose, - model_type + subset, preprocessor, resolution, transform, model_type ) transformed_datasets.append(transformed_ds) - # Add remaining samples without transformation - remaining_indices = np.setdiff1d(indices, used_indices) - if len(remaining_indices) > 0: - remaining_subset = Subset(train_dataset, remaining_indices) + # Add untransformed subset (if any remaining) + if len(train_subsets) > len(transforms_list): untransformed_ds = ISICDataset( - remaining_subset, - preprocessor, - resolution, - None, - model_type + train_subsets[-1], preprocessor, resolution, None, model_type ) transformed_datasets.append(untransformed_ds) @@ -149,29 +206,16 @@ def create_transformed_datasets( train_ds = ConcatDataset(transformed_datasets) # Create validation dataset (no transformations) - val_ds = ISICDataset( - val_dataset, - preprocessor, - resolution, - model_type=model_type, - ) + val_ds = ISICDataset(val_dataset, preprocessor, resolution, None, model_type) return train_ds, val_ds -def balance_dataset(dataset: Dataset, filtered_classes: List[str], num_train_images: int, seed: int = 42): - """ - Balance dataset by sampling equal numbers from each class. - - Args: - dataset: Input dataset - filtered_classes: Classes to keep - num_train_images: Total number of training images - seed: Random seed +# ============================================================================ +# DATASET BALANCING +# ============================================================================ - Returns: - Balanced dataset - """ - # Get class counts +def get_class_distribution(dataset: Dataset, filtered_classes: List[str]) -> Dict[str, List[int]]: + """Get class distribution and indices.""" class_counts = {label: 0 for label in filtered_classes} class_indices = {label: [] for label in filtered_classes} @@ -182,19 +226,46 @@ def balance_dataset(dataset: Dataset, filtered_classes: List[str], num_train_ima class_indices[label_str].append(i) print(f"Class counts before balancing: {class_counts}") + return class_indices + + +def sample_balanced_indices( + class_indices: Dict[str, List[int]], + num_train_images: int, + seed: int = 42 +) -> List[int]: + """Sample balanced indices from each class.""" + np.random.seed(seed) # Calculate samples per class - min_class_size = min(class_counts.values()) - images_per_class = min(num_train_images // len(filtered_classes), min_class_size) + num_classes = len(class_indices) + min_class_size = min(len(indices) for indices in class_indices.values()) + images_per_class = min(num_train_images // num_classes, min_class_size) + + print(f"Sampling {images_per_class} images per class") # Sample from each class - np.random.seed(seed) balanced_indices = [] - - for label in filtered_classes: - indices = class_indices[label] + for label, indices in class_indices.items(): sampled = np.random.choice(indices, images_per_class, replace=False) balanced_indices.extend(sampled) np.random.shuffle(balanced_indices) + return balanced_indices + + +def balance_dataset( + dataset: Dataset, + filtered_classes: List[str], + num_train_images: int, + seed: int = 42 +) -> Dataset: + """Balance dataset by sampling equal numbers from each class.""" + # Get class distribution + class_indices = get_class_distribution(dataset, filtered_classes) + + # Sample balanced indices + balanced_indices = sample_balanced_indices(class_indices, num_train_images, seed) + + # Return balanced dataset return dataset.select(balanced_indices) \ No newline at end of file diff --git a/src/data/datasets.py b/src/data/datasets.py deleted file mode 100644 index 8568949..0000000 --- a/src/data/datasets.py +++ /dev/null @@ -1,14 +0,0 @@ -# src/data/datasets.py -from src.data.isic_dataset import ISICDataset -from src.data.tcga_dataset import TCGADataset -from src.data.merlin_dataset import MerlinDataset - -def get_dataset(dataset_name, data_dir, split, cfg): - if dataset_name.lower() == "isic": - return ISICDataset(data_dir=data_dir, split=split, cfg=cfg) - elif dataset_name.lower() == "tcga": - return TCGADataset(data_dir=data_dir, split=split, cfg=cfg) - elif dataset_name.lower() == "merlin": - return MerlinDataset(data_dir=data_dir, split=split, cfg=cfg) - else: - raise ValueError(f"Unknown dataset: {dataset_name}") diff --git a/src/data/isic_loader.py b/src/data/isic_loader.py new file mode 100644 index 0000000..ee87d09 --- /dev/null +++ b/src/data/isic_loader.py @@ -0,0 +1,42 @@ +# src/data/isic_loader.py +from typing import Any, Dict, Union +from torch.utils.data import Dataset, Subset + +class ISICBaseDataset(Dataset): + """ + Minimal, transformation-free wrapper for ISIC (or ISIC-like) datasets. + + Expects the backing dataset (or Subset) to yield items with: + item["image"] : PIL.Image (or array/tensor if your source uses that) + item["label"] : int-like + + Returns each sample unchanged: + {"image": , "label": } + """ + + def __init__(self, dataset: Union[Dataset, Subset]): + self.dataset = dataset + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> Dict[str, Any]: + # Handle Subset wrapping transparently + base = self.dataset + if hasattr(base, "dataset") and hasattr(base, "indices"): + # Subset case + item = base.dataset[int(base.indices[idx])] + else: + item = base[idx] + + # Do NOT touch the image (no resize, no cast, no transforms) + image = item["image"] + label = item["label"] + + # Make sure label is int-like, but don't coerce image + try: + label = int(label) + except Exception: + raise TypeError("Label must be convertible to int.") + + return {"image": image, "label": label} diff --git a/src/engines/distill_engine.py b/src/engines/distill_engine.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/engines/finetune_engine.py b/src/engines/finetune_engine.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/engines/linear_probe_engine.py b/src/engines/linear_probe_engine.py deleted file mode 100644 index 9567b6f..0000000 --- a/src/engines/linear_probe_engine.py +++ /dev/null @@ -1,224 +0,0 @@ -# src/engines/linear_probe_engine.py -# -*- coding: utf-8 -*- -"""Linear probing engine for training classification heads on frozen backbones.""" -from __future__ import annotations -from typing import Dict, Any, Tuple, Optional -import math -import torch -from torch import nn -from torch.utils.data import DataLoader - -# pylint: disable=import-error -from src.utils.logging import get_logger, MetricAverager, WandbLogger -from src.utils.optim import step_scheduler - -log = get_logger(__name__) - - -def _unpack_loaders(loaders: Any) -> Tuple[DataLoader, Optional[DataLoader], Optional[DataLoader]]: - """ - Accepts either: - - {"train": ..., "val": ..., "test": ...} - - (train_loader, val_loader) - Returns (train, val, test) - """ - if isinstance(loaders, dict): - return loaders.get("train"), loaders.get("val"), loaders.get("test") - if isinstance(loaders, (tuple, list)) and len(loaders) >= 2: - return loaders[0], loaders[1], loaders[2] if len(loaders) > 2 else None - raise ValueError( - "`loaders` must be a dict with keys train/val[/test] or a (train,val) tuple." - ) - - -@torch.no_grad() -def _evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Dict[str, float]: - model.eval() - correct = 0 - total = 0 - running_loss = 0.0 - - # if the model exposes a criterion attribute, prefer the external loss - # passed in train loop anyway. - ce = nn.CrossEntropyLoss() - - for batch in loader: - # support (x,y) or dict - if isinstance(batch, dict): - x = (batch.get("pixel_values") or - batch.get("images") or - batch.get("x")) - y = batch.get("labels") or batch.get("y") - else: - x, y = batch - - x = x.to(device, non_blocking=True) - y = y.to(device, non_blocking=True) - - logits = model(x) - if isinstance(logits, (tuple, list)): - logits = logits[0] - loss = ce(logits, y) - running_loss += float(loss.item()) * x.size(0) - - preds = logits.argmax(dim=-1) - correct += int((preds == y).sum().item()) - total += int(y.numel()) - - acc = correct / max(total, 1) - avg_loss = running_loss / max(total, 1) - return {"val_loss": avg_loss, "val_acc": acc} - - -def train_probe( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-branches,too-many-statements - model: nn.Module, - loaders: Any, - loss_fn, # callable: (logits, targets) -> loss tensor - optimizer: torch.optim.Optimizer, - scheduler: Any = None, # either a scheduler or (scheduler, meta) - device: torch.device = torch.device("cpu"), - epochs: int = 10, - grad_clip: Optional[float] = None, - mixed_precision: bool = True, - log_interval: int = 50, - wandb_logger: Optional[WandbLogger] = None, - metric_key: str = "val_acc", -) -> Dict[str, Any]: - """ - Linear probing engine. - Expects the backbone to be frozen already; only the head should have requires_grad=True. - """ - train_loader, val_loader, _ = _unpack_loaders(loaders) - assert train_loader is not None, "train_loader is required" - assert val_loader is not None, "val_loader is required" - - scaler = torch.amp.GradScaler(enabled=(mixed_precision and device.type == "cuda")) - - # unwrap (scheduler, meta) if it came from our utils - sched, sched_meta = None, {"by": "epoch"} - if scheduler is not None: - if isinstance(scheduler, tuple) and len(scheduler) == 2: - sched, sched_meta = scheduler[0], scheduler[1] - else: - sched = scheduler - - # sanity: log trainable parameter ratio - total_params = sum(p.numel() for p in model.parameters()) - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - log.info(f"[probe] trainable params: {trainable_params:,} / {total_params:,} " - f"({100.0 * trainable_params / max(total_params,1):.2f}%)") - - history = {"train_loss": [], "val_loss": [], "val_acc": []} - best_metric = -math.inf - best_state: Optional[Dict[str, torch.Tensor]] = None - - model.to(device) - - for epoch in range(1, epochs + 1): - model.train() - ma = MetricAverager() - - for it, batch in enumerate(train_loader, start=1): - if isinstance(batch, dict): - x = (batch.get("pixel_values") or - batch.get("images") or - batch.get("x")) - y = batch.get("labels") or batch.get("y") - else: - x, y = batch - - x = x.to(device, non_blocking=True) - y = y.to(device, non_blocking=True) - - optimizer.zero_grad(set_to_none=True) - - if scaler.is_enabled(): - with torch.amp.autocast(device_type=device.type): - logits = model(x) - if isinstance(logits, (tuple, list)): - logits = logits[0] - loss = loss_fn(logits, y) - scaler.scale(loss).backward() - if grad_clip is not None and grad_clip > 0: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_( - filter(lambda p: p.requires_grad, model.parameters()), - max_norm=grad_clip - ) - scaler.step(optimizer) - scaler.update() - else: - logits = model(x) - if isinstance(logits, (tuple, list)): - logits = logits[0] - loss = loss_fn(logits, y) - loss.backward() - if grad_clip is not None and grad_clip > 0: - torch.nn.utils.clip_grad_norm_( - filter(lambda p: p.requires_grad, model.parameters()), - max_norm=grad_clip - ) - optimizer.step() - - # running metrics - with torch.no_grad(): - preds = logits.argmax(dim=-1) - acc = (preds == y).float().mean().item() - ma.update(loss=float(loss.item()), acc=acc) - - if it % log_interval == 0: - avgs = ma.averages() - log.info(f"[probe] epoch {epoch:03d} iter {it:05d} " - f"| loss {avgs['loss']:.4f} | acc {avgs['acc']:.4f}") - if wandb_logger: - wandb_logger.log({ - "train/loss": avgs["loss"], - "train/acc": avgs["acc"], - "epoch": epoch, - "iter": it, - }) - - # end epoch → validation - val_metrics = _evaluate(model, val_loader, device=device) - history["train_loss"].append(ma.averages()["loss"]) - history["val_loss"].append(val_metrics["val_loss"]) - history["val_acc"].append(val_metrics["val_acc"]) - - log.info(f"[probe] epoch {epoch:03d} | val_loss {val_metrics['val_loss']:.4f} " - f"| val_acc {val_metrics['val_acc']:.4f}") - - if wandb_logger: - wandb_logger.log({ - "val/loss": val_metrics["val_loss"], - "val/acc": val_metrics["val_acc"], - "epoch": epoch, - }) - - # scheduler step - if sched is not None: - # if ReduceLROnPlateau, step on val metric; else per-epoch - if hasattr(sched, "step") and sched_meta.get("by") == "val_metric": - # for plateau: lower is better typically; if you want higher-better, pass negative - step_scheduler(sched, sched_meta, epoch=epoch, val_metric=val_metrics["val_loss"]) - else: - step_scheduler(sched, sched_meta, epoch=epoch) - - # track best - current = val_metrics.get(metric_key, val_metrics["val_acc"]) - if current > best_metric: - best_metric = current - best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} - - # restore best weights (optional but standard) - if best_state is not None: - model.load_state_dict(best_state) - - return { - "history": history, - "best_metric": float(best_metric), - "best_metric_name": metric_key, - "final_val_acc": float(history["val_acc"][-1]) if history["val_acc"] else None, - "final_val_loss": float(history["val_loss"][-1]) if history["val_loss"] else None, - "trainable_params": trainable_params, - "total_params": total_params, - } diff --git a/src/engines/train.py b/src/engines/train.py new file mode 100644 index 0000000..594f720 --- /dev/null +++ b/src/engines/train.py @@ -0,0 +1,516 @@ +"""Main training script.""" +import os +import time +import argparse +import numpy as np +import torch +import wandb +from datasets import load_dataset, ClassLabel +from transformers import Trainer, TrainingArguments +from typing import Dict, Any + +from src.config import ( + TrainingConfig, MODEL_REGISTRY, FILTERED_CLASSES, + NUM_FILTERED_CLASSES, HF_MODELS +) +from src.utils import ( + setup_environment, env_path, get_gpu_memory, + check_disk_space, save_results +) +from src.models import ( + create_model, create_preprocessor, freeze_backbone, save_model +) +from src.data_utils import ( + ISICDataset, create_transformed_datasets, balance_dataset +) +from src.transformation.transforms import ( + get_degradation_transforms, ResolutionReductionTransform +) +from src.utils.training_utils import ( + LossLoggerCallback, + WandbCallback, + profile_model +) + +from src.evaluation.metrics import ( + create_compute_metrics_fn, +) + +def create_multi_validation_datasets( + val_dataset, + preprocessor, + resolution: int, + model_type: str +) -> Dict[str, Any]: + """ + Create validation datasets with different degradation levels. + + Returns: + Dictionary mapping degradation name to dataset + """ + val_datasets = {} + + # Clean (no degradation) + val_datasets['clean'] = ISICDataset( + val_dataset, + preprocessor, + resolution, + transform=None, + model_type=model_type + ) + + # JPEG compression at different quality levels + for quality in [90, 50, 20]: + val_datasets[f'jpeg_{quality}'] = ISICDataset( + val_dataset, + preprocessor, + resolution, + transform=JPEGCompressionTransform(quality=quality), + model_type=model_type + ) + + # Gaussian blur at different radii + for radius in [1.0, 3.0, 5.0]: + val_datasets[f'blur_{radius:.1f}'] = ISICDataset( + val_dataset, + preprocessor, + resolution, + transform=GaussianBlurTransform(radius=radius), + model_type=model_type + ) + + # Color quantization at different levels + for n_colors in [64, 16, 4]: + val_datasets[f'color_{n_colors}'] = ISICDataset( + val_dataset, + preprocessor, + resolution, + transform=ColorQuantizationTransform(n_colors=n_colors), + model_type=model_type + ) + + return val_datasets + +def evaluate_all_datasets(trainer, val_datasets: Dict[str, Any], model_name: str) -> Dict[str, Any]: + """ + Evaluate model on all validation datasets. + + Args: + trainer: HuggingFace Trainer object + val_datasets: Dictionary of validation datasets + model_name: Name of the model for logging + + Returns: + Dictionary of results for each dataset + """ + all_results = {} + + for val_name, val_dataset in val_datasets.items(): + print(f"Evaluating on {val_name}...") + + # Evaluate on this dataset + eval_results = trainer.evaluate( + eval_dataset=val_dataset, + metric_key_prefix=f"eval_{val_name}" + ) + + # Extract key metrics + accuracy = eval_results.get(f"eval_{val_name}_accuracy", 0) + f1 = eval_results.get(f"eval_{val_name}_f1", 0) + auc = eval_results.get(f"eval_{val_name}_auc", 0) + + # Store results + all_results[val_name] = { + "accuracy": accuracy, + "f1": f1, + "auc": auc, + "loss": eval_results.get(f"eval_{val_name}_loss", 0) + } + + # Log to wandb + wandb.log({ + f"{val_name}/accuracy": accuracy, + f"{val_name}/f1": f1, + f"{val_name}/auc": auc, + "model": model_name + }) + + print(f" {val_name}: Acc={accuracy:.3f}, F1={f1:.3f}, AUC={auc:.3f}") + + return all_results + +def train_model( + model_info: dict, + train_dataset, + val_dataset, + config: TrainingConfig, + degradation_transforms: list, + training_mode: str = "finetune" # "finetune" or "linear_probe" +) -> dict: + """ + Train a single model with specified training mode. + + Args: + model_info: Model configuration + train_dataset: Training dataset + val_dataset: Validation dataset + config: Training configuration + degradation_transforms: List of data augmentations + training_mode: "finetune" or "linear_probe" + + Returns: + Dictionary of training results + """ + name = model_info["name"] + model_type = model_info["type"] + + print(f"\n{'='*50}") + print(f"Training {name} ({model_type}) - Mode: {training_mode}") + print(f"{'='*50}") + + # Initialize wandb + wandb.init( + entity="sonnet-xu-stanford-university", + project="CS231N Test", + name=f"{name}_{config.resolution}_{config.num_epochs}_epochs_{training_mode}", + config={ + **config.to_wandb_config(), + "model_config": model_info["config"], + "training_mode": training_mode + }, + tags=["baseline", "model-comparison", training_mode, name, f"res_{config.resolution}"], + reinit=True + ) + + # Create model and preprocessor + model = create_model(model_info, config.resolution) + preprocessor = create_preprocessor(model_info, config.resolution) + + # Freeze backbone for linear probing + if training_mode == "linear_probe": + print(f"Freezing backbone for linear probing...") + freeze_backbone(model, model_type) + + # Count trainable parameters + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + total_params = sum(p.numel() for p in model.parameters()) + print(f"Trainable params: {trainable_params:,} / {total_params:,} " + f"({100 * trainable_params / total_params:.2f}%)") + + # Move model to device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + + # Create training dataset with transformations + train_ds, _ = create_transformed_datasets( + train_dataset, + val_dataset, # Not used but required by function signature + degradation_transforms, + config.proportion_per_transform, + preprocessor, + config.resolution, + model_type + ) + + # Create multiple validation datasets + val_datasets = create_multi_validation_datasets( + val_dataset, + preprocessor, + config.resolution, + model_type + ) + + # Profile model + flops = profile_model(model, config.resolution) + + # Setup training arguments + output_dir = os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}_{training_mode}") + log_dir = env_path("LOG_DIR", "./logs") + + # Adjust learning rate for linear probing (typically higher) + learning_rate = config.learning_rate + if training_mode == "linear_probe": + learning_rate = config.learning_rate * 10 # Often need higher LR for linear probe + + training_args = TrainingArguments( + output_dir=output_dir, + num_train_epochs=config.num_epochs, + per_device_train_batch_size=config.batch_size, + per_device_eval_batch_size=config.batch_size, + learning_rate=learning_rate, + lr_scheduler_type="cosine", + weight_decay=config.weight_decay, + logging_dir=os.path.join(log_dir, f"{name}_{training_mode}"), + logging_steps=1, + evaluation_strategy="steps", + eval_steps=config.eval_steps, + save_strategy="steps", + save_steps=config.eval_steps, + load_best_model_at_end=True, # Load best model for final evaluation + metric_for_best_model="eval_clean_accuracy", # Use clean accuracy for model selection + greater_is_better=True, + save_total_limit=1, + save_safetensors=False, + push_to_hub=False, + ) + + # Check disk space + check_disk_space(required_gb=1.0) + + # Create trainer with clean validation set for checkpointing + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_ds, + eval_dataset=val_datasets['clean'], # Use clean for model selection + compute_metrics=create_compute_metrics_fn(name), + callbacks=[ + LossLoggerCallback(log_dir, training_mode, name), + WandbCallback(name, training_mode), + ], + ) + + # Log model to wandb + if model_type in HF_MODELS: + wandb.watch(model, log="all", log_freq=100) + + # Train + start_time = time.time() + peak_memory = get_gpu_memory() + + trainer.train() + + # Evaluate on all validation datasets + eval_start_time = time.time() + multi_eval_results = evaluate_all_datasets(trainer, val_datasets, name) + eval_time = time.time() - eval_start_time + train_time = time.time() - start_time - eval_time + + # Track peak memory + current_memory = get_gpu_memory() + peak_memory = max(peak_memory, current_memory) if peak_memory > 0 else current_memory + + # Prepare comprehensive results + results = { + "model_name": name, + "model_type": model_type, + "training_mode": training_mode, + "peak_memory_mb": peak_memory, + "flops_giga": flops, + "train_time_seconds": train_time, + "eval_time_seconds": eval_time, + "eval_results_by_degradation": multi_eval_results, + # Summary statistics + "clean_accuracy": multi_eval_results['clean']['accuracy'], + "avg_degraded_accuracy": np.mean([ + res['accuracy'] for key, res in multi_eval_results.items() + if key != 'clean' + ]), + "robustness_gap": multi_eval_results['clean']['accuracy'] - multi_eval_results['jpeg_20']['accuracy'], + } + + # Log summary to wandb + wandb.log({ + "summary/clean_accuracy": results["clean_accuracy"], + "summary/avg_degraded_accuracy": results["avg_degraded_accuracy"], + "summary/robustness_gap": results["robustness_gap"], + }) + + # Save model + model_dir = os.path.join( + env_path("MODEL_DIR", "."), + f"{name}_{model_type}_{training_mode}_lr{learning_rate}_bs{config.batch_size}" + ) + save_model(model, model_info, model_dir, preprocessor) + + # Save as wandb artifact + artifact = wandb.Artifact( + name=f"{name}_{training_mode}_model", + type="model", + description=f"Trained {name} model with {model_type} architecture in {training_mode} mode" + ) + artifact.add_dir(model_dir) + wandb.log_artifact(artifact) + + # Finish wandb run + wandb.finish() + + # Clear GPU memory + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return results + +def main(config: TrainingConfig): + """Main training loop with both fine-tuning and linear probing.""" + # Setup environment + setup_environment() + + # Load dataset + print("Loading dataset...") + dataset = load_dataset( + "MKZuziak/ISIC_2019_224", + cache_dir=os.environ["HF_DATASETS_CACHE"], + split="train", + ) + print(f"Initial dataset size: {len(dataset)} images") + + # Slice dataset for debug purposes + dataset = dataset[:50] + + # Filter for desired classes + filtered_indices = [ + i for i, label in enumerate(dataset["label"]) + if str(label) in FILTERED_CLASSES + ] + dataset = dataset.select(filtered_indices) + print(f"After filtering: {len(dataset)} images") + + # Cast labels to correct number of classes + dataset = dataset.cast_column("label", ClassLabel(num_classes=NUM_FILTERED_CLASSES)) + + # Balance dataset + balanced_dataset = balance_dataset(dataset, FILTERED_CLASSES, config.num_train_images) + + # Split into train and validation + split_dataset = balanced_dataset.train_test_split( + test_size=0.2, + stratify_by_column="label", + seed=42 + ) + train_dataset = split_dataset["train"] + val_dataset = split_dataset["test"] + + print(f"Training samples: {len(train_dataset)}") + print(f"Validation samples: {len(val_dataset)}") + + # Get degradation transforms + degradation_transforms = get_degradation_transforms() + + # Select models to train + models = [m for m in MODEL_REGISTRY if m["name"] in ["vit"]] # Modify as needed + + # Store all results + all_results = { + "finetune": {}, + "linear_probe": {} + } + + # Train each model with both strategies + for model_info in models: + model_name = model_info["name"] + + # Fine-tuning + try: + print(f"\n{'='*60}") + print(f"FINE-TUNING: {model_name}") + print(f"{'='*60}") + + finetune_results = train_model( + model_info, + train_dataset, + val_dataset, + config, + degradation_transforms, + training_mode="finetune" + ) + all_results["finetune"][model_name] = finetune_results + + except Exception as e: + print(f"Error fine-tuning {model_name}: {e}") + all_results["finetune"][model_name] = {"error": str(e)} + + # Linear probing + try: + print(f"\n{'='*60}") + print(f"LINEAR PROBING: {model_name}") + print(f"{'='*60}") + + probe_results = train_model( + model_info, + train_dataset, + val_dataset, + config, + degradation_transforms, + training_mode="linear_probe" + ) + all_results["linear_probe"][model_name] = probe_results + + except Exception as e: + print(f"Error linear probing {model_name}: {e}") + all_results["linear_probe"][model_name] = {"error": str(e)} + + # Save comprehensive results + output_filename = ( + f"results_comprehensive_lr{config.learning_rate}_" + f"bs{config.batch_size}_ep{config.num_epochs}.json" + ) + save_results( + all_results, + os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), output_filename) + ) + + # Print summary comparison + print_results_summary(all_results) + + print("\n" + "="*60) + print("Training complete!") + print("="*60) + +def print_results_summary(results: Dict[str, Any]): + """Print a formatted summary of results.""" + print("\n" + "="*60) + print("RESULTS SUMMARY") + print("="*60) + + # Create comparison table + print("\nClean Accuracy Comparison:") + print("-" * 40) + print(f"{'Model':<15} {'Fine-tune':<12} {'Linear Probe':<12}") + print("-" * 40) + + for model_name in results["finetune"].keys(): + ft_acc = results["finetune"][model_name].get("clean_accuracy", 0) + lp_acc = results["linear_probe"][model_name].get("clean_accuracy", 0) + print(f"{model_name:<15} {ft_acc:<12.3f} {lp_acc:<12.3f}") + + print("\nRobustness (Clean - JPEG20 Accuracy):") + print("-" * 40) + print(f"{'Model':<15} {'Fine-tune':<12} {'Linear Probe':<12}") + print("-" * 40) + + for model_name in results["finetune"].keys(): + ft_rob = results["finetune"][model_name].get("robustness_gap", 0) + lp_rob = results["linear_probe"][model_name].get("robustness_gap", 0) + print(f"{model_name:<15} {ft_rob:<12.3f} {lp_rob:<12.3f}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Model comparison with fine-tuning and linear probing") + parser.add_argument('--resolution', type=int, default=224, + help='Input image resolution (default: 224)') + parser.add_argument('--batch_size', type=int, default=128, + help='Batch size for training (default: 128)') + parser.add_argument('--num_train_images', type=int, default=500, + help='Number of training images (default: 500)') + parser.add_argument('--num_epochs', type=int, default=3, + help='Number of training epochs (default: 3)') + parser.add_argument('--eval_steps', type=int, default=100, + help='Steps between evaluations (default: 100)') + parser.add_argument('--learning_rate', type=float, default=1e-4, + help='Learning rate (default: 1e-4)') + parser.add_argument('--mode', type=str, default='both', + choices=['finetune', 'linear_probe', 'both'], + help='Training mode (default: both)') + + args = parser.parse_args() + + config = TrainingConfig( + num_train_images=args.num_train_images, + resolution=args.resolution, + batch_size=args.batch_size, + num_epochs=args.num_epochs, + eval_steps=args.eval_steps, + learning_rate=args.learning_rate, + ) + + main(config) diff --git a/src/wrappers/distill.py b/src/wrappers/distill.py deleted file mode 100644 index e69de29..0000000 From f4784604369ac7f89ea7bb83455222787f72a638 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 17 Oct 2025 19:21:10 -0700 Subject: [PATCH 18/26] update demo --- scripts/demo_transformation.py | 130 +++++++++++++++++-------------- src/transformation/transforms.py | 61 ++++++++------- 2 files changed, 105 insertions(+), 86 deletions(-) diff --git a/scripts/demo_transformation.py b/scripts/demo_transformation.py index 080dbd7..a392f9e 100644 --- a/scripts/demo_transformation.py +++ b/scripts/demo_transformation.py @@ -1,14 +1,16 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# pylint: disable=broad-exception-caught +# pylint: disable=broad-exception-caught,wrong-import-position,import-error """ -Demo script to apply ResolutionReductionTransform to ISIC dataset images and save results. +Apply ResolutionReductionTransform to a few ISIC images and save: +- original_.png +- resolution_reduced_.png (reduced-size, not upsampled) """ import sys from pathlib import Path -from PIL import Image +from typing import Iterable, Optional # Add src to Python path for imports current_dir = Path(__file__).parent @@ -19,81 +21,89 @@ from src.data.isic_loader import ISICBaseDataset from src.transformation.transforms import ResolutionReductionTransform -def main(): + +def load_isic_dataset() -> Optional[ISICBaseDataset]: + """Load ISIC dataset from HuggingFace.""" print("🔄 Loading ISIC dataset from HuggingFace...") try: hf_dataset = load_dataset("MKZuziak/ISIC_2019_224", split="train") - isic_dataset = ISICBaseDataset(hf_dataset) - print(f"✅ Loaded {len(isic_dataset)} samples") + ds = ISICBaseDataset(hf_dataset) + print(f"✅ Loaded {len(ds)} samples") + return ds except Exception as e: print(f"❌ Failed to load dataset: {e}") print("Make sure you have 'datasets' installed: pip install datasets") - return - - # Create output directory - output_dir = Path("outputs") - output_dir.mkdir(exist_ok=True) + return None + + +def process_images( + dataset: ISICBaseDataset, + indices: Iterable[int], + transform: ResolutionReductionTransform, + output_dir: Path, + save_prefix: str = "resolution_reduced", +) -> int: + """ + For each index: save original and one reduced image using `transform`. + Assumes transform returns the reduced-size image (no upsample). + """ + output_dir.mkdir(parents=True, exist_ok=True) print(f"📁 Output directory: {output_dir.absolute()}") - print("\nApplying ResolutionReductionTransform to first 5 images...") - - # Create the transform - resolution_transform = ResolutionReductionTransform() # Random reduction factor - - # Process first 5 images - num_images = min(5, len(isic_dataset)) - - for i in range(num_images): - print(f"\nProcessing image {i+1}/{num_images}...") - + processed = 0 + for idx in indices: try: - # Get the original image - sample = isic_dataset[i] - original_image = sample["image"] - label = sample["label"] - - print(f"Original size: {original_image.size}, Label: {label}") + sample = dataset[idx] + original = sample["image"] + label = sample.get("label", None) + print(f"\n🖼️ Image {idx} | Original size: {original.size} | Label: {label}") - # Save original image - original_path = output_dir / f"original_{i}.png" - original_image.save(original_path) - print(f" 💾 Saved original: {original_path}") + # Save original + orig_path = output_dir / f"original_{idx}.png" + original.save(orig_path) + print(f" 💾 Saved original → {orig_path.name}") - # Apply resolution reduction transform - transformed_image = resolution_transform(original_image) - print(f" 🔄 Transformed size: {transformed_image.size}") + # Save reduced (actual transformed size) + reduced = transform(original) + out_path = output_dir / f"{save_prefix}_{idx}.png" + reduced.save(out_path) + print(f" 🔧 Reduced to: {reduced.size}") + print(f" 💾 Saved reduced → {out_path.name}") - # Save transformed image - transformed_path = output_dir / f"resolution_reduced_{i}.png" - transformed_image.save(transformed_path) - print(f" 💾 Saved transformed: {transformed_path}") + processed += 1 except Exception as e: - print(f" ❌ Error processing image {i}: {e}") - continue + print(f" ❌ Error at index {idx}: {e}") - print(f"\n✅ All images saved to: {output_dir.absolute()}") - - # Show what reduction factors were used (they're random) - print("\nTesting with fixed reduction factors...") - for factor in [0.25, 0.5, 0.75]: - print(f"\nTesting reduction factor: {factor}") - try: - fixed_transform = ResolutionReductionTransform(reduction_factor=factor) + return processed - # Use first image for this demo - sample = isic_dataset[0] - original_image = sample["image"] - transformed = fixed_transform(original_image) - output_path = output_dir / f"fixed_reduction_{factor}_{0}.png" - transformed.save(output_path) - print(f" 💾 Saved: {output_path}") +def main(): + ds = load_isic_dataset() + if ds is None: + return - except Exception as e: - print(f" ❌ Error with factor {factor}: {e}") + output_dir = Path("outputs") + num_images = min(5, len(ds)) + indices = range(num_images) + + # Use target size and DO NOT upsample back in the transform + # Ensure your class has restore_original_size=False (default) as we discussed. + resolution_transform = ResolutionReductionTransform( + target_resolution=(54, 54), + restore_original_size=False + ) + + print("\n▶️ Saving originals and reduced versions for first 5 images...") + n = process_images( + dataset=ds, + indices=indices, + transform=resolution_transform, + output_dir=output_dir, + save_prefix="resolution_reduced", + ) + print(f"\n✅ Done. Saved {n} reduced images. Check: {output_dir.absolute()}") - print(f"\n🎉 Demo completed! Check {output_dir.absolute()} for results.") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/transformation/transforms.py b/src/transformation/transforms.py index fd39fe3..394977a 100644 --- a/src/transformation/transforms.py +++ b/src/transformation/transforms.py @@ -4,41 +4,50 @@ import numpy as np from PIL import Image, ImageFilter -class ResolutionReductionTransform: # pylint: disable=too-few-public-methods - """Reduce spatial resolution of images.""" +from typing import Optional, Tuple +import numpy as np +from PIL import Image - def __init__(self, reduction_factor: Optional[float] = None): - """ - Args: - reduction_factor: Factor to reduce resolution by (0.1-1.0). - For example, 0.5 reduces to half resolution. - If None, random factor is used. - """ +from typing import Optional, Tuple +import numpy as np +from PIL import Image + +class ResolutionReductionTransform: + """Reduce image resolution by factor or target resolution.""" + + def __init__( + self, + reduction_factor: Optional[float] = None, + target_resolution: Optional[Tuple[int, int]] = None, + restore_original_size: bool = False, + ): self.reduction_factor = reduction_factor + self.target_resolution = target_resolution + self.restore_original_size = restore_original_size def __call__(self, img: Image.Image) -> Image.Image: - """Apply resolution reduction.""" - if self.reduction_factor is None: - # Random reduction factor between 0.2 and 0.8 - reduction_factor = np.random.uniform(0.2, 0.8) + ow, oh = img.size + + if self.target_resolution is not None: + nw, nh = self.target_resolution else: - reduction_factor = self.reduction_factor + factor = ( + np.random.uniform(0.2, 0.8) + if self.reduction_factor is None + else self.reduction_factor + ) + factor = max(0.1, min(1.0, factor)) + nw, nh = max(1, int(ow * factor)), max(1, int(oh * factor)) - # Clamp reduction factor to valid range - reduction_factor = max(0.1, min(1.0, reduction_factor)) + # Downsample + reduced = img.resize((nw, nh), Image.Resampling.LANCZOS) - # Calculate new size - original_width, original_height = img.size - new_width = int(original_width * reduction_factor) - new_height = int(original_height * reduction_factor) + # Either return the reduced image as-is, or restore to original size + if self.restore_original_size: + return reduced.resize((ow, oh), Image.Resampling.LANCZOS) + return reduced - # Ensure minimum size of 1x1 - new_width = max(1, new_width) - new_height = max(1, new_height) - # Downsample and then upsample back to original size - downsampled = img.resize((new_width, new_height), Image.Resampling.LANCZOS) - return downsampled.resize((original_width, original_height), Image.Resampling.LANCZOS) class JPEGCompressionTransform: # pylint: disable=too-few-public-methods """Apply JPEG compression to images.""" From 6cea1475620707bc4f0663caca96ed358e047007 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 17 Oct 2025 19:25:00 -0700 Subject: [PATCH 19/26] add license --- scripts/demo_transformation.py | 6 +++ src/cli/train.py | 6 +++ src/config.py | 11 +++- src/data/data_utils.py | 6 +++ src/data/datamodule.py | 6 +++ src/data/isic_loader.py | 6 +++ src/engines/train.py | 7 +++ src/evaluation/metrics.py | 13 +++++ src/evaluation/visualization.py | 6 +++ src/evaluation/visualize_results.py | 84 +++++++++++++++-------------- src/losses/classification.py | 6 +++ src/losses/distillation.py | 6 +++ src/models/factory.py | 6 +++ src/transformation/transforms.py | 6 +++ src/utils/callbacks_hf.py | 6 +++ src/utils/constants.py | 6 +++ src/utils/logging.py | 6 +++ src/utils/optim.py | 7 +++ src/utils/training_utils.py | 6 +++ src/utils/utils.py | 6 +++ src/wrappers/finetune.py | 6 +++ src/wrappers/probe.py | 6 +++ 22 files changed, 183 insertions(+), 41 deletions(-) diff --git a/scripts/demo_transformation.py b/scripts/demo_transformation.py index a392f9e..3eebc86 100644 --- a/scripts/demo_transformation.py +++ b/scripts/demo_transformation.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + #!/usr/bin/env python3 # -*- coding: utf-8 -*- # pylint: disable=broad-exception-caught,wrong-import-position,import-error diff --git a/src/cli/train.py b/src/cli/train.py index 5fe1460..462d7d3 100644 --- a/src/cli/train.py +++ b/src/cli/train.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/cli/train.py import os import sys diff --git a/src/config.py b/src/config.py index 88abf6f..10a4b6c 100644 --- a/src/config.py +++ b/src/config.py @@ -1,3 +1,10 @@ + +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + """Configuration and constants.""" from dataclasses import dataclass from typing import List, Dict, Any @@ -28,11 +35,11 @@ class TrainingConfig: eval_steps: int = 10 learning_rate: float = 1e-4 weight_decay: float = 0.01 - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for logging.""" return self.__dict__ - + def to_wandb_config(self) -> Dict[str, Any]: """Create wandb configuration.""" import torch diff --git a/src/data/data_utils.py b/src/data/data_utils.py index e7d0b75..5d28f9d 100644 --- a/src/data/data_utils.py +++ b/src/data/data_utils.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + """Dataset implementations and data utilities.""" import numpy as np import torch diff --git a/src/data/datamodule.py b/src/data/datamodule.py index b2e9a02..70321d3 100644 --- a/src/data/datamodule.py +++ b/src/data/datamodule.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/data/datamodule.py # -*- coding: utf-8 -*- from torch.utils.data import DataLoader, random_split diff --git a/src/data/isic_loader.py b/src/data/isic_loader.py index ee87d09..2a52569 100644 --- a/src/data/isic_loader.py +++ b/src/data/isic_loader.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/data/isic_loader.py from typing import Any, Dict, Union from torch.utils.data import Dataset, Subset diff --git a/src/engines/train.py b/src/engines/train.py index 594f720..2608594 100644 --- a/src/engines/train.py +++ b/src/engines/train.py @@ -1,3 +1,10 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + + """Main training script.""" import os import time diff --git a/src/evaluation/metrics.py b/src/evaluation/metrics.py index 459c275..de693be 100644 --- a/src/evaluation/metrics.py +++ b/src/evaluation/metrics.py @@ -1,3 +1,16 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + +"""Evaluation metrics for classification tasks.""" +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/metrics/metrics.py # -*- coding: utf-8 -*- from __future__ import annotations diff --git a/src/evaluation/visualization.py b/src/evaluation/visualization.py index 913f99d..216414b 100644 --- a/src/evaluation/visualization.py +++ b/src/evaluation/visualization.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/metrics/visualization.py # -*- coding: utf-8 -*- from __future__ import annotations diff --git a/src/evaluation/visualize_results.py b/src/evaluation/visualize_results.py index 3e3fb53..2da32d1 100644 --- a/src/evaluation/visualize_results.py +++ b/src/evaluation/visualize_results.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/visualize_results.py """Visualize robustness results across models and degradations.""" import json @@ -9,138 +15,138 @@ def create_robustness_heatmap(results_path: str): """Create heatmap showing model performance across degradations.""" - + # Load results with open(results_path, 'r') as f: results = json.load(f) - + # Prepare data for heatmap models = [] degradations = [] accuracies = [] - + for training_mode in ['finetune', 'linear_probe']: for model_name, model_results in results[training_mode].items(): if 'error' not in model_results: eval_results = model_results['eval_results_by_degradation'] - + for degradation, metrics in eval_results.items(): models.append(f"{model_name}_{training_mode[:2]}") degradations.append(degradation) accuracies.append(metrics['accuracy']) - + # Create DataFrame df = pd.DataFrame({ 'Model': models, 'Degradation': degradations, 'Accuracy': accuracies }) - + # Pivot for heatmap pivot_df = df.pivot(index='Model', columns='Degradation', values='Accuracy') - + # Reorder columns logically - column_order = ['clean', + column_order = ['clean', 'jpeg_90', 'jpeg_50', 'jpeg_20', 'blur_1.0', 'blur_3.0', 'blur_5.0', 'color_64', 'color_16', 'color_4'] pivot_df = pivot_df[column_order] - + # Create figure plt.figure(figsize=(14, 8)) - + # Create heatmap - sns.heatmap(pivot_df, - annot=True, - fmt='.3f', + sns.heatmap(pivot_df, + annot=True, + fmt='.3f', cmap='RdYlGn', - vmin=0.5, + vmin=0.5, vmax=1.0, cbar_kws={'label': 'Accuracy'}) - + plt.title('Model Robustness Across Different Degradations') plt.xlabel('Degradation Type') plt.ylabel('Model (ft=finetune, lp=linear probe)') plt.tight_layout() - + # Save figure output_path = Path(results_path).parent / 'robustness_heatmap.png' plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.show() - + return pivot_df def create_degradation_curves(results_path: str): """Create line plots showing accuracy degradation.""" - + with open(results_path, 'r') as f: results = json.load(f) - + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) - + # JPEG degradation jpeg_qualities = [90, 50, 20] ax = axes[0] - + for training_mode in ['finetune', 'linear_probe']: for model_name, model_results in results[training_mode].items(): if 'error' not in model_results: eval_results = model_results['eval_results_by_degradation'] - + clean_acc = eval_results['clean']['accuracy'] jpeg_accs = [eval_results[f'jpeg_{q}']['accuracy'] for q in jpeg_qualities] - + label = f"{model_name} ({training_mode})" - ax.plot([100] + jpeg_qualities, [clean_acc] + jpeg_accs, + ax.plot([100] + jpeg_qualities, [clean_acc] + jpeg_accs, marker='o', label=label) - + ax.set_xlabel('JPEG Quality') ax.set_ylabel('Accuracy') ax.set_title('JPEG Compression Robustness') ax.legend() ax.grid(True, alpha=0.3) ax.invert_xaxis() - + # Blur degradation blur_radii = [0, 1.0, 3.0, 5.0] ax = axes[1] - + for training_mode in ['finetune', 'linear_probe']: for model_name, model_results in results[training_mode].items(): if 'error' not in model_results: eval_results = model_results['eval_results_by_degradation'] - + blur_accs = [] blur_accs.append(eval_results['clean']['accuracy']) for r in blur_radii[1:]: blur_accs.append(eval_results[f'blur_{r:.1f}']['accuracy']) - + label = f"{model_name} ({training_mode})" ax.plot(blur_radii, blur_accs, marker='o', label=label) - + ax.set_xlabel('Blur Radius') ax.set_ylabel('Accuracy') ax.set_title('Gaussian Blur Robustness') ax.legend() ax.grid(True, alpha=0.3) - + # Color quantization color_levels = [256, 64, 16, 4] ax = axes[2] - + for training_mode in ['finetune', 'linear_probe']: for model_name, model_results in results[training_mode].items(): if 'error' not in model_results: eval_results = model_results['eval_results_by_degradation'] - + color_accs = [] color_accs.append(eval_results['clean']['accuracy']) for c in color_levels[1:]: color_accs.append(eval_results[f'color_{c}']['accuracy']) - + label = f"{model_name} ({training_mode})" ax.plot(color_levels, color_accs, marker='o', label=label) - + ax.set_xlabel('Number of Colors') ax.set_ylabel('Accuracy') ax.set_title('Color Quantization Robustness') @@ -148,10 +154,10 @@ def create_degradation_curves(results_path: str): ax.grid(True, alpha=0.3) ax.set_xscale('log', base=2) ax.invert_xaxis() - + plt.suptitle('Model Robustness to Different Degradation Types') plt.tight_layout() - + # Save figure output_path = Path(results_path).parent / 'degradation_curves.png' plt.savefig(output_path, dpi=300, bbox_inches='tight') @@ -160,11 +166,11 @@ def create_degradation_curves(results_path: str): # Usage if __name__ == "__main__": results_path = "results/results_comprehensive_lr0.0001_bs128_ep3.json" - + # Create visualizations pivot_df = create_robustness_heatmap(results_path) create_degradation_curves(results_path) - + # Print summary statistics print("\nModel Rankings by Robustness:") print("-" * 40) diff --git a/src/losses/classification.py b/src/losses/classification.py index 4168dee..af9022d 100644 --- a/src/losses/classification.py +++ b/src/losses/classification.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/losses/classification.py # -*- coding: utf-8 -*- from typing import Optional diff --git a/src/losses/distillation.py b/src/losses/distillation.py index 27e985e..c6c4b7e 100644 --- a/src/losses/distillation.py +++ b/src/losses/distillation.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/losses/distillation.py # -*- coding: utf-8 -*- from typing import Dict diff --git a/src/models/factory.py b/src/models/factory.py index b7210a0..be9cce4 100644 --- a/src/models/factory.py +++ b/src/models/factory.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/models/factory.py # -*- coding: utf-8 -*- """Unified model factory: HF vision models (ViT/DINOv2), optional timm, diff --git a/src/transformation/transforms.py b/src/transformation/transforms.py index 394977a..1049fbb 100644 --- a/src/transformation/transforms.py +++ b/src/transformation/transforms.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + """Image transformation utilities.""" import io from typing import Optional diff --git a/src/utils/callbacks_hf.py b/src/utils/callbacks_hf.py index 630399a..ea8b3ba 100644 --- a/src/utils/callbacks_hf.py +++ b/src/utils/callbacks_hf.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/utils/callbacks_hf.py # -*- coding: utf-8 -*- from __future__ import annotations diff --git a/src/utils/constants.py b/src/utils/constants.py index 9f97406..a10cb33 100644 --- a/src/utils/constants.py +++ b/src/utils/constants.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/utils/constants.py # -*- coding: utf-8 -*- """Global constants and lightweight enums used across the training pipeline.""" diff --git a/src/utils/logging.py b/src/utils/logging.py index 5a5a69c..74c71da 100644 --- a/src/utils/logging.py +++ b/src/utils/logging.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/utils/logging.py # -*- coding: utf-8 -*- from __future__ import annotations diff --git a/src/utils/optim.py b/src/utils/optim.py index 7190423..e10bebb 100644 --- a/src/utils/optim.py +++ b/src/utils/optim.py @@ -1,3 +1,10 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + + # src/utils/optim.py # -*- coding: utf-8 -*- from __future__ import annotations diff --git a/src/utils/training_utils.py b/src/utils/training_utils.py index 9459c6c..2488e4a 100644 --- a/src/utils/training_utils.py +++ b/src/utils/training_utils.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/utils/training_utils.py # -*- coding: utf-8 -*- from __future__ import annotations diff --git a/src/utils/utils.py b/src/utils/utils.py index daa6c7b..6f15629 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + """General utilities for environment, GPU, and I/O operations.""" import os import json diff --git a/src/wrappers/finetune.py b/src/wrappers/finetune.py index f76c6fa..50816fb 100644 --- a/src/wrappers/finetune.py +++ b/src/wrappers/finetune.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + from typing import Dict, Any import torch diff --git a/src/wrappers/probe.py b/src/wrappers/probe.py index f1c7c12..d5ea342 100644 --- a/src/wrappers/probe.py +++ b/src/wrappers/probe.py @@ -1,3 +1,9 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT + # src/wrappers/probe.py # -*- coding: utf-8 -*- """Linear probing wrapper for training classification heads on frozen backbones.""" From 9f869d506b03dfd8c1efac7aa99360cc053cba8d Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Fri, 17 Oct 2025 19:28:06 -0700 Subject: [PATCH 20/26] add license --- .DS_Store | Bin 6148 -> 0 bytes requirements_demo.txt.license | 5 +++++ src/__init__.py | 3 --- src/requirements.txt | Bin 1716 -> 0 bytes 4 files changed, 5 insertions(+), 3 deletions(-) delete mode 100644 .DS_Store create mode 100644 requirements_demo.txt.license delete mode 100644 src/requirements.txt diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index d1b88db71b01a9339795b5bbeb2700f40e616132..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK!AiqG5Ph3gZ1Iv~kNE<@qqkU-LqR>*ACRVkLTa(5AfECi`ceHDkNRe3Es3dk z6Dcz=`!>7tW|OxelLg?$>-h-~0~pW+qam9eCiCJQi^9kO(fJ%#m?EE)%VJh`qHW+m zDj<7zhzm?`jvS5sTh2Hmuf;1p{qV*fFvK+@&yeAgc#H++y5nCA_s*!71|!A_w|K-F z_rAB_)dMO#kzL^#Rhu8xA0aIX46=2O48yqQGYYLbGrodJK`936c!8Bm4=sz87{1JfI=dd^S<)=|NNx(E< zt;iXgQz}uZE*vqO(m9@lxHMp`sC2k+_;6um7fvY7ug?654~I(?tu+NqfwlrYy^N*) z=RcqS+a$X(1x$gxQo!}IX*T8~h1%MAIH|P>{e~_kakb)>!j3P+jFnP6qI=_bA{}BH TuvX*<&3**D4Az(ef2zPIfu>j5 diff --git a/requirements_demo.txt.license b/requirements_demo.txt.license new file mode 100644 index 0000000..3cc951b --- /dev/null +++ b/requirements_demo.txt.license @@ -0,0 +1,5 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) +# +# SPDX-License-Identifier: MIT \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index 1a66f1d..e69de29 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,3 +0,0 @@ -""" -Root package for the project. -""" diff --git a/src/requirements.txt b/src/requirements.txt deleted file mode 100644 index 2d330480365fea94d5cdc9f0c7ab6c125710cd23..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1716 zcmZvcTaKGR5JmetQkLS6C|Sttz`};H@nB4V89x@ELrr>z0^UP55V4Z$$J<%!dCey<>0FE8her`_yl&EktOgV{vEp5h zRTJa&iyNBh)%uzvLtiTlPW$wONBYDpGlG_mbuzv2>0iSj2J@B>OP~*TZn`MviTBqGMVhvySSDksu{@*XlsOOEy zGp`fWp=Q~`e(!QW-!q1NR}1Z>KKK60z5=h$9IVeMzxHnp$^_FMj%CJsc2T*QniGFA zukf2W@V7;+3O1rf=d|`&)##aP=QyTgAkrLXZ=#F(f`^_i*Kl6YJG^tA1Sem;#fhDr zHkrxWSyUnQ3RMTP zoe?#ER^@TxS9w<_;+f-f#PcaqrQgIyuVSYOZtYU_3P#>M&aHKXuXN6HT``wdQv6Q8 zvsctBm@Vfq5PHyGdCD}Dn{2|Oj?p(~C+48k6}OO@mRl&!PMmlDt;Un*nY43TFrp`M zuQ+4gBk~B#maOO12lz56cc?dzlQ-Qb#`6+Y&1hG{CJ>BsQdDfdN7%LKSDcZFwf9oF zwIcTs^5fn2eaGK@leK_*h%*u`k6VMGxM6YyPVn8b6 Date: Mon, 20 Oct 2025 12:55:25 -0700 Subject: [PATCH 21/26] add demo --- scripts/demo_transformation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/demo_transformation.py b/scripts/demo_transformation.py index 3eebc86..ddf0b4d 100644 --- a/scripts/demo_transformation.py +++ b/scripts/demo_transformation.py @@ -62,12 +62,12 @@ def process_images( sample = dataset[idx] original = sample["image"] label = sample.get("label", None) - print(f"\n🖼️ Image {idx} | Original size: {original.size} | Label: {label}") + print(f"\nImage {idx} | Original size: {original.size} | Label: {label}") # Save original orig_path = output_dir / f"original_{idx}.png" original.save(orig_path) - print(f" 💾 Saved original → {orig_path.name}") + print(f"Saved original → {orig_path.name}") # Save reduced (actual transformed size) reduced = transform(original) @@ -94,7 +94,6 @@ def main(): indices = range(num_images) # Use target size and DO NOT upsample back in the transform - # Ensure your class has restore_original_size=False (default) as we discussed. resolution_transform = ResolutionReductionTransform( target_resolution=(54, 54), restore_original_size=False From 7951ce188f74a7ed36b1030fab18f29373f99ebd Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Mon, 20 Oct 2025 17:05:33 -0700 Subject: [PATCH 22/26] implement training task abstraction and wrappers level for training --- requirements.txt | 12 +- requirements_demo.txt | 20 -- requirements_demo.txt.license | 5 - src/cli/train.py | 138 ++++++-- src/engines/finetune_engine.py | 212 +++++++++++ src/engines/linear_probe_engine.py | 184 ++++++++++ src/engines/train.py | 523 ---------------------------- src/engines/utils/training_loops.py | 78 +++++ src/models/factory.py | 89 +++-- src/wrappers/finetune.py | 171 +++++++-- 10 files changed, 776 insertions(+), 656 deletions(-) delete mode 100644 requirements_demo.txt delete mode 100644 requirements_demo.txt.license create mode 100644 src/engines/finetune_engine.py create mode 100644 src/engines/linear_probe_engine.py delete mode 100644 src/engines/train.py create mode 100644 src/engines/utils/training_loops.py diff --git a/requirements.txt b/requirements.txt index e779364..a817c5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,7 @@ fonttools h11 httpcore httpx +hydra idna importlib_metadata ipykernel @@ -36,6 +37,7 @@ matplotlib matplotlib-inline nest_asyncio numpy>=1.26.0 +omegaconf openai packaging pandas>=2.2.0 @@ -65,11 +67,11 @@ stack-data thop threadpoolctl timm -torch -torchvision -tqdm +torch>=2.0.0 +torchvision>=0.15.0 +tqdm>=4.65.0 traitlets -transformers +transformers>=4.30.0 triton typing-inspect typing_extensions @@ -78,4 +80,4 @@ urllib3 wandb wcwidth wheel -zipp +zipp \ No newline at end of file diff --git a/requirements_demo.txt b/requirements_demo.txt deleted file mode 100644 index 4971a53..0000000 --- a/requirements_demo.txt +++ /dev/null @@ -1,20 +0,0 @@ -# Requirements for running the demo_transformation.py script locally - -# Core ML libraries -torch>=2.0.0 -torchvision>=0.15.0 - -# HuggingFace datasets for loading ISIC data -datasets>=2.14.0 - -# Image processing -Pillow>=9.0.0 - -# Data manipulation -numpy>=1.24.0 - -# Optional: For progress bars and better user experience -tqdm>=4.65.0 - -# Optional: For HuggingFace transformers (if using real preprocessors) -transformers>=4.30.0 diff --git a/requirements_demo.txt.license b/requirements_demo.txt.license deleted file mode 100644 index 3cc951b..0000000 --- a/requirements_demo.txt.license +++ /dev/null @@ -1,5 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT \ No newline at end of file diff --git a/src/cli/train.py b/src/cli/train.py index 462d7d3..5a830f4 100644 --- a/src/cli/train.py +++ b/src/cli/train.py @@ -5,6 +5,15 @@ # SPDX-License-Identifier: MIT # src/cli/train.py + +"""CLI entry point for training models with different paradigms (probe/finetune/distill). + +Usage: +python -m src.cli.train train.mode=probe \ + dataset.name=isic2019 dataset.image_size=224 dataset.batch_size=128 \ + model.type=vit model.model_id=google/vit-base-patch16-224 +""" + import os import sys import time @@ -19,28 +28,25 @@ import hydra from omegaconf import DictConfig, OmegaConf -# ---- Optional: import wrappers (these implement run(cfg) and return a dict of metrics) +# ---- Wrappers implement run(cfg) and return a dict of metrics from src.wrappers import probe as probe_wrapper from src.wrappers import finetune as finetune_wrapper -from src.wrappers import distill as distill_wrapper log = logging.getLogger("train") +# ------------------------- helpers ------------------------- + + def _is_rank_zero() -> bool: - # Works for both torchrun (DDP) and single process local_rank = int(os.environ.get("LOCAL_RANK", "0")) return local_rank == 0 def _select_device(cfg: DictConfig) -> torch.device: - # Honor cfg.train.device if present; otherwise auto if "device" in cfg.train and cfg.train.device: - dev = cfg.train.device - return torch.device(dev) - if torch.cuda.is_available(): - return torch.device("cuda") - return torch.device("cpu") + return torch.device(cfg.train.device) + return torch.device("cuda" if torch.cuda.is_available() else "cpu") def _seed_everything(seed: int, deterministic: bool = False): @@ -59,7 +65,7 @@ def _save_resolved_config(cfg: DictConfig, run_dir: Path): if not _is_rank_zero(): return run_dir.mkdir(parents=True, exist_ok=True) - with open(run_dir / "resolved_config.yaml", "w") as f: + with open(run_dir / "resolved_config.yaml", "w", encoding="utf-8") as f: OmegaConf.save(config=cfg, f=f.name) @@ -70,7 +76,7 @@ def _print_run_header(cfg: DictConfig, run_dir: Path, device: torch.device): f"\n=== TRAIN START ===\n" f"mode : {cfg.train.mode}\n" f"dataset : {cfg.dataset.name}\n" - f"model : {getattr(cfg.model, 'name', 'N/A')}\n" + f"model : {getattr(cfg.model, 'name', getattr(cfg.model, 'type', 'N/A'))}\n" f"device : {device}\n" f"seed : {cfg.seed}\n" f"run_dir : {str(run_dir)}\n" @@ -80,33 +86,97 @@ def _print_run_header(cfg: DictConfig, run_dir: Path, device: torch.device): def _dispatch_wrapper(cfg: DictConfig) -> Dict[str, Any]: - mode = cfg.train.mode.lower() + mode = str(cfg.train.mode).lower() if mode == "probe": return probe_wrapper.run(cfg) - elif mode == "finetune": + if mode == "finetune": return finetune_wrapper.run(cfg) - elif mode == "distill": - return distill_wrapper.run(cfg) - else: - raise ValueError( - f"Unknown train.mode='{cfg.train.mode}'. " - f"Expected one of: probe | finetune | distill" - ) + # if mode == "distill": + # return distill_wrapper.run(cfg) + raise ValueError( + f"Unknown train.mode='{cfg.train.mode}'. " + f"Expected one of: probe | finetune | distill" + ) +def _ensure_keys(d: DictConfig, keys_defaults: Dict[str, Any]) -> None: + """Ensure keys exist in DictConfig with defaults (in-place).""" + for k, v in keys_defaults.items(): + if k not in d or d[k] is None: + d[k] = v + + +def _normalize_dataset_into_data(cfg: DictConfig) -> None: + """ + Map cfg.dataset → cfg.data so wrappers and datamodules get a consistent schema. + + Expected cfg.dataset fields (typical): + - name: str (e.g., "isic2019", "chexpert") + - data_dir: str or null (some loaders use HF hub instead) + - image_size: int + - num_classes: int + - batch_size: int + - num_workers: int + - pin_memory: bool (optional) + + After this, cfg.data will contain: + dataset_name, data_dir, image_size, batch_size, num_workers, pin_memory + """ + if "dataset" not in cfg: + raise ValueError("Config is missing 'dataset' group (cfg.dataset.*).") + + ds = cfg.dataset + # Create cfg.data if missing + if "data" not in cfg or cfg.data is None: + cfg.data = OmegaConf.create() + + # Copy/normalize fields + cfg.data.dataset_name = str(ds.get("name")) + cfg.data.data_dir = ds.get("data_dir") # can be None for HF datasets + cfg.data.image_size = int(ds.get("image_size", 224)) + cfg.data.batch_size = int( + ds.get("batch_size", getattr(cfg.train, "batch_size", 64)) + ) + cfg.data.num_workers = int(ds.get("num_workers", 4)) + cfg.data.pin_memory = bool(ds.get("pin_memory", True)) + + # Optional: enforce/propagate num_classes to model config + num_classes = ds.get("num_classes", None) + if num_classes is not None: + if "model" not in cfg: + cfg.model = OmegaConf.create() + if "config" not in cfg.model or cfg.model.config is None: + cfg.model.config = OmegaConf.create() + # Many factories look for "num_labels" + cfg.model.config.num_labels = int(num_classes) + + # Sanity checks + if not cfg.data.dataset_name: + raise ValueError("cfg.dataset.name must be set (e.g., 'isic2019').") + + +# ------------------------- main ------------------------- + @hydra.main(config_path="../../configs", config_name="defaults", version_base=None) def main(cfg: DictConfig): - # ---- Resolve and freeze config for safety - OmegaConf.set_struct(cfg, False) # allow wrappers to attach runtime fields if needed + """Main training CLI entry point.""" + # Allow wrappers to attach runtime fields if needed + OmegaConf.set_struct(cfg, False) + + # Normalize dataset selection into cfg.data for wrappers/datamodules + _normalize_dataset_into_data(cfg) - # ---- Set up run directory (Hydra changes CWD to a unique run dir automatically) + # Set up run directory (Hydra sets CWD to the unique run dir) run_dir = Path(os.getcwd()) - # ---- Device & seeding + # Device & seeding device = _select_device(cfg) - _seed_everything(seed=int(cfg.seed), deterministic=getattr(cfg.train, "deterministic", False)) + _seed_everything( + seed=int(cfg.seed), + deterministic=bool(getattr(cfg.train, "deterministic", False)), + ) - # ---- Optional: attach runtime info for wrappers/engines + # Attach runtime info used by wrappers/engines cfg.runtime = { "device": str(device), "start_time": time.strftime("%Y-%m-%d %H:%M:%S"), @@ -115,11 +185,11 @@ def main(cfg: DictConfig): "world_size": int(os.environ.get("WORLD_SIZE", "1")), } - # ---- Save resolved config + # Persist resolved config and print header _save_resolved_config(cfg, run_dir) _print_run_header(cfg, run_dir, device) - # ---- Kick off the selected training paradigm via wrapper + # Kick off the selected training paradigm try: metrics = _dispatch_wrapper(cfg) except KeyboardInterrupt: @@ -127,19 +197,21 @@ def main(cfg: DictConfig): print("\n⚠️ Training interrupted by user.", flush=True) raise except Exception as e: - # Surface a readable error at rank zero; still propagate for proper exit codes if _is_rank_zero(): print(f"\n❌ Training failed: {e}\n", flush=True) raise - # ---- Persist final metrics + # Save final metrics if _is_rank_zero(): metrics = metrics or {} - with open(run_dir / "final_metrics.json", "w") as f: + with open(run_dir / "final_metrics.json", "w", encoding="utf-8") as f: json.dump(metrics, f, indent=2) - print("✅ Train done. Final metrics written to final_metrics.json\n", flush=True) + print( + "✅ Train done. Final metrics written to final_metrics.json\n", flush=True + ) if __name__ == "__main__": - # Support `python -m src.cli.train train=probe dataset=isic2019` + # Example: python -m src.cli.train train.mode=probe dataset.name=isic2019 + # pylint: disable=no-value-for-parameter sys.exit(main()) diff --git a/src/engines/finetune_engine.py b/src/engines/finetune_engine.py new file mode 100644 index 0000000..b7c2333 --- /dev/null +++ b/src/engines/finetune_engine.py @@ -0,0 +1,212 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University +# SPDX-License-Identifier: MIT + +# -*- coding: utf-8 -*- +"""Training engine for full fine-tuning (end-to-end).""" +from __future__ import annotations +from typing import Dict, Any, Tuple, Optional + +import math +import torch +from torch import nn + +try: + # PyTorch 2.0+ unified AMP API + from torch.amp import autocast, GradScaler +except ImportError: + # Fallback for older PyTorch versions + from torch.cuda.amp import autocast, GradScaler + +# pylint: disable=import-error +from src.utils.logging import get_logger +from src.engines.utils.training_loops import ( + _maybe_scheduler_step, + _evaluate, +) + +log = get_logger(__name__) + + +def train_finetune( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + *, + model: nn.Module, + loaders: Dict[str, Any], # {"train": DataLoader, "val": DataLoader} + loss_fn, + optimizer: torch.optim.Optimizer, + scheduler: Optional[Tuple[Any, Dict[str, Any]]] = None, # (scheduler, meta) + device: torch.device, + epochs: int, + grad_clip: Optional[float] = None, + mixed_precision: bool = True, + log_interval: int = 50, + wandb_logger=None, + metric_key: str = "val_acc", + accumulation_steps: int = 1, + zero_grad_set_to_none: bool = True, +) -> Dict[str, Any]: + """ + Generic engine for full fine-tuning. Agnostic to datasets & transforms. + + Args: + model: torch.nn.Module to train end-to-end. + loaders: dict with "train" and "val" DataLoaders. + loss_fn: callable (logits, targets) -> loss tensor. + optimizer: torch optimizer over model parameters. + scheduler: optional (scheduler, meta) where meta may contain: + - step_per: {"batch","epoch","val"} (default "epoch") + - monitor: metric value for schedulers like ReduceLROnPlateau (filled automatically) + device: torch.device. + epochs: number of epochs. + grad_clip: optional float, max norm for gradient clipping. + mixed_precision: bool, enable autocast + GradScaler. + log_interval: int steps for logging. + wandb_logger: object with .log(dict) and optional .watch(...) methods. + metric_key: "val_acc" (maximize) or "...loss" (minimize) to pick best checkpoint. + accumulation_steps: gradient accumulation steps (>1 to simulate larger batch). + zero_grad_set_to_none: pass to optimizer.zero_grad for perf. + + Returns: + dict with "best_metric", "history", and "final_lr". + """ + assert accumulation_steps >= 1, "accumulation_steps must be >= 1" + + # Initialize GradScaler with backward compatibility + try: + scaler = GradScaler( + enabled=mixed_precision, + init_scale=2.0**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + ) + except TypeError: + scaler = GradScaler(enabled=mixed_precision) + + sched, sched_meta = scheduler or (None, {}) + best_metric = -math.inf if not metric_key.endswith("loss") else math.inf + best_state = None + + history = {"train_loss": [], "val_loss": [], "val_acc": [], "lr": []} + + for epoch in range(1, epochs + 1): + model.train() + running_loss, n_seen = 0.0, 0 + optimizer.zero_grad(set_to_none=zero_grad_set_to_none) + + for step, batch in enumerate(loaders["train"], start=1): + if isinstance(batch, dict): + x, y = batch.get("image"), batch.get("label") + else: + x, y = batch + x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) + + with autocast(device_type=device.type, enabled=mixed_precision): + logits = model(x) + loss = loss_fn(logits, y) + loss_to_backprop = loss / accumulation_steps + + if mixed_precision: + scaler.scale(loss_to_backprop).backward() + else: + loss_to_backprop.backward() + + # Step on accumulation boundary + if step % accumulation_steps == 0: + if grad_clip is not None: + if mixed_precision: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + + if mixed_precision: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + + optimizer.zero_grad(set_to_none=zero_grad_set_to_none) + + # Per-batch scheduler step (if configured) + if sched is not None: + _maybe_scheduler_step(sched_meta, sched, on="batch") + + bsz = y.size(0) + running_loss += float(loss.item()) * bsz + n_seen += bsz + + if step % log_interval == 0: + cur_lr = optimizer.param_groups[0]["lr"] + if wandb_logger: + wandb_logger.log({"train/loss": float(loss.item()), "lr": cur_lr}) + + # ---- validation + val_loss, val_acc = _evaluate( + model=model, + loader=loaders["val"], + loss_fn=loss_fn, + device=device, + mixed_precision=mixed_precision, + ) + + # Epoch or val-based scheduler step + if sched is not None: + if isinstance(sched, torch.optim.lr_scheduler.ReduceLROnPlateau): + sched_meta["monitor"] = ( + val_loss if metric_key.endswith("loss") else val_acc + ) + _maybe_scheduler_step(sched_meta, sched, on="val") + else: + _maybe_scheduler_step(sched_meta, sched, on="epoch") + + # Aggregate + log + train_loss = running_loss / max(n_seen, 1) + cur_lr = optimizer.param_groups[0]["lr"] + + history["train_loss"].append(train_loss) + history["val_loss"].append(val_loss) + history["val_acc"].append(val_acc) + history["lr"].append(cur_lr) + + if wandb_logger: + wandb_logger.log( + { + "epoch": epoch, + "train/loss_epoch": train_loss, + "val/loss": val_loss, + "val/acc": val_acc, + "lr": cur_lr, + } + ) + + log.info( + "Epoch %d | train_loss=%.4f | val_loss=%.4f | val_acc=%.4f | lr=%.2e", + epoch, + train_loss, + val_loss, + val_acc, + cur_lr, + ) + + monitor = val_loss if metric_key.endswith("loss") else val_acc + is_better = ( + (monitor < best_metric) + if metric_key.endswith("loss") + else (monitor > best_metric) + ) + + if is_better: + best_metric = monitor + best_state = { + k: v.detach().cpu().clone() for k, v in model.state_dict().items() + } + + # Restore best weights so caller can save/export + if best_state is not None: + model.load_state_dict(best_state) + + return { + "best_metric": best_metric, + "history": history, + "final_lr": optimizer.param_groups[0]["lr"], + } diff --git a/src/engines/linear_probe_engine.py b/src/engines/linear_probe_engine.py new file mode 100644 index 0000000..a2e74bd --- /dev/null +++ b/src/engines/linear_probe_engine.py @@ -0,0 +1,184 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University +# SPDX-License-Identifier: MIT + +# -*- coding: utf-8 -*- +"""Training engine for linear probing.""" +from __future__ import annotations +from typing import Dict, Any, Tuple, Optional + +import math +import torch +from torch import nn + +# --- AMP import (robust across versions) --- +try: + # PyTorch 2.0+ unified AMP API + from torch.amp import autocast, GradScaler +except ImportError: + # Fallback for older PyTorch versions + from torch.cuda.amp import autocast, GradScaler + +# pylint: disable=import-error +from src.utils.logging import get_logger +from src.engines.utils.training_loops import ( + _maybe_scheduler_step, + _evaluate, +) + +log = get_logger(__name__) + + +def train_probe( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + *, + model: nn.Module, + loaders: Dict[str, Any], # {"train": DataLoader, "val": DataLoader} + loss_fn, + optimizer: torch.optim.Optimizer, + scheduler: Optional[Tuple[Any, Dict[str, Any]]] = None, # (scheduler, meta) + device: torch.device, + epochs: int, + grad_clip: Optional[float] = None, + mixed_precision: bool = True, + log_interval: int = 50, + wandb_logger=None, + metric_key: str = "val_acc", +) -> Dict[str, Any]: + """ + Generic engine for linear probing. Agnostic to dataset & transforms. + Returns a dict with best metric, histories, and final lr. + """ + model.train() + + # Initialize GradScaler safely across torch versions + try: + scaler = GradScaler( + enabled=mixed_precision, + init_scale=2.0**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + ) + except TypeError: + scaler = GradScaler(enabled=mixed_precision) + + sched, sched_meta = scheduler or (None, {}) + best_metric = -math.inf + best_state_dict = None + + history = {"train_loss": [], "val_loss": [], "val_acc": [], "lr": []} + + for epoch in range(1, epochs + 1): + model.train() + running_loss, n_seen = 0.0, 0 + + for step, batch in enumerate(loaders["train"], start=1): + if isinstance(batch, dict): + x, y = batch["image"], batch["label"] + else: + x, y = batch + x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) + + optimizer.zero_grad(set_to_none=True) + with autocast(device_type=device.type, enabled=mixed_precision): + logits = model(x) + loss = loss_fn(logits, y) + + if mixed_precision: + scaler.scale(loss).backward() + + if grad_clip is not None: + # Unscale before clipping + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + + # Step and update scaler + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + optimizer.step() + + if sched is not None: + _maybe_scheduler_step(sched_meta, sched, on="batch") + + running_loss += float(loss.item()) * y.size(0) + n_seen += y.size(0) + + if step % log_interval == 0: + cur_lr = optimizer.param_groups[0]["lr"] + if wandb_logger: + wandb_logger.log({"train/loss": float(loss.item()), "lr": cur_lr}) + + train_loss = running_loss / max(n_seen, 1) + + # ---- validation + val_loss, val_acc = _evaluate( + model=model, + loader=loaders["val"], + loss_fn=loss_fn, + device=device, + mixed_precision=mixed_precision, + ) + + # scheduler on epoch or val metric + if sched is not None: + if isinstance(sched, torch.optim.lr_scheduler.ReduceLROnPlateau): + sched_meta["monitor"] = ( + val_loss if metric_key.endswith("loss") else val_acc + ) + _maybe_scheduler_step(sched_meta, sched, on="val") + else: + _maybe_scheduler_step(sched_meta, sched, on="epoch") + + # logging + cur_lr = optimizer.param_groups[0]["lr"] + history["train_loss"].append(train_loss) + history["val_loss"].append(val_loss) + history["val_acc"].append(val_acc) + history["lr"].append(cur_lr) + + if wandb_logger: + wandb_logger.log( + { + "epoch": epoch, + "train/loss_epoch": train_loss, + "val/loss": val_loss, + "val/acc": val_acc, + "lr": cur_lr, + } + ) + + log.info( + "Epoch %d | train_loss=%.4f | val_loss=%.4f | val_acc=%.4f | lr=%.2e", + epoch, + train_loss, + val_loss, + val_acc, + cur_lr, + ) + + monitor = val_loss if metric_key.endswith("loss") else val_acc + is_better = ( + (monitor < best_metric) + if metric_key.endswith("loss") + else (monitor > best_metric) + ) + if is_better: + best_metric = monitor + best_state_dict = { + k: v.detach().cpu().clone() for k, v in model.state_dict().items() + } + + # restore best (optional: caller can save now) + if best_state_dict is not None: + model.load_state_dict(best_state_dict) + + return { + "best_metric": best_metric, + "history": history, + "final_lr": optimizer.param_groups[0]["lr"], + } diff --git a/src/engines/train.py b/src/engines/train.py deleted file mode 100644 index 2608594..0000000 --- a/src/engines/train.py +++ /dev/null @@ -1,523 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - - -"""Main training script.""" -import os -import time -import argparse -import numpy as np -import torch -import wandb -from datasets import load_dataset, ClassLabel -from transformers import Trainer, TrainingArguments -from typing import Dict, Any - -from src.config import ( - TrainingConfig, MODEL_REGISTRY, FILTERED_CLASSES, - NUM_FILTERED_CLASSES, HF_MODELS -) -from src.utils import ( - setup_environment, env_path, get_gpu_memory, - check_disk_space, save_results -) -from src.models import ( - create_model, create_preprocessor, freeze_backbone, save_model -) -from src.data_utils import ( - ISICDataset, create_transformed_datasets, balance_dataset -) -from src.transformation.transforms import ( - get_degradation_transforms, ResolutionReductionTransform -) -from src.utils.training_utils import ( - LossLoggerCallback, - WandbCallback, - profile_model -) - -from src.evaluation.metrics import ( - create_compute_metrics_fn, -) - -def create_multi_validation_datasets( - val_dataset, - preprocessor, - resolution: int, - model_type: str -) -> Dict[str, Any]: - """ - Create validation datasets with different degradation levels. - - Returns: - Dictionary mapping degradation name to dataset - """ - val_datasets = {} - - # Clean (no degradation) - val_datasets['clean'] = ISICDataset( - val_dataset, - preprocessor, - resolution, - transform=None, - model_type=model_type - ) - - # JPEG compression at different quality levels - for quality in [90, 50, 20]: - val_datasets[f'jpeg_{quality}'] = ISICDataset( - val_dataset, - preprocessor, - resolution, - transform=JPEGCompressionTransform(quality=quality), - model_type=model_type - ) - - # Gaussian blur at different radii - for radius in [1.0, 3.0, 5.0]: - val_datasets[f'blur_{radius:.1f}'] = ISICDataset( - val_dataset, - preprocessor, - resolution, - transform=GaussianBlurTransform(radius=radius), - model_type=model_type - ) - - # Color quantization at different levels - for n_colors in [64, 16, 4]: - val_datasets[f'color_{n_colors}'] = ISICDataset( - val_dataset, - preprocessor, - resolution, - transform=ColorQuantizationTransform(n_colors=n_colors), - model_type=model_type - ) - - return val_datasets - -def evaluate_all_datasets(trainer, val_datasets: Dict[str, Any], model_name: str) -> Dict[str, Any]: - """ - Evaluate model on all validation datasets. - - Args: - trainer: HuggingFace Trainer object - val_datasets: Dictionary of validation datasets - model_name: Name of the model for logging - - Returns: - Dictionary of results for each dataset - """ - all_results = {} - - for val_name, val_dataset in val_datasets.items(): - print(f"Evaluating on {val_name}...") - - # Evaluate on this dataset - eval_results = trainer.evaluate( - eval_dataset=val_dataset, - metric_key_prefix=f"eval_{val_name}" - ) - - # Extract key metrics - accuracy = eval_results.get(f"eval_{val_name}_accuracy", 0) - f1 = eval_results.get(f"eval_{val_name}_f1", 0) - auc = eval_results.get(f"eval_{val_name}_auc", 0) - - # Store results - all_results[val_name] = { - "accuracy": accuracy, - "f1": f1, - "auc": auc, - "loss": eval_results.get(f"eval_{val_name}_loss", 0) - } - - # Log to wandb - wandb.log({ - f"{val_name}/accuracy": accuracy, - f"{val_name}/f1": f1, - f"{val_name}/auc": auc, - "model": model_name - }) - - print(f" {val_name}: Acc={accuracy:.3f}, F1={f1:.3f}, AUC={auc:.3f}") - - return all_results - -def train_model( - model_info: dict, - train_dataset, - val_dataset, - config: TrainingConfig, - degradation_transforms: list, - training_mode: str = "finetune" # "finetune" or "linear_probe" -) -> dict: - """ - Train a single model with specified training mode. - - Args: - model_info: Model configuration - train_dataset: Training dataset - val_dataset: Validation dataset - config: Training configuration - degradation_transforms: List of data augmentations - training_mode: "finetune" or "linear_probe" - - Returns: - Dictionary of training results - """ - name = model_info["name"] - model_type = model_info["type"] - - print(f"\n{'='*50}") - print(f"Training {name} ({model_type}) - Mode: {training_mode}") - print(f"{'='*50}") - - # Initialize wandb - wandb.init( - entity="sonnet-xu-stanford-university", - project="CS231N Test", - name=f"{name}_{config.resolution}_{config.num_epochs}_epochs_{training_mode}", - config={ - **config.to_wandb_config(), - "model_config": model_info["config"], - "training_mode": training_mode - }, - tags=["baseline", "model-comparison", training_mode, name, f"res_{config.resolution}"], - reinit=True - ) - - # Create model and preprocessor - model = create_model(model_info, config.resolution) - preprocessor = create_preprocessor(model_info, config.resolution) - - # Freeze backbone for linear probing - if training_mode == "linear_probe": - print(f"Freezing backbone for linear probing...") - freeze_backbone(model, model_type) - - # Count trainable parameters - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - total_params = sum(p.numel() for p in model.parameters()) - print(f"Trainable params: {trainable_params:,} / {total_params:,} " - f"({100 * trainable_params / total_params:.2f}%)") - - # Move model to device - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - - # Create training dataset with transformations - train_ds, _ = create_transformed_datasets( - train_dataset, - val_dataset, # Not used but required by function signature - degradation_transforms, - config.proportion_per_transform, - preprocessor, - config.resolution, - model_type - ) - - # Create multiple validation datasets - val_datasets = create_multi_validation_datasets( - val_dataset, - preprocessor, - config.resolution, - model_type - ) - - # Profile model - flops = profile_model(model, config.resolution) - - # Setup training arguments - output_dir = os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}_{training_mode}") - log_dir = env_path("LOG_DIR", "./logs") - - # Adjust learning rate for linear probing (typically higher) - learning_rate = config.learning_rate - if training_mode == "linear_probe": - learning_rate = config.learning_rate * 10 # Often need higher LR for linear probe - - training_args = TrainingArguments( - output_dir=output_dir, - num_train_epochs=config.num_epochs, - per_device_train_batch_size=config.batch_size, - per_device_eval_batch_size=config.batch_size, - learning_rate=learning_rate, - lr_scheduler_type="cosine", - weight_decay=config.weight_decay, - logging_dir=os.path.join(log_dir, f"{name}_{training_mode}"), - logging_steps=1, - evaluation_strategy="steps", - eval_steps=config.eval_steps, - save_strategy="steps", - save_steps=config.eval_steps, - load_best_model_at_end=True, # Load best model for final evaluation - metric_for_best_model="eval_clean_accuracy", # Use clean accuracy for model selection - greater_is_better=True, - save_total_limit=1, - save_safetensors=False, - push_to_hub=False, - ) - - # Check disk space - check_disk_space(required_gb=1.0) - - # Create trainer with clean validation set for checkpointing - trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_ds, - eval_dataset=val_datasets['clean'], # Use clean for model selection - compute_metrics=create_compute_metrics_fn(name), - callbacks=[ - LossLoggerCallback(log_dir, training_mode, name), - WandbCallback(name, training_mode), - ], - ) - - # Log model to wandb - if model_type in HF_MODELS: - wandb.watch(model, log="all", log_freq=100) - - # Train - start_time = time.time() - peak_memory = get_gpu_memory() - - trainer.train() - - # Evaluate on all validation datasets - eval_start_time = time.time() - multi_eval_results = evaluate_all_datasets(trainer, val_datasets, name) - eval_time = time.time() - eval_start_time - train_time = time.time() - start_time - eval_time - - # Track peak memory - current_memory = get_gpu_memory() - peak_memory = max(peak_memory, current_memory) if peak_memory > 0 else current_memory - - # Prepare comprehensive results - results = { - "model_name": name, - "model_type": model_type, - "training_mode": training_mode, - "peak_memory_mb": peak_memory, - "flops_giga": flops, - "train_time_seconds": train_time, - "eval_time_seconds": eval_time, - "eval_results_by_degradation": multi_eval_results, - # Summary statistics - "clean_accuracy": multi_eval_results['clean']['accuracy'], - "avg_degraded_accuracy": np.mean([ - res['accuracy'] for key, res in multi_eval_results.items() - if key != 'clean' - ]), - "robustness_gap": multi_eval_results['clean']['accuracy'] - multi_eval_results['jpeg_20']['accuracy'], - } - - # Log summary to wandb - wandb.log({ - "summary/clean_accuracy": results["clean_accuracy"], - "summary/avg_degraded_accuracy": results["avg_degraded_accuracy"], - "summary/robustness_gap": results["robustness_gap"], - }) - - # Save model - model_dir = os.path.join( - env_path("MODEL_DIR", "."), - f"{name}_{model_type}_{training_mode}_lr{learning_rate}_bs{config.batch_size}" - ) - save_model(model, model_info, model_dir, preprocessor) - - # Save as wandb artifact - artifact = wandb.Artifact( - name=f"{name}_{training_mode}_model", - type="model", - description=f"Trained {name} model with {model_type} architecture in {training_mode} mode" - ) - artifact.add_dir(model_dir) - wandb.log_artifact(artifact) - - # Finish wandb run - wandb.finish() - - # Clear GPU memory - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return results - -def main(config: TrainingConfig): - """Main training loop with both fine-tuning and linear probing.""" - # Setup environment - setup_environment() - - # Load dataset - print("Loading dataset...") - dataset = load_dataset( - "MKZuziak/ISIC_2019_224", - cache_dir=os.environ["HF_DATASETS_CACHE"], - split="train", - ) - print(f"Initial dataset size: {len(dataset)} images") - - # Slice dataset for debug purposes - dataset = dataset[:50] - - # Filter for desired classes - filtered_indices = [ - i for i, label in enumerate(dataset["label"]) - if str(label) in FILTERED_CLASSES - ] - dataset = dataset.select(filtered_indices) - print(f"After filtering: {len(dataset)} images") - - # Cast labels to correct number of classes - dataset = dataset.cast_column("label", ClassLabel(num_classes=NUM_FILTERED_CLASSES)) - - # Balance dataset - balanced_dataset = balance_dataset(dataset, FILTERED_CLASSES, config.num_train_images) - - # Split into train and validation - split_dataset = balanced_dataset.train_test_split( - test_size=0.2, - stratify_by_column="label", - seed=42 - ) - train_dataset = split_dataset["train"] - val_dataset = split_dataset["test"] - - print(f"Training samples: {len(train_dataset)}") - print(f"Validation samples: {len(val_dataset)}") - - # Get degradation transforms - degradation_transforms = get_degradation_transforms() - - # Select models to train - models = [m for m in MODEL_REGISTRY if m["name"] in ["vit"]] # Modify as needed - - # Store all results - all_results = { - "finetune": {}, - "linear_probe": {} - } - - # Train each model with both strategies - for model_info in models: - model_name = model_info["name"] - - # Fine-tuning - try: - print(f"\n{'='*60}") - print(f"FINE-TUNING: {model_name}") - print(f"{'='*60}") - - finetune_results = train_model( - model_info, - train_dataset, - val_dataset, - config, - degradation_transforms, - training_mode="finetune" - ) - all_results["finetune"][model_name] = finetune_results - - except Exception as e: - print(f"Error fine-tuning {model_name}: {e}") - all_results["finetune"][model_name] = {"error": str(e)} - - # Linear probing - try: - print(f"\n{'='*60}") - print(f"LINEAR PROBING: {model_name}") - print(f"{'='*60}") - - probe_results = train_model( - model_info, - train_dataset, - val_dataset, - config, - degradation_transforms, - training_mode="linear_probe" - ) - all_results["linear_probe"][model_name] = probe_results - - except Exception as e: - print(f"Error linear probing {model_name}: {e}") - all_results["linear_probe"][model_name] = {"error": str(e)} - - # Save comprehensive results - output_filename = ( - f"results_comprehensive_lr{config.learning_rate}_" - f"bs{config.batch_size}_ep{config.num_epochs}.json" - ) - save_results( - all_results, - os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), output_filename) - ) - - # Print summary comparison - print_results_summary(all_results) - - print("\n" + "="*60) - print("Training complete!") - print("="*60) - -def print_results_summary(results: Dict[str, Any]): - """Print a formatted summary of results.""" - print("\n" + "="*60) - print("RESULTS SUMMARY") - print("="*60) - - # Create comparison table - print("\nClean Accuracy Comparison:") - print("-" * 40) - print(f"{'Model':<15} {'Fine-tune':<12} {'Linear Probe':<12}") - print("-" * 40) - - for model_name in results["finetune"].keys(): - ft_acc = results["finetune"][model_name].get("clean_accuracy", 0) - lp_acc = results["linear_probe"][model_name].get("clean_accuracy", 0) - print(f"{model_name:<15} {ft_acc:<12.3f} {lp_acc:<12.3f}") - - print("\nRobustness (Clean - JPEG20 Accuracy):") - print("-" * 40) - print(f"{'Model':<15} {'Fine-tune':<12} {'Linear Probe':<12}") - print("-" * 40) - - for model_name in results["finetune"].keys(): - ft_rob = results["finetune"][model_name].get("robustness_gap", 0) - lp_rob = results["linear_probe"][model_name].get("robustness_gap", 0) - print(f"{model_name:<15} {ft_rob:<12.3f} {lp_rob:<12.3f}") - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Model comparison with fine-tuning and linear probing") - parser.add_argument('--resolution', type=int, default=224, - help='Input image resolution (default: 224)') - parser.add_argument('--batch_size', type=int, default=128, - help='Batch size for training (default: 128)') - parser.add_argument('--num_train_images', type=int, default=500, - help='Number of training images (default: 500)') - parser.add_argument('--num_epochs', type=int, default=3, - help='Number of training epochs (default: 3)') - parser.add_argument('--eval_steps', type=int, default=100, - help='Steps between evaluations (default: 100)') - parser.add_argument('--learning_rate', type=float, default=1e-4, - help='Learning rate (default: 1e-4)') - parser.add_argument('--mode', type=str, default='both', - choices=['finetune', 'linear_probe', 'both'], - help='Training mode (default: both)') - - args = parser.parse_args() - - config = TrainingConfig( - num_train_images=args.num_train_images, - resolution=args.resolution, - batch_size=args.batch_size, - num_epochs=args.num_epochs, - eval_steps=args.eval_steps, - learning_rate=args.learning_rate, - ) - - main(config) diff --git a/src/engines/utils/training_loops.py b/src/engines/utils/training_loops.py new file mode 100644 index 0000000..8ff8058 --- /dev/null +++ b/src/engines/utils/training_loops.py @@ -0,0 +1,78 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University +# SPDX-License-Identifier: MIT + +# -*- coding: utf-8 -*- +"""Training engine utilities for fine-tuning and linear probing. + +This module provides helper functions used by other training engines: +- _accuracy: Computes top-1 accuracy +- _maybe_scheduler_step: Step schedulers based on configuration/meta +- _evaluate: Evaluate model on a loader with loss + accuracy + +Imported by: +- src.engines.linear_probe_engine +- src.engines.finetune_engine +""" +from __future__ import annotations +from typing import Dict, Any, Tuple + +import torch +from torch import nn + +try: + # PyTorch 2.0+ unified AMP API + from torch.amp import autocast +except ImportError: + # Fallback for older PyTorch versions + from torch.cuda.amp import autocast + + +def _accuracy(logits: torch.Tensor, targets: torch.Tensor) -> float: + with torch.no_grad(): + preds = torch.argmax(logits, dim=1) + correct = (preds == targets).sum().item() + total = targets.numel() + return correct / max(total, 1) + + +def _maybe_scheduler_step(scheduler_meta: Dict[str, Any], scheduler, *, on: str): + """Step scheduler if meta says so. `on` ∈ {'batch','epoch','val'}.""" + step_when = str(scheduler_meta.get("step_per", "epoch")) + if step_when == on: + if "monitor" in scheduler_meta: + # e.g., ReduceLROnPlateau + scheduler.step(scheduler_meta["monitor"]) + else: + scheduler.step() + + +def _evaluate( + model: nn.Module, + loader, + loss_fn, + device: torch.device, + mixed_precision: bool, +) -> Tuple[float, float]: + model.eval() + running_loss, running_acc, n = 0.0, 0.0, 0 + with torch.no_grad(): + for batch in loader: + if isinstance(batch, dict): + x, y = batch.get("image"), batch.get("label") + else: + x, y = batch + x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) + + # Device-aware autocast (CPU/GPU) and version-safe + with autocast(device_type=getattr(device, "type", "cuda"), enabled=mixed_precision): + logits = model(x) + loss = loss_fn(logits, y) + + bsz = y.size(0) + running_loss += float(loss.item()) * bsz + running_acc += _accuracy(logits, y) * bsz + n += bsz + + return running_loss / max(n, 1), running_acc / max(n, 1) diff --git a/src/models/factory.py b/src/models/factory.py index be9cce4..b382598 100644 --- a/src/models/factory.py +++ b/src/models/factory.py @@ -9,10 +9,13 @@ """Unified model factory: HF vision models (ViT/DINOv2), optional timm, plus matching preprocessors and helpers (freeze_backbone, save_model).""" +from __future__ import annotations from typing import Dict, Any import os -import torch.nn as nn +import json +import torch +from torch import nn from PIL import Image # --- Hugging Face --- @@ -27,28 +30,35 @@ try: import timm # type: ignore _TIMM_AVAILABLE = True -except Exception: +except ImportError: _TIMM_AVAILABLE = False -# --- Project constants (small change from your code: avoid importing configs directly) --- +# --- Project constants (optional) --- try: from src.utils.constants import HF_MODELS # e.g., {"vit", "dinov2"} -except Exception: - # Fallback if constants not present yet +except ImportError: HF_MODELS = {"vit", "dinov2"} +# --- Pillow resampling constant (handles both new and old Pillow versions) --- +try: + # Pillow ≥9.1 uses Image.Resampling + RESAMPLING_LANCZOS = Image.Resampling.LANCZOS # type: ignore[attr-defined] +except AttributeError: + # Pillow <9.1 fallback; use getattr to avoid pylint false positives + RESAMPLING_LANCZOS = getattr(Image, "LANCZOS", None) or getattr(Image, "BICUBIC", None) + # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- -def create_model(model_info: Dict[str, Any], resolution: int = 224): +def create_model(model_info: Dict[str, Any], resolution: int = 224) -> nn.Module: """ Factory function to create models based on type. Args: - model_info: {"type": "vit"|"dinov2"|("timm"), "model_id": str, "config": {...}} - resolution: input image resolution + model_info: {"type": "vit"|"dinov2"|"timm", "model_id": str, "config": {...}} + resolution: input image resolution (passed to HF heads when supported) Returns: nn.Module @@ -62,30 +72,28 @@ def create_model(model_info: Dict[str, Any], resolution: int = 224): return ViTForImageClassification.from_pretrained( model_id, num_labels=config["num_labels"], - ignore_mismatched_sizes=config.get("ignore_mismatched_sizes", True), + ignore_mismatched_sizes=bool(config.get("ignore_mismatched_sizes", True)), image_size=resolution, ) # --- HuggingFace DINOv2 (AutoModel) --- - elif model_type == "dinov2": + if model_type == "dinov2": return AutoModelForImageClassification.from_pretrained( model_id, num_labels=config["num_labels"], - ignore_mismatched_sizes=config.get("ignore_mismatched_sizes", True), + ignore_mismatched_sizes=bool(config.get("ignore_mismatched_sizes", True)), image_size=resolution, ) - # --- Optional timm branch (kept minimal, no breaking changes) --- - elif model_type == "timm": + # --- timm --- + if model_type == "timm": if not _TIMM_AVAILABLE: raise RuntimeError("timm is not installed but model_type='timm' was requested.") num_classes = int(config.get("num_labels", 1000)) pretrained = bool(config.get("pretrained", True)) - # If you need image_size-specific config, many timm models accept it via `img_size` return timm.create_model(model_id, pretrained=pretrained, num_classes=num_classes) - else: - raise ValueError(f"Unknown model type: {model_type}") + raise ValueError(f"Unknown model type: {model_type}") def create_preprocessor(model_info: Dict[str, Any], resolution: int = 224): @@ -107,39 +115,40 @@ def create_preprocessor(model_info: Dict[str, Any], resolution: int = 224): model_id, size=resolution, do_resize=True, - resample=Image.LANCZOS, + resample=RESAMPLING_LANCZOS, do_normalize=True, + # ImageNet normalization statistics [red, green, blue] image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225], ) - elif model_type == "dinov2": + if model_type == "dinov2": return AutoImageProcessor.from_pretrained( model_id, size=resolution, do_resize=True, - resample=Image.LANCZOS, + resample=RESAMPLING_LANCZOS, do_normalize=True, + # ImageNet normalization statistics [red, green, blue] image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225], ) - elif model_type == "timm": + if model_type == "timm": # timm uses torchvision transforms; return None and build transforms in your datamodule return None - else: - raise ValueError(f"Unknown model type: {model_type}") + raise ValueError(f"Unknown model type: {model_type}") -def freeze_backbone(model: nn.Module, model_type: str): +def freeze_backbone(model: nn.Module, model_type: str) -> None: """ Freeze backbone parameters for transfer learning. For HF classifiers, keep 'classifier' or 'head' trainable; freeze the rest. Args: model: nn.Module - model_type: 'vit' | 'dinov2' | ('timm' if you wire it similarly) + model_type: 'vit' | 'dinov2' | 'timm' """ if model_type in HF_MODELS: for name, param in model.named_parameters(): @@ -148,18 +157,25 @@ def freeze_backbone(model: nn.Module, model_type: str): param.requires_grad = True else: param.requires_grad = False - elif model_type == "timm": - # Optional: implement project-specific rules (e.g., freeze all except last classifier) + return + + if model_type == "timm": for name, param in model.named_parameters(): if ("classifier" in name) or ("fc" in name) or ("head" in name): param.requires_grad = True else: param.requires_grad = False - else: - raise ValueError(f"Unsupported model_type: {model_type}") + return + raise ValueError(f"Unsupported model_type: {model_type}") -def save_model(model: nn.Module, model_info: Dict[str, Any], save_dir: str, preprocessor=None): + +def save_model( + model: nn.Module, + model_info: Dict[str, Any], + save_dir: str, + preprocessor=None +) -> None: """ Save model based on its type. @@ -176,22 +192,23 @@ def save_model(model: nn.Module, model_info: Dict[str, Any], save_dir: str, prep model.save_pretrained(save_dir) if preprocessor is not None: preprocessor.save_pretrained(save_dir) - elif model_type == "timm": + return + + if model_type == "timm": # Torch-style checkpoint for timm models - import torch ckpt_path = os.path.join(save_dir, "pytorch_model.bin") torch.save(model.state_dict(), ckpt_path) # Minimal config export - with open(os.path.join(save_dir, "config.json"), "w") as f: - import json + with open(os.path.join(save_dir, "config.json"), "w", encoding="utf-8") as f: json.dump( { "model_type": "timm", - "model_id": model_info["model_id"], + "model_id": model_info.get("model_id"), "num_labels": model_info.get("config", {}).get("num_labels", None), }, f, indent=2, ) - else: - raise ValueError(f"Unsupported model_type: {model_type}") + return + + raise ValueError(f"Unsupported model_type: {model_type}") diff --git a/src/wrappers/finetune.py b/src/wrappers/finetune.py index 50816fb..57c1054 100644 --- a/src/wrappers/finetune.py +++ b/src/wrappers/finetune.py @@ -1,44 +1,147 @@ # This source file is part of the Daneshjou Lab projects # -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# +# SPDX-FileCopyrightText: 2025 Stanford University # SPDX-License-Identifier: MIT -from typing import Dict, Any +# -*- coding: utf-8 -*- +"""Fine-tuning wrapper for end-to-end optimization.""" +from __future__ import annotations +from typing import Any, Dict + +import os import torch +from torch.utils.data import DataLoader -from src.engines.finetune_engine import train_finetune -from src.data.datamodule import build_datamodule -from src.models.factory import build_classifier -from src.losses.classification import cross_entropy_loss +# pylint: disable=import-error +from src.utils.logging import setup_logging, get_logger, WandbLogger from src.utils.optim import make_optimizer_and_scheduler -from src.utils.logging import get_logger # TODO +from src.losses.classification import cross_entropy_loss +from src.models.factory import create_model, create_preprocessor, save_model +from src.data.datamodule import BaseDataModule +from src.engines.finetune_engine import train_finetune +from src.utils.training_utils import profile_model log = get_logger(__name__) -def run(cfg) -> Dict[str, Any]: - device = torch.device(cfg.runtime["device"]) - dm = build_datamodule(cfg) - model = build_classifier(cfg).to(device) # backbone + classification head - - # Unfreeze all params for full finetuning - if hasattr(model, "unfreeze_all"): - model.unfreeze_all() - - loss_fn = cross_entropy_loss() - - optimizer, scheduler = make_optimizer_and_scheduler(cfg, model.parameters()) - - log.info("Starting end-to-end finetuning...") - metrics = train_finetune( - model=model, - loaders=dm.loaders(), - loss_fn=loss_fn, - optimizer=optimizer, - scheduler=scheduler, - device=device, - epochs=cfg.train.epochs, - grad_clip=getattr(cfg.train, "grad_clip_norm", None), - mixed_precision=getattr(cfg.train, "amp", True), - ) - return metrics + +class FinetuneWrapper: # pylint: disable=too-many-instance-attributes,too-few-public-methods + """ + Orchestrates full fine-tuning: + - builds model (+preprocessor if needed), + - prepares dataloaders via DataModule, + - creates optimizer/scheduler/loss, + - calls the finetune engine, + - saves best model. + """ + + def __init__(self, cfg: Any): + """ + Expected cfg fields (suggested): + cfg.model, cfg.data, cfg.train, cfg.loss, cfg.logging, cfg.runtime + """ + self.cfg = cfg + setup_logging() + + # Model + self.model_info = cfg.model + self.model = create_model(self.model_info, resolution=cfg.data.image_size) + + # Optional preprocessor + try: + self.preprocessor = create_preprocessor( + self.model_info, resolution=cfg.data.image_size + ) + except (ImportError, AttributeError, KeyError) as e: + log.debug(f"Preprocessor creation failed: {e}") + self.preprocessor = None + + # Data + self.dm = BaseDataModule( + cfg=cfg, + dataset_name=cfg.data.dataset_name, + data_dir=cfg.data.data_dir, + batch_size=cfg.data.batch_size, + num_workers=cfg.data.num_workers, + pin_memory=True, + ) + self.dm.setup("fit") + + # Optimizer & scheduler + self.optimizer, (self.scheduler, self.sched_meta) = ( + make_optimizer_and_scheduler(cfg, self.model.parameters()) + ) + + # Loss + self.loss_fn = cross_entropy_loss( + label_smoothing=float(getattr(cfg.loss, "label_smoothing", 0.0)), + class_weight=None, + ignore_index=int(getattr(cfg.loss, "ignore_index", -100)), + reduction=str(getattr(cfg.loss, "reduction", "mean")), + ) + + # W&B + self.wandb = WandbLogger( + project=getattr(cfg.logging, "project", "resolution-aware-finetune"), + run_name=getattr(cfg.logging, "run_name", None), + config=cfg, + enabled=bool(getattr(cfg.logging, "wandb_enabled", True)), + entity=getattr(cfg.logging, "entity", None), + tags=getattr(cfg.logging, "tags", ["finetune"]), + ) + + # Device + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) + + def _make_loaders(self) -> Dict[str, DataLoader]: + return { + "train": self.dm.train_dataloader(), + "val": self.dm.val_dataloader(), + } + + def train(self) -> Dict[str, Any]: + """Run fine-tuning training.""" + log.info("Starting fine-tune...") + + gflops = profile_model(self.model, self.cfg.data.image_size) + if self.wandb: + self.wandb.log({"model/gflops": gflops}) + + results = train_finetune( + model=self.model, + loaders=self._make_loaders(), + loss_fn=self.loss_fn, + optimizer=self.optimizer, + scheduler=(self.scheduler, self.sched_meta), + device=self.device, + epochs=int(self.cfg.train.epochs), + grad_clip=getattr(self.cfg.train, "grad_clip", None), + mixed_precision=bool(getattr(self.cfg.train, "mixed_precision", True)), + log_interval=int(getattr(self.cfg.train, "log_interval", 50)), + wandb_logger=self.wandb, + metric_key=str(getattr(self.cfg.train, "metric_key", "val_acc")), + ) + + run_dir = getattr(self.cfg.runtime, "run_dir", "./runs/finetune") + os.makedirs(run_dir, exist_ok=True) + try: + save_model( + self.model, + self.model_info, + save_dir=run_dir, + preprocessor=self.preprocessor + ) + except (OSError, ValueError, AttributeError) as e: + log.warning(f"Failed to save model in HF format; error: {e}") + + if self.wandb: + self.wandb.log({"best/metric": results.get("best_metric", None)}) + self.wandb.finish() + + log.info("Fine-tune finished.") + return results + + +def run(cfg: Any) -> Dict[str, Any]: + wrapper = FinetuneWrapper(cfg) + return wrapper.train() From 079984fa293845ad2c877def21614f293dee33ed Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Mon, 20 Oct 2025 17:10:24 -0700 Subject: [PATCH 23/26] update hydra-core version --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a817c5d..1e41dc8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,7 @@ h11 httpcore httpx hydra +hydra-core==1.3.2 idna importlib_metadata ipykernel @@ -37,7 +38,7 @@ matplotlib matplotlib-inline nest_asyncio numpy>=1.26.0 -omegaconf +omegaconf==2.3.0 openai packaging pandas>=2.2.0 From 55776f7fda6473c07441960801976aff600853f0 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Tue, 21 Oct 2025 18:31:50 -0700 Subject: [PATCH 24/26] refactor(data): consolidate dataset logic into dataset_factory and improve transform handling --- .github/workflows/build-and-test.yml | 2 +- src/data/__init__.py | 0 src/data/datamodule.py | 157 +++++++--- .../{data_utils.py => dataset_factory.py} | 178 +++++++++-- src/data/isic_loader.py | 11 +- src/engines/finetune_engine.py | 97 +++--- src/engines/linear_probe_engine.py | 99 +++---- src/engines/utils/training_core.py | 276 ++++++++++++++++++ src/engines/utils/training_loops.py | 78 ----- 9 files changed, 622 insertions(+), 276 deletions(-) create mode 100644 src/data/__init__.py rename src/data/{data_utils.py => dataset_factory.py} (59%) create mode 100644 src/engines/utils/training_core.py delete mode 100644 src/engines/utils/training_loops.py diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 64cb848..1693d02 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.12"] + python-version: ["3.10", "3.12"] steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data/datamodule.py b/src/data/datamodule.py index 70321d3..e9a4967 100644 --- a/src/data/datamodule.py +++ b/src/data/datamodule.py @@ -4,33 +4,64 @@ # # SPDX-License-Identifier: MIT -# src/data/datamodule.py -# -*- coding: utf-8 -*- -from torch.utils.data import DataLoader, random_split -from typing import Optional -from .datasets import get_dataset +""" +DataModule for managing dataset loading, splitting, and dataloader creation. +Provides a unified interface for working with different datasets. +""" +# Standard library imports +from typing import Optional, Any + +# Third-party imports +import torch +from torch.utils.data import DataLoader, random_split, Subset + +# Local imports +# pylint: disable=import-error +from src.data.dataset_factory import get_dataset class BaseDataModule: """ Dataset-agnostic data module. - Responsibilities: - - Instantiate train/val/test datasets using `get_dataset()`. - - Build and cache dataloaders. - - Apply consistent transforms and batching options across datasets. + - If get_dataset(dataset_name, split=...) exists, we use provided splits. + - Otherwise we create train/val via random_split from a single 'train' split. """ + # pylint: disable=too-many-instance-attributes def __init__( self, - cfg, + cfg: Any, dataset_name: str, data_dir: str, + *, num_workers: int = 8, batch_size: int = 32, pin_memory: bool = True, drop_last: bool = False, - ): + split_seed: int = 42, + preprocessor: Any = None, + resolution: int = 224, + transform: Any = None, + model_type: str = "vit", + ): # pylint: disable=too-many-arguments + """ + Initialize the DataModule with dataset and dataloader parameters. + + Args: + cfg: Configuration object + dataset_name: Name of the dataset to load + data_dir: Directory where dataset is stored + num_workers: Number of worker processes for data loading + batch_size: Number of samples per batch + pin_memory: Whether to pin memory for faster GPU transfer + drop_last: Whether to drop the last incomplete batch + split_seed: Random seed for train/val splitting + preprocessor: Data preprocessing pipeline + resolution: Image resolution to use + transform: Data augmentation transforms + model_type: Type of model (e.g., 'vit', 'cnn') + """ self.cfg = cfg self.dataset_name = dataset_name self.data_dir = data_dir @@ -38,51 +69,103 @@ def __init__( self.batch_size = batch_size self.pin_memory = pin_memory self.drop_last = drop_last + self.split_seed = split_seed + + # Plumb-through for dataset construction + self.preprocessor = preprocessor + self.resolution = resolution + self.transform = transform + self.model_type = model_type self.train_set = None self.val_set = None self.test_set = None # ------------------------------------------------------------------ - def setup(self, stage: Optional[str] = None): + def setup(self, _stage: Optional[str] = None): """ - Called once to initialize datasets. - stage: 'fit' | 'validate' | 'test' | None + Initialize datasets. _stage: 'fit' | 'validate' | 'test' | None + Tries explicit splits first; falls back to random_split from a 'train' split. + + Args: + _stage: Current stage of training pipeline (unused but kept for API compatibility) """ - ds_full = get_dataset(self.dataset_name, self.data_dir, split="train", cfg=self.cfg) - n_total = len(ds_full) - n_val = int(0.1 * n_total) - n_train = n_total - n_val - self.train_set, self.val_set = random_split(ds_full, [n_train, n_val]) + # Try to fetch explicit train/val + ds_train = get_dataset( + self.dataset_name, self.data_dir, split="train", cfg=self.cfg, + preprocessor=self.preprocessor, resolution=self.resolution, + transform=self.transform, model_type=self.model_type, + ) + + try: + ds_val = get_dataset( + self.dataset_name, self.data_dir, split="val", cfg=self.cfg, + preprocessor=self.preprocessor, resolution=self.resolution, + transform=None, model_type=self.model_type, + ) + except Exception as e: # pylint: disable=broad-exception-caught, unused-variable + # Exception is broad to handle any dataset-specific errors + ds_val = None + + try: + ds_test = get_dataset( + self.dataset_name, self.data_dir, split="test", cfg=self.cfg, + preprocessor=self.preprocessor, resolution=self.resolution, + transform=None, model_type=self.model_type, + ) + except Exception as e: # pylint: disable=broad-exception-caught, unused-variable + # Exception is broad to handle any dataset-specific errors + ds_test = None - # test set - self.test_set = get_dataset(self.dataset_name, self.data_dir, split="test", cfg=self.cfg) + if ds_val is None: + # Fallback: split train into train/val + n_total = len(ds_train) + n_val = max(1, int(0.1 * n_total)) + n_train = n_total - n_val + g = torch.Generator().manual_seed(self.split_seed) + self.train_set, self.val_set = random_split(ds_train, [n_train, n_val], generator=g) + else: + self.train_set, self.val_set = ds_train, ds_val + + self.test_set = ds_test # ------------------------------------------------------------------ def train_dataloader(self): + """ + Create and return the training data loader. + + Returns: + DataLoader: Configured training data loader + """ return DataLoader( - self.train_set, - batch_size=self.batch_size, - shuffle=True, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - drop_last=self.drop_last, + self.train_set, batch_size=self.batch_size, shuffle=True, + num_workers=self.num_workers, pin_memory=self.pin_memory, drop_last=self.drop_last, ) def val_dataloader(self): + """ + Create and return the validation data loader. + + Returns: + DataLoader: Configured validation data loader + """ return DataLoader( - self.val_set, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - pin_memory=self.pin_memory, + self.val_set, batch_size=self.batch_size, shuffle=False, + num_workers=self.num_workers, pin_memory=self.pin_memory, ) def test_dataloader(self): + """ + Create and return the test data loader. + If no test set exists, returns an empty loader. + + Returns: + DataLoader: Configured test data loader + """ + if self.test_set is None: + # Provide an empty loader if test isn't defined + return DataLoader(Subset(self.val_set, []), batch_size=self.batch_size) return DataLoader( - self.test_set, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - pin_memory=self.pin_memory, + self.test_set, batch_size=self.batch_size, shuffle=False, + num_workers=self.num_workers, pin_memory=self.pin_memory, ) diff --git a/src/data/data_utils.py b/src/data/dataset_factory.py similarity index 59% rename from src/data/data_utils.py rename to src/data/dataset_factory.py index 5d28f9d..5ebe33e 100644 --- a/src/data/data_utils.py +++ b/src/data/dataset_factory.py @@ -5,13 +5,19 @@ # SPDX-License-Identifier: MIT """Dataset implementations and data utilities.""" +# Standard library imports +from typing import Optional, List, Dict, Any, Union + +# Third-party imports import numpy as np import torch from PIL import Image from torch.utils.data import Dataset, ConcatDataset, Subset -from typing import Optional, List, Dict, Any, Union +# Local imports +# pylint: disable=import-error,relative-beyond-top-level from src.config import HF_MODELS, DEFAULT_IMAGE_SIZE +from src.data.isic_loader import ISICBaseDataset # ============================================================================ # DATA LOADING @@ -26,6 +32,10 @@ def __init__(self, dataset: Union[Dataset, Subset]): def __len__(self) -> int: return len(self.dataset) + def __getitem__(self, idx: int) -> Dict[str, Any]: + """Get item from dataset, satisfying Dataset abstract method.""" + return self.get_raw_item(idx) + def get_raw_item(self, idx: int) -> Dict[str, Any]: """Get raw item from dataset, handling both Dataset and Subset.""" # Convert numpy types to Python int @@ -37,9 +47,8 @@ def get_raw_item(self, idx: int) -> Dict[str, Any]: # This is a Subset subset_idx = int(self.dataset.indices[idx]) return self.dataset.dataset[subset_idx] - else: - # Direct dataset access - return self.dataset[idx] + # Direct dataset access + return self.dataset[idx] # ============================================================================ @@ -72,23 +81,45 @@ def apply_transforms(self, image: Image.Image, transform: Optional[Any]) -> Imag class ModelPreprocessor: """Handles model-specific preprocessing.""" + # pylint: disable=too-few-public-methods def __init__(self, preprocessor: Optional[Any] = None, model_type: str = "vit"): self.preprocessor = preprocessor self.model_type = model_type + if self.preprocessor is None: + # Will no-op and return a tensorized fallback later + return + if self.model_type not in HF_MODELS: raise ValueError(f"Unsupported model_type: {self.model_type}") def preprocess(self, image: Image.Image, resolution: int) -> torch.Tensor: - """Apply model-specific preprocessing.""" - # For HuggingFace models - if hasattr(self.preprocessor, 'size'): - self.preprocessor.size = resolution + """Apply model-specific preprocessing; no-op if preprocessor is None.""" + if self.preprocessor is None: + return self._convert_pil_to_tensor(image) + + # HuggingFace path + if hasattr(self.preprocessor, "size") and self.preprocessor.size != resolution: + try: + self.preprocessor.size = resolution + except (AttributeError, TypeError) as e: # pylint: disable=unused-variable + # Some processors use tuples or different APIs for size + pass encoding = self.preprocessor(images=image, return_tensors="pt") return encoding["pixel_values"].squeeze(0) + def _convert_pil_to_tensor(self, image: Image.Image) -> torch.Tensor: + """Convert PIL image to tensor with proper format.""" + # Minimal safety: convert PIL -> tensor [C,H,W] in [0,1] + arr = np.asarray(image).astype(np.float32) / 255.0 + if arr.ndim == 2: # grayscale -> [H,W] -> [H,W,1] + arr = arr[:, :, None] + # HW(C) -> CHW + arr = np.transpose(arr, (2, 0, 1)) + return torch.from_numpy(arr) + # ============================================================================ # COMBINED DATASET @@ -100,11 +131,12 @@ class ISICDataset(Dataset): def __init__( self, dataset: Union[Dataset, Subset], + *, # Force keyword arguments after first positional arg preprocessor: Optional[Any] = None, resolution: int = DEFAULT_IMAGE_SIZE, transform: Optional[Any] = None, model_type: str = "vit", - ): + ): # pylint: disable=too-many-arguments # Data loading self.data_wrapper = DatasetWrapper(dataset) @@ -180,11 +212,13 @@ def create_transformed_datasets( val_dataset: Dataset, transforms_list: List[Any], proportion_per_transform: float, - preprocessor: Optional[Any], - resolution: int, - model_type: str + *, # Force keyword arguments + preprocessor: Optional[Any] = None, + resolution: int = DEFAULT_IMAGE_SIZE, + model_type: str = "vit" ) -> tuple[Dataset, Dataset]: """Create train and validation datasets with transformations.""" + # pylint: disable=too-many-arguments,too-many-locals # Split training data into subsets train_subsets = split_dataset_for_transforms( @@ -195,16 +229,24 @@ def create_transformed_datasets( transformed_datasets = [] # Apply each transform to corresponding subset - for i, (subset, transform) in enumerate(zip(train_subsets[:-1], transforms_list)): + for _, (subset, transform) in enumerate(zip(train_subsets[:-1], transforms_list)): transformed_ds = ISICDataset( - subset, preprocessor, resolution, transform, model_type + subset, + preprocessor=preprocessor, + resolution=resolution, + transform=transform, + model_type=model_type ) transformed_datasets.append(transformed_ds) # Add untransformed subset (if any remaining) if len(train_subsets) > len(transforms_list): untransformed_ds = ISICDataset( - train_subsets[-1], preprocessor, resolution, None, model_type + train_subsets[-1], + preprocessor=preprocessor, + resolution=resolution, + transform=None, + model_type=model_type ) transformed_datasets.append(untransformed_ds) @@ -212,7 +254,13 @@ def create_transformed_datasets( train_ds = ConcatDataset(transformed_datasets) # Create validation dataset (no transformations) - val_ds = ISICDataset(val_dataset, preprocessor, resolution, None, model_type) + val_ds = ISICDataset( + val_dataset, + preprocessor=preprocessor, + resolution=resolution, + transform=None, + model_type=model_type + ) return train_ds, val_ds @@ -222,8 +270,8 @@ def create_transformed_datasets( def get_class_distribution(dataset: Dataset, filtered_classes: List[str]) -> Dict[str, List[int]]: """Get class distribution and indices.""" - class_counts = {label: 0 for label in filtered_classes} - class_indices = {label: [] for label in filtered_classes} + class_counts = {class_label: 0 for class_label in filtered_classes} + class_indices = {class_label: [] for class_label in filtered_classes} for i, item in enumerate(dataset): label_str = str(item["label"]) @@ -252,7 +300,7 @@ def sample_balanced_indices( # Sample from each class balanced_indices = [] - for label, indices in class_indices.items(): + for _label, indices in class_indices.items(): sampled = np.random.choice(indices, images_per_class, replace=False) balanced_indices.extend(sampled) @@ -266,12 +314,92 @@ def balance_dataset( num_train_images: int, seed: int = 42 ) -> Dataset: - """Balance dataset by sampling equal numbers from each class.""" - # Get class distribution + """ + Balance dataset by sampling equal numbers from each class. + - If `dataset` is a PyTorch Dataset: returns a torch.utils.data.Subset + - If `dataset` is an HF Dataset (has `.select`): returns a selected HF Dataset + """ class_indices = get_class_distribution(dataset, filtered_classes) - - # Sample balanced indices balanced_indices = sample_balanced_indices(class_indices, num_train_images, seed) - # Return balanced dataset - return dataset.select(balanced_indices) \ No newline at end of file + # HF Dataset has .select; PyTorch does not + if hasattr(dataset, "select"): + return dataset.select(balanced_indices) + return Subset(dataset, balanced_indices) + +def _load_isic_split(data_dir: str, split: str) -> Dataset: + """ + Replace with your real ISIC split loader that yields dicts: + {"image": PIL.Image, "label": int} + + Expected interface: + class ISICRawSplit(Dataset): + def __init__(self, data_dir: str, split: str): ... + def __getitem__(self, i) -> {"image": PIL.Image, "label": int} + def __len__(self) -> int: ... + + If you already have it elsewhere, just import and return it here. + """ + # pylint: disable=import-outside-toplevel,relative-beyond-top-level + + try: + from src.data.isic_raw import ISICRawSplit # Use absolute import + except ImportError as e: + raise ImportError( + "ISICRawSplit not found. Create src/data/isic_raw.py with an ISICRawSplit " + "that returns {'image': PIL.Image, 'label': int}." + ) from e + + return ISICRawSplit(data_dir, split) + + +def get_dataset( + dataset_name: str, + data_dir: str, + split: str, + cfg=None, # Kept for API compatibility + *, + preprocessor=None, + resolution: int = DEFAULT_IMAGE_SIZE, + transform=None, + model_type: str = "vit", + mode: str = "model_ready", # "raw" or "model_ready" +) -> Dataset: + """ + Unified dataset factory used by BaseDataModule. + + Args: + dataset_name: e.g., "isic2019" + data_dir: root directory for the dataset + split: "train" | "val" | "test" (if "val" doesn't exist, DataModule can split) + cfg: optional config (unused but kept for API compatibility) + preprocessor: HF image processor or similar (used by model_ready) + resolution: model input resolution (used by model_ready) + transform: optional PIL->PIL degradation transform (applied before preprocess) + model_type: str key you use in HF_MODELS + mode: "raw" (no transforms/preprocessing) or "model_ready" (pipeline) + + Returns: + A torch.utils.data.Dataset + """ + # pylint: disable=too-many-arguments,unused-argument + + name = dataset_name.lower() + + if name in {"isic", "isic2019", "dermatology"}: + base_ds = _load_isic_split(data_dir, split) + + if mode == "raw": + # No resize/degradation or model preprocessing + return ISICBaseDataset(base_ds) + + # model_ready: resize + optional degradation + HF preprocessing + return ISICDataset( + dataset=base_ds, + preprocessor=preprocessor, + resolution=resolution, + transform=transform, + model_type=model_type, + ) + + raise ValueError(f"Unknown dataset_name: {dataset_name}") diff --git a/src/data/isic_loader.py b/src/data/isic_loader.py index 2a52569..f16403e 100644 --- a/src/data/isic_loader.py +++ b/src/data/isic_loader.py @@ -4,7 +4,7 @@ # # SPDX-License-Identifier: MIT -# src/data/isic_loader.py +"""ISIC dataset loader implementation for dermatology image datasets.""" from typing import Any, Dict, Union from torch.utils.data import Dataset, Subset @@ -27,7 +27,7 @@ def __len__(self) -> int: return len(self.dataset) def __getitem__(self, idx: int) -> Dict[str, Any]: - # Handle Subset wrapping transparently + # Handle Subset wrapping base = self.dataset if hasattr(base, "dataset") and hasattr(base, "indices"): # Subset case @@ -35,14 +35,13 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: else: item = base[idx] - # Do NOT touch the image (no resize, no cast, no transforms) + # no resize, no cast, no transforms image = item["image"] label = item["label"] - # Make sure label is int-like, but don't coerce image try: label = int(label) - except Exception: - raise TypeError("Label must be convertible to int.") + except Exception as exc: + raise TypeError("Label must be convertible to int.") from exc return {"image": image, "label": label} diff --git a/src/engines/finetune_engine.py b/src/engines/finetune_engine.py index b7c2333..7b0394f 100644 --- a/src/engines/finetune_engine.py +++ b/src/engines/finetune_engine.py @@ -5,6 +5,7 @@ # -*- coding: utf-8 -*- """Training engine for full fine-tuning (end-to-end).""" +# pylint: disable=duplicate-code from __future__ import annotations from typing import Dict, Any, Tuple, Optional @@ -14,16 +15,20 @@ try: # PyTorch 2.0+ unified AMP API - from torch.amp import autocast, GradScaler + from torch.amp import autocast except ImportError: # Fallback for older PyTorch versions - from torch.cuda.amp import autocast, GradScaler + from torch.cuda.amp import autocast # pylint: disable=import-error from src.utils.logging import get_logger -from src.engines.utils.training_loops import ( +from src.engines.utils.training_core import ( _maybe_scheduler_step, - _evaluate, + _create_grad_scaler, + _update_history_and_log, + _preprocess_batch, + _run_validation_and_scheduler, + _update_best_model_state, ) log = get_logger(__name__) @@ -73,16 +78,7 @@ def train_finetune( # pylint: disable=too-many-arguments,too-many-locals,too-ma assert accumulation_steps >= 1, "accumulation_steps must be >= 1" # Initialize GradScaler with backward compatibility - try: - scaler = GradScaler( - enabled=mixed_precision, - init_scale=2.0**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - ) - except TypeError: - scaler = GradScaler(enabled=mixed_precision) + scaler = _create_grad_scaler(mixed_precision) sched, sched_meta = scheduler or (None, {}) best_metric = -math.inf if not metric_key.endswith("loss") else math.inf @@ -96,11 +92,7 @@ def train_finetune( # pylint: disable=too-many-arguments,too-many-locals,too-ma optimizer.zero_grad(set_to_none=zero_grad_set_to_none) for step, batch in enumerate(loaders["train"], start=1): - if isinstance(batch, dict): - x, y = batch.get("image"), batch.get("label") - else: - x, y = batch - x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) + x, y = _preprocess_batch(batch, device) with autocast(device_type=device.type, enabled=mixed_precision): logits = model(x) @@ -140,66 +132,43 @@ def train_finetune( # pylint: disable=too-many-arguments,too-many-locals,too-ma if wandb_logger: wandb_logger.log({"train/loss": float(loss.item()), "lr": cur_lr}) - # ---- validation - val_loss, val_acc = _evaluate( + # ---- validation and scheduler step + val_loss, val_acc = _run_validation_and_scheduler( model=model, - loader=loaders["val"], + loaders=loaders, loss_fn=loss_fn, device=device, mixed_precision=mixed_precision, + sched=sched, + sched_meta=sched_meta, + metric_key=metric_key, ) - # Epoch or val-based scheduler step - if sched is not None: - if isinstance(sched, torch.optim.lr_scheduler.ReduceLROnPlateau): - sched_meta["monitor"] = ( - val_loss if metric_key.endswith("loss") else val_acc - ) - _maybe_scheduler_step(sched_meta, sched, on="val") - else: - _maybe_scheduler_step(sched_meta, sched, on="epoch") - # Aggregate + log train_loss = running_loss / max(n_seen, 1) cur_lr = optimizer.param_groups[0]["lr"] - history["train_loss"].append(train_loss) - history["val_loss"].append(val_loss) - history["val_acc"].append(val_acc) - history["lr"].append(cur_lr) - - if wandb_logger: - wandb_logger.log( - { - "epoch": epoch, - "train/loss_epoch": train_loss, - "val/loss": val_loss, - "val/acc": val_acc, - "lr": cur_lr, - } - ) - - log.info( - "Epoch %d | train_loss=%.4f | val_loss=%.4f | val_acc=%.4f | lr=%.2e", - epoch, - train_loss, - val_loss, - val_acc, - cur_lr, + _update_history_and_log( + history=history, + epoch=epoch, + train_loss=train_loss, + val_loss=val_loss, + val_acc=val_acc, + cur_lr=cur_lr, + wandb_logger=wandb_logger, + log=log, ) - monitor = val_loss if metric_key.endswith("loss") else val_acc - is_better = ( - (monitor < best_metric) - if metric_key.endswith("loss") - else (monitor > best_metric) + updated_state, best_metric, is_better = _update_best_model_state( + model=model, + metric_key=metric_key, + val_loss=val_loss, + val_acc=val_acc, + best_metric=best_metric, ) if is_better: - best_metric = monitor - best_state = { - k: v.detach().cpu().clone() for k, v in model.state_dict().items() - } + best_state = updated_state # Restore best weights so caller can save/export if best_state is not None: diff --git a/src/engines/linear_probe_engine.py b/src/engines/linear_probe_engine.py index a2e74bd..f0e6a69 100644 --- a/src/engines/linear_probe_engine.py +++ b/src/engines/linear_probe_engine.py @@ -5,6 +5,7 @@ # -*- coding: utf-8 -*- """Training engine for linear probing.""" +# pylint: disable=duplicate-code from __future__ import annotations from typing import Dict, Any, Tuple, Optional @@ -15,16 +16,20 @@ # --- AMP import (robust across versions) --- try: # PyTorch 2.0+ unified AMP API - from torch.amp import autocast, GradScaler + from torch.amp import autocast except ImportError: # Fallback for older PyTorch versions - from torch.cuda.amp import autocast, GradScaler + from torch.cuda.amp import autocast # pylint: disable=import-error from src.utils.logging import get_logger -from src.engines.utils.training_loops import ( +from src.engines.utils.training_core import ( _maybe_scheduler_step, - _evaluate, + _create_grad_scaler, + _update_history_and_log, + _preprocess_batch, + _run_validation_and_scheduler, + _update_best_model_state, ) log = get_logger(__name__) @@ -52,19 +57,10 @@ def train_probe( # pylint: disable=too-many-arguments,too-many-locals,too-many- model.train() # Initialize GradScaler safely across torch versions - try: - scaler = GradScaler( - enabled=mixed_precision, - init_scale=2.0**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - ) - except TypeError: - scaler = GradScaler(enabled=mixed_precision) + scaler = _create_grad_scaler(mixed_precision) sched, sched_meta = scheduler or (None, {}) - best_metric = -math.inf + best_metric = -math.inf if not metric_key.endswith("loss") else math.inf best_state_dict = None history = {"train_loss": [], "val_loss": [], "val_acc": [], "lr": []} @@ -74,11 +70,7 @@ def train_probe( # pylint: disable=too-many-arguments,too-many-locals,too-many- running_loss, n_seen = 0.0, 0 for step, batch in enumerate(loaders["train"], start=1): - if isinstance(batch, dict): - x, y = batch["image"], batch["label"] - else: - x, y = batch - x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) + x, y = _preprocess_batch(batch, device) optimizer.zero_grad(set_to_none=True) with autocast(device_type=device.type, enabled=mixed_precision): @@ -115,63 +107,40 @@ def train_probe( # pylint: disable=too-many-arguments,too-many-locals,too-many- train_loss = running_loss / max(n_seen, 1) - # ---- validation - val_loss, val_acc = _evaluate( + # ---- validation and scheduler step + val_loss, val_acc = _run_validation_and_scheduler( model=model, - loader=loaders["val"], + loaders=loaders, loss_fn=loss_fn, device=device, mixed_precision=mixed_precision, + sched=sched, + sched_meta=sched_meta, + metric_key=metric_key, ) - # scheduler on epoch or val metric - if sched is not None: - if isinstance(sched, torch.optim.lr_scheduler.ReduceLROnPlateau): - sched_meta["monitor"] = ( - val_loss if metric_key.endswith("loss") else val_acc - ) - _maybe_scheduler_step(sched_meta, sched, on="val") - else: - _maybe_scheduler_step(sched_meta, sched, on="epoch") - # logging cur_lr = optimizer.param_groups[0]["lr"] - history["train_loss"].append(train_loss) - history["val_loss"].append(val_loss) - history["val_acc"].append(val_acc) - history["lr"].append(cur_lr) - - if wandb_logger: - wandb_logger.log( - { - "epoch": epoch, - "train/loss_epoch": train_loss, - "val/loss": val_loss, - "val/acc": val_acc, - "lr": cur_lr, - } - ) - - log.info( - "Epoch %d | train_loss=%.4f | val_loss=%.4f | val_acc=%.4f | lr=%.2e", - epoch, - train_loss, - val_loss, - val_acc, - cur_lr, + _update_history_and_log( + history=history, + epoch=epoch, + train_loss=train_loss, + val_loss=val_loss, + val_acc=val_acc, + cur_lr=cur_lr, + wandb_logger=wandb_logger, + log=log, ) - monitor = val_loss if metric_key.endswith("loss") else val_acc - is_better = ( - (monitor < best_metric) - if metric_key.endswith("loss") - else (monitor > best_metric) + updated_state, best_metric, is_better = _update_best_model_state( + model=model, + metric_key=metric_key, + val_loss=val_loss, + val_acc=val_acc, + best_metric=best_metric, ) if is_better: - best_metric = monitor - best_state_dict = { - k: v.detach().cpu().clone() for k, v in model.state_dict().items() - } + best_state_dict = updated_state # restore best (optional: caller can save now) if best_state_dict is not None: diff --git a/src/engines/utils/training_core.py b/src/engines/utils/training_core.py new file mode 100644 index 0000000..d914703 --- /dev/null +++ b/src/engines/utils/training_core.py @@ -0,0 +1,276 @@ +# This source file is part of the Daneshjou Lab projects +# +# SPDX-FileCopyrightText: 2025 Stanford University +# SPDX-License-Identifier: MIT + +# -*- coding: utf-8 -*- +# pylint: disable=duplicate-code +"""Training engine utilities for fine-tuning and linear probing. + +This module provides helper functions used by other training engines: +- _accuracy: Computes top-1 accuracy +- _maybe_scheduler_step: Step schedulers based on configuration/meta +- _evaluate: Evaluate model on a loader with loss + accuracy +- _create_grad_scaler: Create a GradScaler with common configuration +- _update_history_and_log: Update training history and log results +- _is_metric_better: Check if current metric is better than best metric + +Imported by: +- src.engines.linear_probe_engine +- src.engines.finetune_engine +""" +from __future__ import annotations +from typing import Dict, Any, Tuple, List, Optional + +import torch +from torch import nn + +try: + # PyTorch 2.0+ unified AMP API + from torch.amp import autocast, GradScaler +except ImportError: + # Fallback for older PyTorch versions + from torch.cuda.amp import autocast, GradScaler + + +def _accuracy(logits: torch.Tensor, targets: torch.Tensor) -> float: + with torch.no_grad(): + preds = torch.argmax(logits, dim=1) + correct = (preds == targets).sum().item() + total = targets.numel() + return correct / max(total, 1) + + +def _maybe_scheduler_step(scheduler_meta: Dict[str, Any], scheduler, *, on: str): + """Step scheduler if meta says so. `on` ∈ {'batch','epoch','val'}.""" + step_when = str(scheduler_meta.get("step_per", "epoch")) + if step_when == on: + if "monitor" in scheduler_meta: + # e.g., ReduceLROnPlateau + scheduler.step(scheduler_meta["monitor"]) + else: + scheduler.step() + + +def _evaluate( + model: nn.Module, + loader, + loss_fn, + device: torch.device, + mixed_precision: bool, +) -> Tuple[float, float]: + model.eval() + running_loss, running_acc, n = 0.0, 0.0, 0 + with torch.no_grad(): + for batch in loader: + x, y = _preprocess_batch(batch, device) + + # Device-aware autocast (CPU/GPU) and version-safe + with autocast( + device_type=getattr(device, "type", "cuda"), enabled=mixed_precision + ): + logits = model(x) + loss = loss_fn(logits, y) + + bsz = y.size(0) + running_loss += float(loss.item()) * bsz + running_acc += _accuracy(logits, y) * bsz + n += bsz + + return running_loss / max(n, 1), running_acc / max(n, 1) + + +def _create_grad_scaler(mixed_precision: bool = True) -> GradScaler: + """Create a GradScaler with common configuration for mixed precision training. + + Args: + mixed_precision: Whether to enable mixed precision training. + + Returns: + A configured GradScaler instance. + """ + try: + return GradScaler( + enabled=mixed_precision, + init_scale=2.0**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + ) + except TypeError: + # Fallback for older PyTorch versions that don't support all parameters + return GradScaler(enabled=mixed_precision) + + +def _update_history_and_log( # pylint: disable=too-many-arguments + *, + history: Dict[str, List[float]], + epoch: int, + train_loss: float, + val_loss: float, + val_acc: float, + cur_lr: float, + wandb_logger: Optional[Any] = None, + log: Optional[Any] = None, +) -> None: + """Update training history and log results. + + Args: + history: Dictionary to store training history. + epoch: Current epoch number. + train_loss: Training loss for current epoch. + val_loss: Validation loss for current epoch. + val_acc: Validation accuracy for current epoch. + cur_lr: Current learning rate. + wandb_logger: Optional wandb logger instance. + log: Optional logger instance. + """ + # Update history + history["train_loss"].append(train_loss) + history["val_loss"].append(val_loss) + history["val_acc"].append(val_acc) + history["lr"].append(cur_lr) + + # Log to wandb if available + if wandb_logger: + wandb_logger.log( + { + "epoch": epoch, + "train/loss_epoch": train_loss, + "val/loss": val_loss, + "val/acc": val_acc, + "lr": cur_lr, + } + ) + + # Log to console if logger is available + if log: + log.info( + "Epoch %d | train_loss=%.4f | val_loss=%.4f | val_acc=%.4f | lr=%.2e", + epoch, + train_loss, + val_loss, + val_acc, + cur_lr, + ) + + +def _is_metric_better( + metric_key: str, current_metric: float, best_metric: float +) -> Tuple[bool, float]: + """Check if current metric is better than best metric. + + Args: + metric_key: Metric name, if ends with "loss" will minimize, otherwise maximize. + current_metric: Current metric value. + best_metric: Best metric value so far. + + Returns: + Tuple of (is_better, best_metric_value) + """ + minimize = metric_key.endswith("loss") + is_better = ( + (current_metric < best_metric) if minimize else (current_metric > best_metric) + ) + + if is_better: + return True, current_metric + return False, best_metric + + +def _preprocess_batch(batch, device): + """Preprocess a batch by moving data to the appropriate device. + + Args: + batch: Either a dictionary with 'image'/'label' keys or a tuple (x, y). + device: The target device for tensors. + + Returns: + Tuple of (input_tensor, target_tensor) on the specified device. + """ + if isinstance(batch, dict): + x, y = batch.get("image", batch.get("pixel_values")), batch.get( + "label", batch.get("labels") + ) + else: + x, y = batch + + return x.to(device, non_blocking=True), y.to(device, non_blocking=True) + + +def _run_validation_and_scheduler( # pylint: disable=too-many-arguments + *, + model: nn.Module, + loaders: Dict[str, Any], + loss_fn: Any, + device: torch.device, + mixed_precision: bool, + sched: Any, + sched_meta: Dict[str, Any], + metric_key: str, +) -> Tuple[float, float]: + """Run validation and handle scheduler steps. + + Args: + model: The model to evaluate. + loaders: Dictionary of data loaders. + loss_fn: Loss function. + device: Target device. + mixed_precision: Whether to use mixed precision. + sched: Optional scheduler. + sched_meta: Scheduler metadata. + metric_key: Metric key to monitor. + + Returns: + Tuple of (validation_loss, validation_accuracy). + """ + # Run validation + val_loss, val_acc = _evaluate( + model=model, + loader=loaders["val"], + loss_fn=loss_fn, + device=device, + mixed_precision=mixed_precision, + ) + + # Handle scheduler steps + if sched is not None: + if isinstance(sched, torch.optim.lr_scheduler.ReduceLROnPlateau): + sched_meta["monitor"] = val_loss if metric_key.endswith("loss") else val_acc + _maybe_scheduler_step(sched_meta, sched, on="val") + else: + _maybe_scheduler_step(sched_meta, sched, on="epoch") + + return val_loss, val_acc + + +def _update_best_model_state( + *, + model: nn.Module, + metric_key: str, + val_loss: float, + val_acc: float, + best_metric: float, +) -> Tuple[Dict[str, torch.Tensor], float, bool]: + """Update best model state if current metric is better. + + Args: + model: The model to get state from. + metric_key: Metric key to monitor. + val_loss: Validation loss. + val_acc: Validation accuracy. + best_metric: Best metric value so far. + + Returns: + Tuple of (best_state_dict, best_metric, is_better). + """ + monitor = val_loss if metric_key.endswith("loss") else val_acc + is_better, updated_best_metric = _is_metric_better(metric_key, monitor, best_metric) + + best_state_dict = None + if is_better: + best_state_dict = { + k: v.detach().cpu().clone() for k, v in model.state_dict().items() + } + + return best_state_dict, updated_best_metric, is_better diff --git a/src/engines/utils/training_loops.py b/src/engines/utils/training_loops.py deleted file mode 100644 index 8ff8058..0000000 --- a/src/engines/utils/training_loops.py +++ /dev/null @@ -1,78 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University -# SPDX-License-Identifier: MIT - -# -*- coding: utf-8 -*- -"""Training engine utilities for fine-tuning and linear probing. - -This module provides helper functions used by other training engines: -- _accuracy: Computes top-1 accuracy -- _maybe_scheduler_step: Step schedulers based on configuration/meta -- _evaluate: Evaluate model on a loader with loss + accuracy - -Imported by: -- src.engines.linear_probe_engine -- src.engines.finetune_engine -""" -from __future__ import annotations -from typing import Dict, Any, Tuple - -import torch -from torch import nn - -try: - # PyTorch 2.0+ unified AMP API - from torch.amp import autocast -except ImportError: - # Fallback for older PyTorch versions - from torch.cuda.amp import autocast - - -def _accuracy(logits: torch.Tensor, targets: torch.Tensor) -> float: - with torch.no_grad(): - preds = torch.argmax(logits, dim=1) - correct = (preds == targets).sum().item() - total = targets.numel() - return correct / max(total, 1) - - -def _maybe_scheduler_step(scheduler_meta: Dict[str, Any], scheduler, *, on: str): - """Step scheduler if meta says so. `on` ∈ {'batch','epoch','val'}.""" - step_when = str(scheduler_meta.get("step_per", "epoch")) - if step_when == on: - if "monitor" in scheduler_meta: - # e.g., ReduceLROnPlateau - scheduler.step(scheduler_meta["monitor"]) - else: - scheduler.step() - - -def _evaluate( - model: nn.Module, - loader, - loss_fn, - device: torch.device, - mixed_precision: bool, -) -> Tuple[float, float]: - model.eval() - running_loss, running_acc, n = 0.0, 0.0, 0 - with torch.no_grad(): - for batch in loader: - if isinstance(batch, dict): - x, y = batch.get("image"), batch.get("label") - else: - x, y = batch - x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) - - # Device-aware autocast (CPU/GPU) and version-safe - with autocast(device_type=getattr(device, "type", "cuda"), enabled=mixed_precision): - logits = model(x) - loss = loss_fn(logits, y) - - bsz = y.size(0) - running_loss += float(loss.item()) * bsz - running_acc += _accuracy(logits, y) * bsz - n += bsz - - return running_loss / max(n, 1), running_acc / max(n, 1) From b59150d1455fcf2012e4fb4350b0326c63a3960d Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Tue, 21 Oct 2025 18:41:42 -0700 Subject: [PATCH 25/26] docs(train): document support for Hydra-based hyperparameter sweeps - hyperparameter tuning is natively supported via Hydra multiruns --- src/cli/train.py | 190 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 156 insertions(+), 34 deletions(-) diff --git a/src/cli/train.py b/src/cli/train.py index 5a830f4..6123efb 100644 --- a/src/cli/train.py +++ b/src/cli/train.py @@ -1,19 +1,53 @@ # This source file is part of the Daneshjou Lab projects # # SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# # SPDX-License-Identifier: MIT # src/cli/train.py +# -*- coding: utf-8 -*- -"""CLI entry point for training models with different paradigms (probe/finetune/distill). - -Usage: -python -m src.cli.train train.mode=probe \ - dataset.name=isic2019 dataset.image_size=224 dataset.batch_size=128 \ +""" +CLI entry point for training/evaluating models across paradigms (probe/finetune). +It normalizes dataset config → data config, builds a BaseDataModule using the +dataset factory, and dispatches to the selected training wrapper. + +Usage examples: + python -m src.cli.train train.mode=probe \ + dataset.name=isic2019 dataset.data_dir=/data/ISIC \ + dataset.image_size=224 dataset.batch_size=128 \ + model.type=vit model.model_id=google/vit-base-patch16-224 + + # With controlled degradation: downsample to 112px, then pipeline resizes to 224 + python -m src.cli.train train.mode=probe \ + dataset.name=isic2019 dataset.data_dir=/data/ISIC \ + dataset.image_size=224 dataset.batch_size=128 \ + dataset.degradation.target_resolution=112 \ + model.type=vit model.model_id=google/vit-base-patch16-224 + + # Random search (build-in basic sweep) + python -m src.cli.train -m \ + hydra.sweeper=basic \ + hydra.sweeper.n_trials=20 \ + hydra.sweeper.params='optim.lr=log(1e-5,1e-3); optim.weight_decay=uniform(0,0.1); dataset.batch_size=choice(64,128)' \ + train.mode=probe \ + dataset.name=isic2019 dataset.data_dir=/data/ISIC \ + dataset.image_size=224 \ model.type=vit model.model_id=google/vit-base-patch16-224 + + OR + + python -m src.cli.train -m \ + train.mode=probe \ + dataset.name=isic2019 dataset.data_dir=/data/ISIC \ + dataset.image_size=224 dataset.batch_size=128 \ + model.type=vit model.model_id=google/vit-base-patch16-224 \ + optim.lr=1e-5,3e-5,1e-4,3e-4 \ + optim.weight_decay=0.0,0.01 + """ +from __future__ import annotations + import os import sys import time @@ -21,22 +55,33 @@ import random import logging from pathlib import Path -from typing import Dict, Any +from typing import Dict, Any, Optional, Tuple, Union import torch import numpy as np import hydra from omegaconf import DictConfig, OmegaConf -# ---- Wrappers implement run(cfg) and return a dict of metrics +# ---- Training wrappers (each provides run(cfg) -> dict of metrics) from src.wrappers import probe as probe_wrapper from src.wrappers import finetune as finetune_wrapper -log = logging.getLogger("train") +# ---- Data pipeline +from src.data.datamodule import BaseDataModule +from src.transformations.transforms import ResolutionReductionTransform + +# ---- Optional HF preprocessor (only needed when actually running a HF backbone) +try: + from transformers import AutoImageProcessor # noqa: F401 +except Exception: # pylint: broad-exception-caught # pragma: no cover + AutoImageProcessor = None # type: ignore +log = logging.getLogger("train") -# ------------------------- helpers ------------------------- +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- def _is_rank_zero() -> bool: local_rank = int(os.environ.get("LOCAL_RANK", "0")) @@ -44,12 +89,12 @@ def _is_rank_zero() -> bool: def _select_device(cfg: DictConfig) -> torch.device: - if "device" in cfg.train and cfg.train.device: + if "train" in cfg and "device" in cfg.train and cfg.train.device: return torch.device(cfg.train.device) return torch.device("cuda" if torch.cuda.is_available() else "cpu") -def _seed_everything(seed: int, deterministic: bool = False): +def _seed_everything(seed: int, deterministic: bool = False) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) @@ -61,7 +106,7 @@ def _seed_everything(seed: int, deterministic: bool = False): torch.backends.cudnn.benchmark = True -def _save_resolved_config(cfg: DictConfig, run_dir: Path): +def _save_resolved_config(cfg: DictConfig, run_dir: Path) -> None: if not _is_rank_zero(): return run_dir.mkdir(parents=True, exist_ok=True) @@ -69,14 +114,15 @@ def _save_resolved_config(cfg: DictConfig, run_dir: Path): OmegaConf.save(config=cfg, f=f.name) -def _print_run_header(cfg: DictConfig, run_dir: Path, device: torch.device): +def _print_run_header(cfg: DictConfig, run_dir: Path, device: torch.device) -> None: if not _is_rank_zero(): return + model_name = getattr(cfg.model, "name", getattr(cfg.model, "type", "N/A")) banner = ( f"\n=== TRAIN START ===\n" f"mode : {cfg.train.mode}\n" f"dataset : {cfg.dataset.name}\n" - f"model : {getattr(cfg.model, 'name', getattr(cfg.model, 'type', 'N/A'))}\n" + f"model : {model_name}\n" f"device : {device}\n" f"seed : {cfg.seed}\n" f"run_dir : {str(run_dir)}\n" @@ -94,8 +140,7 @@ def _dispatch_wrapper(cfg: DictConfig) -> Dict[str, Any]: # if mode == "distill": # return distill_wrapper.run(cfg) raise ValueError( - f"Unknown train.mode='{cfg.train.mode}'. " - f"Expected one of: probe | finetune | distill" + f"Unknown train.mode='{cfg.train.mode}'. Expected one of: probe | finetune | distill" ) @@ -113,54 +158,130 @@ def _normalize_dataset_into_data(cfg: DictConfig) -> None: Expected cfg.dataset fields (typical): - name: str (e.g., "isic2019", "chexpert") - data_dir: str or null (some loaders use HF hub instead) - - image_size: int - - num_classes: int + - image_size: int (model input size, e.g., 224) + - num_classes: int (optional; forwarded to model.config.num_labels) - batch_size: int - num_workers: int - pin_memory: bool (optional) + - degradation: (optional group) + target_resolution: int | [w,h] + restore_original_size: bool + # or + reduction_factor: float in (0,1] - After this, cfg.data will contain: + After this, cfg.data contains: dataset_name, data_dir, image_size, batch_size, num_workers, pin_memory """ if "dataset" not in cfg: raise ValueError("Config is missing 'dataset' group (cfg.dataset.*).") ds = cfg.dataset - # Create cfg.data if missing if "data" not in cfg or cfg.data is None: cfg.data = OmegaConf.create() - # Copy/normalize fields cfg.data.dataset_name = str(ds.get("name")) cfg.data.data_dir = ds.get("data_dir") # can be None for HF datasets cfg.data.image_size = int(ds.get("image_size", 224)) - cfg.data.batch_size = int( - ds.get("batch_size", getattr(cfg.train, "batch_size", 64)) - ) + cfg.data.batch_size = int(ds.get("batch_size", getattr(cfg.train, "batch_size", 64))) cfg.data.num_workers = int(ds.get("num_workers", 4)) cfg.data.pin_memory = bool(ds.get("pin_memory", True)) - # Optional: enforce/propagate num_classes to model config + # Optional: propagate num_classes → model.config.num_labels num_classes = ds.get("num_classes", None) if num_classes is not None: if "model" not in cfg: cfg.model = OmegaConf.create() if "config" not in cfg.model or cfg.model.config is None: cfg.model.config = OmegaConf.create() - # Many factories look for "num_labels" cfg.model.config.num_labels = int(num_classes) - # Sanity checks if not cfg.data.dataset_name: raise ValueError("cfg.dataset.name must be set (e.g., 'isic2019').") -# ------------------------- main ------------------------- +def _build_degradation_transform(degr_cfg: DictConfig) -> Optional[ResolutionReductionTransform]: + """ + Build a ResolutionReductionTransform from a degradation config group. + Supports: + - target_resolution: int | [w,h] + - restore_original_size: bool + - reduction_factor: float + """ + if not degr_cfg: + return None + + # normalize target_resolution + target_res: Optional[Union[int, Tuple[int, int]]] = getattr(degr_cfg, "target_resolution", None) + restore: bool = bool(getattr(degr_cfg, "restore_original_size", False)) + reduction_factor = getattr(degr_cfg, "reduction_factor", None) + + if target_res is not None: + if isinstance(target_res, int): + target_res = (int(target_res), int(target_res)) + else: + target_res = tuple(int(x) for x in target_res) + return ResolutionReductionTransform( + target_resolution=target_res, restore_original_size=restore + ) + + if reduction_factor is not None: + return ResolutionReductionTransform( + reduction_factor=float(reduction_factor), restore_original_size=restore + ) + + return None + + +def _build_datamodule(cfg: DictConfig) -> BaseDataModule: + """ + Construct the BaseDataModule using normalized cfg.data and optional + degradation transforms. Keeps HF preprocessor optional. + """ + # Optional degradation transform + degr_cfg = getattr(cfg.dataset, "degradation", None) + transform = _build_degradation_transform(degr_cfg) if degr_cfg else None + + # Optional HF preprocessor (only if actually running a HF backbone) + preproc = None + model_id = getattr(cfg.model, "model_id", None) + if model_id and AutoImageProcessor is not None: + try: + preproc = AutoImageProcessor.from_pretrained(model_id) + except Exception: + preproc = None # safe no-op path + + image_size = int(getattr(cfg.data, "image_size", 224)) + batch_size = int(getattr(cfg.data, "batch_size", 64)) + num_workers = int(getattr(cfg.data, "num_workers", 4)) + pin_memory = bool(getattr(cfg.data, "pin_memory", True)) + model_type = str(getattr(cfg.model, "type", "vit")) + + dm = BaseDataModule( + cfg=cfg, + dataset_name=str(cfg.data.dataset_name), + data_dir=getattr(cfg.data, "data_dir", None), + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + preprocessor=preproc, # None -> ModelPreprocessor no-op + resolution=image_size, # model input size (e.g., 224) + transform=transform, # may be None + model_type=model_type, + ) + + # Expose to wrappers + cfg.runtime = dict(getattr(cfg, "runtime", {})) + cfg.runtime["datamodule"] = dm + return dm + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- @hydra.main(config_path="../../configs", config_name="defaults", version_base=None) def main(cfg: DictConfig): """Main training CLI entry point.""" - # Allow wrappers to attach runtime fields if needed OmegaConf.set_struct(cfg, False) # Normalize dataset selection into cfg.data for wrappers/datamodules @@ -185,6 +306,10 @@ def main(cfg: DictConfig): "world_size": int(os.environ.get("WORLD_SIZE", "1")), } + # Build & setup datamodule using dataset factory logic + dm = _build_datamodule(cfg) + dm.setup(stage="fit") # prepares train/val (or splits train if no val split) + # Persist resolved config and print header _save_resolved_config(cfg, run_dir) _print_run_header(cfg, run_dir, device) @@ -206,12 +331,9 @@ def main(cfg: DictConfig): metrics = metrics or {} with open(run_dir / "final_metrics.json", "w", encoding="utf-8") as f: json.dump(metrics, f, indent=2) - print( - "✅ Train done. Final metrics written to final_metrics.json\n", flush=True - ) + print("✅ Train done. Final metrics written to final_metrics.json\n", flush=True) if __name__ == "__main__": - # Example: python -m src.cli.train train.mode=probe dataset.name=isic2019 # pylint: disable=no-value-for-parameter sys.exit(main()) From 681986ccd12841a14aa0e24911114b4cbfbf2987 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Tue, 21 Oct 2025 19:35:29 -0700 Subject: [PATCH 26/26] =?UTF-8?q?-=20Updated=20trial.py=20to=20load=20ISIC?= =?UTF-8?q?BaseDataset=20without=20transform=20for=20correct=20original=20?= =?UTF-8?q?(224=C3=97224)=20images?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + requirements.txt | 1 - scripts/demo_transformation.py | 114 -------------- scripts/trial.py | 56 +++++++ src/cli/train.py | 46 ++++-- src/config.py | 31 ++-- src/data/datamodule.py | 4 +- src/data/dataset_factory.py | 10 +- src/data/isic_loader.py | 179 ++++++++++++++++++---- src/engines/finetune_engine.py | 4 +- src/engines/linear_probe_engine.py | 4 +- src/engines/{utils => }/training_core.py | 0 src/losses/classification.py | 11 +- src/losses/distillation.py | 23 ++- src/transformation/transforms.py | 17 +- src/utils/callbacks_hf.py | 4 +- src/utils/constants.py | 2 - src/utils/{logging.py => logging_core.py} | 109 +++++++++++-- src/utils/optim.py | 78 +++++++--- src/utils/training_utils.py | 2 + src/utils/utils.py | 1 + src/wrappers/finetune.py | 2 +- src/wrappers/probe.py | 2 +- 23 files changed, 456 insertions(+), 245 deletions(-) delete mode 100644 scripts/demo_transformation.py create mode 100644 scripts/trial.py rename src/engines/{utils => }/training_core.py (100%) rename src/utils/{logging.py => logging_core.py} (51%) diff --git a/.gitignore b/.gitignore index 2afc05e..070e362 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ __pycache__/ *.py[codz] *$py.class +.python-version # C extensions *.so diff --git a/requirements.txt b/requirements.txt index 1e41dc8..c224c16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,6 @@ fonttools h11 httpcore httpx -hydra hydra-core==1.3.2 idna importlib_metadata diff --git a/scripts/demo_transformation.py b/scripts/demo_transformation.py deleted file mode 100644 index ddf0b4d..0000000 --- a/scripts/demo_transformation.py +++ /dev/null @@ -1,114 +0,0 @@ -# This source file is part of the Daneshjou Lab projects -# -# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) -# -# SPDX-License-Identifier: MIT - -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# pylint: disable=broad-exception-caught,wrong-import-position,import-error - -""" -Apply ResolutionReductionTransform to a few ISIC images and save: -- original_.png -- resolution_reduced_.png (reduced-size, not upsampled) -""" - -import sys -from pathlib import Path -from typing import Iterable, Optional - -# Add src to Python path for imports -current_dir = Path(__file__).parent -project_root = current_dir.parent -sys.path.insert(0, str(project_root)) - -from datasets import load_dataset -from src.data.isic_loader import ISICBaseDataset -from src.transformation.transforms import ResolutionReductionTransform - - -def load_isic_dataset() -> Optional[ISICBaseDataset]: - """Load ISIC dataset from HuggingFace.""" - print("🔄 Loading ISIC dataset from HuggingFace...") - try: - hf_dataset = load_dataset("MKZuziak/ISIC_2019_224", split="train") - ds = ISICBaseDataset(hf_dataset) - print(f"✅ Loaded {len(ds)} samples") - return ds - except Exception as e: - print(f"❌ Failed to load dataset: {e}") - print("Make sure you have 'datasets' installed: pip install datasets") - return None - - -def process_images( - dataset: ISICBaseDataset, - indices: Iterable[int], - transform: ResolutionReductionTransform, - output_dir: Path, - save_prefix: str = "resolution_reduced", -) -> int: - """ - For each index: save original and one reduced image using `transform`. - Assumes transform returns the reduced-size image (no upsample). - """ - output_dir.mkdir(parents=True, exist_ok=True) - print(f"📁 Output directory: {output_dir.absolute()}") - - processed = 0 - for idx in indices: - try: - sample = dataset[idx] - original = sample["image"] - label = sample.get("label", None) - print(f"\nImage {idx} | Original size: {original.size} | Label: {label}") - - # Save original - orig_path = output_dir / f"original_{idx}.png" - original.save(orig_path) - print(f"Saved original → {orig_path.name}") - - # Save reduced (actual transformed size) - reduced = transform(original) - out_path = output_dir / f"{save_prefix}_{idx}.png" - reduced.save(out_path) - print(f" 🔧 Reduced to: {reduced.size}") - print(f" 💾 Saved reduced → {out_path.name}") - - processed += 1 - - except Exception as e: - print(f" ❌ Error at index {idx}: {e}") - - return processed - - -def main(): - ds = load_isic_dataset() - if ds is None: - return - - output_dir = Path("outputs") - num_images = min(5, len(ds)) - indices = range(num_images) - - # Use target size and DO NOT upsample back in the transform - resolution_transform = ResolutionReductionTransform( - target_resolution=(54, 54), - restore_original_size=False - ) - - print("\n▶️ Saving originals and reduced versions for first 5 images...") - n = process_images( - dataset=ds, - indices=indices, - transform=resolution_transform, - output_dir=output_dir, - save_prefix="resolution_reduced", - ) - print(f"\n✅ Done. Saved {n} reduced images. Check: {output_dir.absolute()}") - - -if __name__ == "__main__": - main() diff --git a/scripts/trial.py b/scripts/trial.py new file mode 100644 index 0000000..f967db0 --- /dev/null +++ b/scripts/trial.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: MIT +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# pylint: disable=all +""" +Trial: load ISIC (HF, 224×224), then save original vs. 112×112 downsampled. +""" + +import sys +from pathlib import Path + +current_dir = Path(__file__).resolve().parent +project_root = current_dir.parent +sys.path.insert(0, str(project_root)) + +from src.data.isic_loader import ISICBaseDataset # HF-backed +from src.transformation.transforms import ResolutionReductionTransform + +def main(): + out_dir = Path("outputs/trial_isic112") + out_dir.mkdir(parents=True, exist_ok=True) + + # IMPORTANT: no transform here, so we get true originals from HF (224×224) + ds = ISICBaseDataset( + repo_id="MKZuziak/ISIC_2019_224", + split="train", + transform=None, + ) + print(f"Loaded {len(ds)} samples") + + # Build the reduction transform to 112×112 + reduce112 = ResolutionReductionTransform( + target_resolution=(112, 112), + restore_original_size=False, + ) + + for i in range(5): + sample = ds[i] + img = sample["image"] + lbl = sample["label"] + + orig_path = out_dir / f"original_{i}.png" + img.save(orig_path) + + reduced = reduce112(img) + red_path = out_dir / f"reduced_{i}.png" + reduced.save(red_path) + + print(f"[{i}] label={lbl} | original={img.size} → reduced={reduced.size} | " + f"saved: {orig_path.name}, {red_path.name}") + + print(f"\n✅ Done. Check: {out_dir.resolve()}") + +if __name__ == "__main__": + main() diff --git a/src/cli/train.py b/src/cli/train.py index 6123efb..1a9e5e5 100644 --- a/src/cli/train.py +++ b/src/cli/train.py @@ -5,7 +5,7 @@ # src/cli/train.py # -*- coding: utf-8 -*- - +# pylint: disable=import-error, broad-exception-caught """ CLI entry point for training/evaluating models across paradigms (probe/finetune). It normalizes dataset config → data config, builds a BaseDataModule using the @@ -28,7 +28,8 @@ python -m src.cli.train -m \ hydra.sweeper=basic \ hydra.sweeper.n_trials=20 \ - hydra.sweeper.params='optim.lr=log(1e-5,1e-3); optim.weight_decay=uniform(0,0.1); dataset.batch_size=choice(64,128)' \ + hydra.sweeper.params='optim.lr=log(1e-5,1e-3); optim.weight_decay=uniform(0,0.1); \ +dataset.batch_size=choice(64,128)' \ train.mode=probe \ dataset.name=isic2019 dataset.data_dir=/data/ISIC \ dataset.image_size=224 \ @@ -59,21 +60,24 @@ import torch import numpy as np -import hydra -from omegaconf import DictConfig, OmegaConf +import hydra # pylint: disable=import-error +from omegaconf import DictConfig, OmegaConf # pylint: disable=import-error # ---- Training wrappers (each provides run(cfg) -> dict of metrics) -from src.wrappers import probe as probe_wrapper -from src.wrappers import finetune as finetune_wrapper +from src.wrappers import probe as probe_wrapper # pylint: disable=import-error +from src.wrappers import finetune as finetune_wrapper # pylint: disable=import-error # ---- Data pipeline -from src.data.datamodule import BaseDataModule -from src.transformations.transforms import ResolutionReductionTransform +from src.data.datamodule import BaseDataModule # pylint: disable=import-error +from src.transformations.transforms import ( + ResolutionReductionTransform, +) # pylint: disable=import-error # ---- Optional HF preprocessor (only needed when actually running a HF backbone) try: from transformers import AutoImageProcessor # noqa: F401 -except Exception: # pylint: broad-exception-caught # pragma: no cover +except ImportError: # pragma: no cover + # If transformers is not installed, we continue without it AutoImageProcessor = None # type: ignore log = logging.getLogger("train") @@ -83,6 +87,7 @@ # Helper functions # --------------------------------------------------------------------------- + def _is_rank_zero() -> bool: local_rank = int(os.environ.get("LOCAL_RANK", "0")) return local_rank == 0 @@ -182,7 +187,9 @@ def _normalize_dataset_into_data(cfg: DictConfig) -> None: cfg.data.dataset_name = str(ds.get("name")) cfg.data.data_dir = ds.get("data_dir") # can be None for HF datasets cfg.data.image_size = int(ds.get("image_size", 224)) - cfg.data.batch_size = int(ds.get("batch_size", getattr(cfg.train, "batch_size", 64))) + cfg.data.batch_size = int( + ds.get("batch_size", getattr(cfg.train, "batch_size", 64)) + ) cfg.data.num_workers = int(ds.get("num_workers", 4)) cfg.data.pin_memory = bool(ds.get("pin_memory", True)) @@ -199,7 +206,9 @@ def _normalize_dataset_into_data(cfg: DictConfig) -> None: raise ValueError("cfg.dataset.name must be set (e.g., 'isic2019').") -def _build_degradation_transform(degr_cfg: DictConfig) -> Optional[ResolutionReductionTransform]: +def _build_degradation_transform( + degr_cfg: DictConfig, +) -> Optional[ResolutionReductionTransform]: """ Build a ResolutionReductionTransform from a degradation config group. Supports: @@ -211,7 +220,9 @@ def _build_degradation_transform(degr_cfg: DictConfig) -> Optional[ResolutionRed return None # normalize target_resolution - target_res: Optional[Union[int, Tuple[int, int]]] = getattr(degr_cfg, "target_resolution", None) + target_res: Optional[Union[int, Tuple[int, int]]] = getattr( + degr_cfg, "target_resolution", None + ) restore: bool = bool(getattr(degr_cfg, "restore_original_size", False)) reduction_factor = getattr(degr_cfg, "reduction_factor", None) @@ -263,9 +274,9 @@ def _build_datamodule(cfg: DictConfig) -> BaseDataModule: batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, - preprocessor=preproc, # None -> ModelPreprocessor no-op - resolution=image_size, # model input size (e.g., 224) - transform=transform, # may be None + preprocessor=preproc, # None -> ModelPreprocessor no-op + resolution=image_size, # model input size (e.g., 224) + transform=transform, # may be None model_type=model_type, ) @@ -279,6 +290,7 @@ def _build_datamodule(cfg: DictConfig) -> BaseDataModule: # Main # --------------------------------------------------------------------------- + @hydra.main(config_path="../../configs", config_name="defaults", version_base=None) def main(cfg: DictConfig): """Main training CLI entry point.""" @@ -331,7 +343,9 @@ def main(cfg: DictConfig): metrics = metrics or {} with open(run_dir / "final_metrics.json", "w", encoding="utf-8") as f: json.dump(metrics, f, indent=2) - print("✅ Train done. Final metrics written to final_metrics.json\n", flush=True) + print( + "✅ Train done. Final metrics written to final_metrics.json\n", flush=True + ) if __name__ == "__main__": diff --git a/src/config.py b/src/config.py index 10a4b6c..00c50c6 100644 --- a/src/config.py +++ b/src/config.py @@ -1,4 +1,3 @@ - # This source file is part of the Daneshjou Lab projects # # SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see AUTHORS.md) @@ -7,7 +6,15 @@ """Configuration and constants.""" from dataclasses import dataclass -from typing import List, Dict, Any +from typing import Dict, Any + +# Optional import - used for hardware detection +try: + import torch # pylint: disable=import-error + CUDA_AVAILABLE = torch.cuda.is_available() +except ImportError: + # If torch is not available, assume no CUDA + CUDA_AVAILABLE = False # Model constants HF_MODELS = ["vit", "dinov2"] @@ -24,9 +31,11 @@ "std": [0.229, 0.224, 0.225], } + @dataclass -class TrainingConfig: +class TrainingConfig: # pylint: disable=too-many-instance-attributes """Training configuration.""" + num_train_images: int = 100 proportion_per_transform: float = 0.2 resolution: int = 224 @@ -42,30 +51,24 @@ def to_dict(self) -> Dict[str, Any]: def to_wandb_config(self) -> Dict[str, Any]: """Create wandb configuration.""" - import torch return { **self.to_dict(), - "gpu_available": torch.cuda.is_available(), + "gpu_available": CUDA_AVAILABLE, } + # Model configurations MODEL_REGISTRY = [ { "name": "vit", "model_id": "google/vit-base-patch16-224", "type": "vit", - "config": { - "num_labels": NUM_FILTERED_CLASSES, - "ignore_mismatched_sizes": True - } + "config": {"num_labels": NUM_FILTERED_CLASSES, "ignore_mismatched_sizes": True}, }, { "name": "dinov2", "model_id": "facebook/dinov2-base", "type": "dinov2", - "config": { - "num_labels": NUM_FILTERED_CLASSES, - "ignore_mismatched_sizes": True - } + "config": {"num_labels": NUM_FILTERED_CLASSES, "ignore_mismatched_sizes": True}, }, -] \ No newline at end of file +] diff --git a/src/data/datamodule.py b/src/data/datamodule.py index e9a4967..ddf9dcb 100644 --- a/src/data/datamodule.py +++ b/src/data/datamodule.py @@ -13,8 +13,8 @@ from typing import Optional, Any # Third-party imports -import torch -from torch.utils.data import DataLoader, random_split, Subset +import torch # pylint: disable=import-error +from torch.utils.data import DataLoader, random_split, Subset # pylint: disable=import-error # Local imports # pylint: disable=import-error diff --git a/src/data/dataset_factory.py b/src/data/dataset_factory.py index 5ebe33e..5101b21 100644 --- a/src/data/dataset_factory.py +++ b/src/data/dataset_factory.py @@ -9,10 +9,10 @@ from typing import Optional, List, Dict, Any, Union # Third-party imports -import numpy as np -import torch -from PIL import Image -from torch.utils.data import Dataset, ConcatDataset, Subset +import numpy as np # pylint: disable=import-error +import torch # pylint: disable=import-error +from PIL import Image # pylint: disable=import-error +from torch.utils.data import Dataset, ConcatDataset, Subset # pylint: disable=import-error # Local imports # pylint: disable=import-error,relative-beyond-top-level @@ -340,7 +340,7 @@ def __len__(self) -> int: ... If you already have it elsewhere, just import and return it here. """ - # pylint: disable=import-outside-toplevel,relative-beyond-top-level + # pylint: disable=import-outside-toplevel,relative-beyond-top-level,import-error try: from src.data.isic_raw import ISICRawSplit # Use absolute import diff --git a/src/data/isic_loader.py b/src/data/isic_loader.py index f16403e..7d0b7f6 100644 --- a/src/data/isic_loader.py +++ b/src/data/isic_loader.py @@ -4,44 +4,167 @@ # # SPDX-License-Identifier: MIT -"""ISIC dataset loader implementation for dermatology image datasets.""" -from typing import Any, Dict, Union -from torch.utils.data import Dataset, Subset +"""ISIC dataset loader implementation for dermatology image datasets. -class ISICBaseDataset(Dataset): +This module provides a Hugging Face–backed loader for ISIC that returns +dicts shaped like: + {"image": PIL.Image, "label": int} + +Typical usage (HF-only, no DataModule required): + from src.data.isic_loader import ISICHFRawSplit + ds = ISICHFRawSplit(repo_id="MKZuziak/ISIC_2019_224", split="train") + sample = ds[0] + image, label = sample["image"], sample["label"] + +You may pass a PIL->PIL transform (e.g., ResolutionReductionTransform) via `transform=...`. +""" + +from __future__ import annotations + +from typing import Any, Dict, Optional, Sequence +from torch.utils.data import Dataset # pylint: disable=import-error +from PIL import Image + +try: + # Hugging Face datasets is required for this loader + from datasets import load_dataset, Image as HFImageFeature # type: ignore +except Exception as _e: # pragma: no cover + load_dataset = None + HFImageFeature = None + _HF_IMPORT_ERROR = _e + + +def _to_pil(x: Any) -> Image.Image: + """Best-effort conversion to PIL.Image.""" + if isinstance(x, Image.Image): + return x + try: + # HF Image feature usually yields PIL already; if not, try convert + return x.convert("RGB") + except Exception: + import numpy as np + return Image.fromarray(np.asarray(x)).convert("RGB") + + +class ISICHFRawSplit(Dataset): """ - Minimal, transformation-free wrapper for ISIC (or ISIC-like) datasets. + Hugging Face–backed ISIC split. Returns dicts: + {"image": PIL.Image, "label": int} - Expects the backing dataset (or Subset) to yield items with: - item["image"] : PIL.Image (or array/tensor if your source uses that) - item["label"] : int-like + Parameters + ---------- + repo_id : str + HF dataset repo id. Default: "MKZuziak/ISIC_2019_224" + split : str + Split name to load (e.g., "train"). Depends on the repo. + cache_dir : Optional[str] + Local HF cache directory. + image_column : str + Column name for image. Default: "image" + label_column : str + Column name for label. Default: "label" + transform : Optional[callable] + Optional PIL->PIL transform (e.g., ResolutionReductionTransform). + filter_fn : Optional[callable] + Optional row-level filter: receives an item dict and returns True/False. + If provided, the dataset builds an index of rows where filter_fn(item) is True. + keep_indices : Optional[Sequence[int]] + Optional explicit list of indices to keep (applied after filter_fn if both are given). - Returns each sample unchanged: - {"image": , "label": } + Notes + ----- + - Model-specific preprocessing (tensor conversion, normalization) is intentionally not applied here. + That belongs in your model-ready pipeline/DataModule. """ - def __init__(self, dataset: Union[Dataset, Subset]): - self.dataset = dataset + def __init__( + self, + *, + repo_id: str = "MKZuziak/ISIC_2019_224", + split: str = "train", + cache_dir: Optional[str] = None, + image_column: str = "image", + label_column: str = "label", + transform: Optional[Any] = None, + filter_fn: Optional[Any] = None, + keep_indices: Optional[Sequence[int]] = None, + ): + if load_dataset is None: # pragma: no cover + raise ImportError( + "Hugging Face `datasets` is required for ISICHFRawSplit. " + "Install via `pip install datasets`. " + f"Original import error: {_HF_IMPORT_ERROR}" + ) + + self.repo_id = repo_id + self.split = split + self.cache_dir = cache_dir + self.image_column = image_column + self.label_column = label_column + self.transform = transform + + # Load HF dataset split + self.ds = load_dataset(repo_id, split=split, cache_dir=cache_dir) + + # Validate image column if possible + try: + feats = self.ds.features + if image_column in feats and HFImageFeature is not None: + # If it's an HF Image feature, decoding to PIL happens on access. + pass + except Exception: + pass # proceed; we'll coerce per-sample in __getitem__ + + # Build index (filter + keep_indices) + idx = list(range(len(self.ds))) + if filter_fn is not None: + filtered = [] + for i in idx: + try: + if filter_fn(self.ds[i]): + filtered.append(i) + except Exception: + # Skip rows that fail the filter function + continue + idx = filtered + if keep_indices is not None: + # Intersect in original order + keep_set = set(int(k) for k in keep_indices) + idx = [i for i in idx if i in keep_set] + + self._indices = idx + + self.class_names = None + try: + label_feat = self.ds.features.get(label_column, None) + names = getattr(label_feat, "names", None) + if names: + self.class_names = tuple(names) + except Exception: + self.class_names = None def __len__(self) -> int: - return len(self.dataset) + return len(self._indices) def __getitem__(self, idx: int) -> Dict[str, Any]: - # Handle Subset wrapping - base = self.dataset - if hasattr(base, "dataset") and hasattr(base, "indices"): - # Subset case - item = base.dataset[int(base.indices[idx])] - else: - item = base[idx] - - # no resize, no cast, no transforms - image = item["image"] - label = item["label"] + real_idx = int(self._indices[idx]) + item = self.ds[real_idx] - try: - label = int(label) - except Exception as exc: - raise TypeError("Label must be convertible to int.") from exc + image = _to_pil(item[self.image_column]) + label = int(item[self.label_column]) + + if self.transform is not None: + image = self.transform(image) return {"image": image, "label": label} + + +class ISICBaseDataset(ISICHFRawSplit): # type: ignore[misc] + """ + Backwards-compatible alias for the previous ISICBaseDataset, now backed by HF. + + Usage (unchanged import path): + from src.data.isic_loader import ISICBaseDataset + ds = ISICBaseDataset(repo_id="MKZuziak/ISIC_2019_224", split="train") + """ + pass \ No newline at end of file diff --git a/src/engines/finetune_engine.py b/src/engines/finetune_engine.py index 7b0394f..f014bd3 100644 --- a/src/engines/finetune_engine.py +++ b/src/engines/finetune_engine.py @@ -21,8 +21,8 @@ from torch.cuda.amp import autocast # pylint: disable=import-error -from src.utils.logging import get_logger -from src.engines.utils.training_core import ( +from src.utils.logging_core import get_logger +from src.engines.training_core import ( _maybe_scheduler_step, _create_grad_scaler, _update_history_and_log, diff --git a/src/engines/linear_probe_engine.py b/src/engines/linear_probe_engine.py index f0e6a69..9b5d5e4 100644 --- a/src/engines/linear_probe_engine.py +++ b/src/engines/linear_probe_engine.py @@ -22,8 +22,8 @@ from torch.cuda.amp import autocast # pylint: disable=import-error -from src.utils.logging import get_logger -from src.engines.utils.training_core import ( +from src.utils.logging_core import get_logger +from src.engines.training_core import ( _maybe_scheduler_step, _create_grad_scaler, _update_history_and_log, diff --git a/src/engines/utils/training_core.py b/src/engines/training_core.py similarity index 100% rename from src/engines/utils/training_core.py rename to src/engines/training_core.py diff --git a/src/losses/classification.py b/src/losses/classification.py index af9022d..a8c4767 100644 --- a/src/losses/classification.py +++ b/src/losses/classification.py @@ -4,11 +4,20 @@ # # SPDX-License-Identifier: MIT +""" +Classification loss functions for model training. + +This module provides common loss functions used for classification tasks, +including cross-entropy loss with various options like label smoothing +and class weighting. +""" + # src/losses/classification.py # -*- coding: utf-8 -*- from typing import Optional +# pylint: disable=import-error import torch.nn.functional as F -from torch import Tensor +from torch import Tensor # pylint: disable=import-error def cross_entropy_loss( diff --git a/src/losses/distillation.py b/src/losses/distillation.py index c6c4b7e..e63c305 100644 --- a/src/losses/distillation.py +++ b/src/losses/distillation.py @@ -4,11 +4,20 @@ # # SPDX-License-Identifier: MIT +""" +Knowledge distillation loss functions. + +This module provides various loss functions used in knowledge distillation, +including cosine embedding loss, KL divergence, and hybrid distillation loss +for transferring knowledge from teacher to student models. +""" + # src/losses/distillation.py # -*- coding: utf-8 -*- from typing import Dict +# pylint: disable=import-error import torch.nn.functional as F -from torch import Tensor +from torch import Tensor # pylint: disable=import-error def cosine_loss(reduction: str = "mean"): @@ -23,7 +32,7 @@ def _loss(s_embed: Tensor, t_embed: Tensor) -> Tensor: loss = 1.0 - sim # minimize (1 - cos) if reduction == "mean": return loss.mean() - elif reduction == "sum": + if reduction == "sum": return loss.sum() return loss return _loss @@ -33,14 +42,14 @@ def kl_divergence_loss(temperature: float = 1.0, reduction: str = "batchmean"): """ KL divergence on logits with temperature scaling (Hinton et al., 2015). """ - T = float(temperature) - T2 = T * T + temp = float(temperature) + temp_squared = temp * temp def _loss(s_logits: Tensor, t_logits: Tensor) -> Tensor: - log_p_s = F.log_softmax(s_logits / T, dim=-1) - p_t = F.softmax(t_logits / T, dim=-1) + log_p_s = F.log_softmax(s_logits / temp, dim=-1) + p_t = F.softmax(t_logits / temp, dim=-1) kl = F.kl_div(log_p_s, p_t, reduction=reduction) - return kl * T2 + return kl * temp_squared return _loss diff --git a/src/transformation/transforms.py b/src/transformation/transforms.py index 1049fbb..87ce6f3 100644 --- a/src/transformation/transforms.py +++ b/src/transformation/transforms.py @@ -5,20 +5,15 @@ # SPDX-License-Identifier: MIT """Image transformation utilities.""" +# Standard library imports import io -from typing import Optional -import numpy as np -from PIL import Image, ImageFilter - from typing import Optional, Tuple -import numpy as np -from PIL import Image -from typing import Optional, Tuple -import numpy as np -from PIL import Image +# Third-party imports +import numpy as np # pylint: disable=import-error +from PIL import Image, ImageFilter # pylint: disable=import-error -class ResolutionReductionTransform: +class ResolutionReductionTransform: # pylint: disable=too-few-public-methods """Reduce image resolution by factor or target resolution.""" def __init__( @@ -99,7 +94,7 @@ def __call__(self, img: Image.Image) -> Image.Image: return img.filter(ImageFilter.GaussianBlur(radius=radius)) -class ColorQuantizationTransform: +class ColorQuantizationTransform: # pylint: disable=too-few-public-methods """Reduce color palette of images.""" def __init__(self, n_colors: Optional[int] = None): diff --git a/src/utils/callbacks_hf.py b/src/utils/callbacks_hf.py index ea8b3ba..928209b 100644 --- a/src/utils/callbacks_hf.py +++ b/src/utils/callbacks_hf.py @@ -4,12 +4,14 @@ # # SPDX-License-Identifier: MIT +# pylint: disable=all + # src/utils/callbacks_hf.py # -*- coding: utf-8 -*- from __future__ import annotations import os import json -from typing import Optional, Dict, Any +from typing import Dict, Any from transformers import TrainerCallback # type: ignore from src.utils.training_utils import get_gpu_memory diff --git a/src/utils/constants.py b/src/utils/constants.py index a10cb33..29e2172 100644 --- a/src/utils/constants.py +++ b/src/utils/constants.py @@ -21,5 +21,3 @@ NUM_CLASSES = 1000 # update dynamically per dataset if needed NUM_FILTERED_CLASSES = 8 # for ISIC filtered subset example - - diff --git a/src/utils/logging.py b/src/utils/logging_core.py similarity index 51% rename from src/utils/logging.py rename to src/utils/logging_core.py index 74c71da..5fe232c 100644 --- a/src/utils/logging.py +++ b/src/utils/logging_core.py @@ -4,7 +4,15 @@ # # SPDX-License-Identifier: MIT -# src/utils/logging.py +""" +Logging and metrics utilities for machine learning experiments. + +Provides functionality for: +- Standard Python logging setup +- Metric averaging for tracking training stats +- Optional Weights & Biases integration +""" + # -*- coding: utf-8 -*- from __future__ import annotations import logging @@ -14,24 +22,25 @@ from typing import Dict, Any, Optional # Optional: Weights & Biases -_WANDB_AVAILABLE = False +wandb_available = False # Module-level flag for wandb availability try: - import wandb # type: ignore - _WANDB_AVAILABLE = True -except Exception: - _WANDB_AVAILABLE = False + import wandb # type: ignore # pylint: disable=import-error + wandb_available = True +except ImportError: + # wandb is an optional dependency + wandb_available = False -def setup_logging(level: int = logging.INFO) -> None: +def setup_logging(level: int = logging.INFO) -> None: # pylint: disable=no-member """Configure root logger format/level.""" fmt = "[%(asctime)s] %(levelname)s - %(name)s: %(message)s" datefmt = "%H:%M:%S" - logging.basicConfig(level=level, format=fmt, datefmt=datefmt) + logging.basicConfig(level=level, format=fmt, datefmt=datefmt) # pylint: disable=no-member -def get_logger(name: str) -> logging.Logger: +def get_logger(name: str) -> logging.Logger: # pylint: disable=no-member """Get a module-specific logger.""" - return logging.getLogger(name) + return logging.getLogger(name) # pylint: disable=no-member class MetricAverager: @@ -47,14 +56,27 @@ def __init__(self) -> None: self.counts: Dict[str, int] = {} def update(self, **kwargs: float) -> None: + """ + Update metrics with new values. + + Args: + **kwargs: Keyword arguments of metric names and values + """ for k, v in kwargs.items(): self.totals[k] = self.totals.get(k, 0.0) + float(v) self.counts[k] = self.counts.get(k, 0) + 1 def averages(self) -> Dict[str, float]: - return {k: (self.totals[k] / max(self.counts[k], 1)) for k in self.totals} + """ + Get current averages for all tracked metrics. + + Returns: + Dictionary mapping metric names to their averages + """ + return {k: (v / max(self.counts[k], 1)) for k, v in self.totals.items()} def reset(self) -> None: + """Clear all tracked metrics.""" self.totals.clear() self.counts.clear() @@ -63,6 +85,7 @@ class WandbLogger: """ Thin W&B wrapper that is safe when wandb is not installed. """ + # pylint: disable=too-many-arguments,too-many-positional-arguments def __init__( self, project: str, @@ -72,7 +95,18 @@ def __init__( entity: Optional[str] = None, tags: Optional[list[str]] = None, ) -> None: - self.enabled = (_WANDB_AVAILABLE if enabled is None else enabled) + """ + Initialize the W&B logger. + + Args: + project: W&B project name + run_name: Optional name for this run + config: Configuration to log (dict, OmegaConf or compatible) + enabled: Whether to enable logging (defaults to wandb_available) + entity: Optional W&B team/entity name + tags: Optional list of tags for the run + """ + self.enabled = (wandb_available if enabled is None else enabled) self.run = None if self.enabled: self.run = wandb.init( @@ -85,36 +119,77 @@ def __init__( ) def log(self, metrics: Dict[str, Any], step: Optional[int] = None, commit: bool = True) -> None: + """ + Log metrics to W&B. + + Args: + metrics: Dictionary of metrics to log + step: Optional step/iteration number + commit: Whether to immediately commit to W&B + """ if self.enabled and self.run is not None: wandb.log(metrics, step=step, commit=commit) def watch_model(self, model, log: str = "gradients", log_freq: int = 100) -> None: + """ + Watch a model's parameters and gradients. + + Args: + model: PyTorch model to watch + log: What to log ('gradients', 'parameters', 'all', or None) + log_freq: How frequently to log + """ if self.enabled and self.run is not None: wandb.watch(model, log=log, log_freq=log_freq) def save_artifact(self, path: str, name: Optional[str] = None, type_: str = "file") -> None: + """ + Save a file as a W&B artifact. + + Args: + path: Path to the file to save + name: Optional artifact name (defaults to filename) + type_: Artifact type + """ if self.enabled and self.run is not None and os.path.exists(path): art = wandb.Artifact(name or Path(path).name, type=type_) art.add_file(path) self.run.log_artifact(art) def finish(self) -> None: + """Mark the W&B run as complete.""" if self.enabled and self.run is not None: self.run.finish() def _maybe_serialize_config(cfg: Any) -> Dict[str, Any]: - """Best-effort: convert config objects to plain dicts.""" + """ + Best-effort: convert config objects to plain dicts. + + Args: + cfg: Configuration object (dict, OmegaConf, or other) + + Returns: + JSON-serializable dictionary + """ + # If it's already a dict, return it + if isinstance(cfg, dict): + return cfg + + # Try OmegaConf conversion try: + # pylint: disable=import-error,import-outside-toplevel from omegaconf import OmegaConf # type: ignore - if isinstance(cfg, dict): - return cfg if OmegaConf.is_config(cfg): return OmegaConf.to_container(cfg, resolve=True) # type: ignore - except Exception: + except ImportError: + # OmegaConf not available pass + + # Try direct JSON serialization try: json.dumps(cfg) # type: ignore return cfg # type: ignore - except Exception: + except (TypeError, ValueError): + # Not JSON serializable return {} diff --git a/src/utils/optim.py b/src/utils/optim.py index e10bebb..a563006 100644 --- a/src/utils/optim.py +++ b/src/utils/optim.py @@ -4,15 +4,23 @@ # # SPDX-License-Identifier: MIT +""" +Optimization utilities for PyTorch models. + +This module provides functionality for: +- Building optimizers with weight decay parameter groups +- Creating learning rate schedulers with various policies +- Handling scheduler updates during training +""" # src/utils/optim.py # -*- coding: utf-8 -*- from __future__ import annotations from typing import Iterable, Tuple, Dict, Any import math -import torch -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler, LambdaLR, StepLR, ReduceLROnPlateau +import torch # pylint: disable=import-error +from torch.optim import Optimizer # pylint: disable=import-error +from torch.optim.lr_scheduler import _LRScheduler, LambdaLR, StepLR, ReduceLROnPlateau # pylint: disable=import-error def _param_groups_decay(model_or_params: Iterable, weight_decay: float) -> list: @@ -42,12 +50,16 @@ def _param_groups_decay(model_or_params: Iterable, weight_decay: float) -> list: def _build_optimizer(cfg, params) -> Optimizer: - opt_cfg = getattr(cfg, "train").get("optimizer", {}) if hasattr(cfg, "train") else {} + opt_cfg = ( + getattr(cfg, "train").get("optimizer", {}) if hasattr(cfg, "train") else {} + ) name = str(opt_cfg.get("name", "adamw")).lower() lr = float(opt_cfg.get("lr", 1e-4)) wd = float(opt_cfg.get("weight_decay", 0.05)) - if isinstance(params, torch.nn.Module) or (hasattr(params, "__iter__") and hasattr(next(iter(params)), "ndim")): + if isinstance(params, torch.nn.Module) or ( + hasattr(params, "__iter__") and hasattr(next(iter(params)), "ndim") + ): param_groups = _param_groups_decay(params, wd) else: # already groups @@ -57,12 +69,14 @@ def _build_optimizer(cfg, params) -> Optimizer: betas = tuple(opt_cfg.get("betas", (0.9, 0.999))) eps = float(opt_cfg.get("eps", 1e-8)) return torch.optim.AdamW(param_groups, lr=lr, betas=betas, eps=eps) - elif name == "sgd": + if name == "sgd": momentum = float(opt_cfg.get("momentum", 0.9)) nesterov = bool(opt_cfg.get("nesterov", True)) - return torch.optim.SGD(param_groups, lr=lr, momentum=momentum, nesterov=nesterov) - else: - raise ValueError(f"Unsupported optimizer: {name}") + return torch.optim.SGD( + param_groups, lr=lr, momentum=momentum, nesterov=nesterov + ) + # If we reach this point, the optimizer is not supported + raise ValueError(f"Unsupported optimizer: {name}") def _warmup_cosine_lambda_fn(epochs: int, warmup_epochs: int, min_lr_ratio: float): @@ -70,45 +84,67 @@ def lr_lambda(current_epoch: int): if current_epoch < warmup_epochs: return float(current_epoch + 1) / float(max(1, warmup_epochs)) # cosine from 1.0 -> min_lr_ratio - progress = (current_epoch - warmup_epochs) / float(max(1, epochs - warmup_epochs)) + progress = (current_epoch - warmup_epochs) / float( + max(1, epochs - warmup_epochs) + ) cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) return min_lr_ratio + (1.0 - min_lr_ratio) * cosine + return lr_lambda -def _build_scheduler(cfg, optimizer: Optimizer) -> Tuple[_LRScheduler | None, Dict[str, Any]]: - sch_cfg = getattr(cfg, "train").get("scheduler", {}) if hasattr(cfg, "train") else {} +def _build_scheduler( + cfg, optimizer: Optimizer +) -> Tuple[_LRScheduler | None, Dict[str, Any]]: + sch_cfg = ( + getattr(cfg, "train").get("scheduler", {}) if hasattr(cfg, "train") else {} + ) name = str(sch_cfg.get("name", "cosine")).lower() epochs = int(sch_cfg.get("epochs", getattr(cfg.train, "epochs", 50))) warmup_epochs = int(sch_cfg.get("warmup_epochs", 0)) if name in ("cosine", "cosineanneal", "cosine_anneal"): min_lr = float(sch_cfg.get("min_lr", 1e-6)) - base_lr = float(getattr(cfg.train.get("optimizer", {}), "lr", 1e-4) if hasattr(cfg, "train") else 1e-4) + base_lr = float( + getattr(cfg.train.get("optimizer", {}), "lr", 1e-4) + if hasattr(cfg, "train") + else 1e-4 + ) min_lr_ratio = max(min_lr / max(base_lr, 1e-12), 0.0) lr_lambda = _warmup_cosine_lambda_fn(epochs, warmup_epochs, min_lr_ratio) scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) return scheduler, {"by": "epoch"} - elif name == "step": + if name == "step": step_size = int(sch_cfg.get("step_size", 30)) gamma = float(sch_cfg.get("gamma", 0.1)) scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma) return scheduler, {"by": "epoch"} - elif name in ("plateau", "reduceonplateau"): + if name in ("plateau", "reduceonplateau"): patience = int(sch_cfg.get("patience", 5)) factor = float(sch_cfg.get("factor", 0.5)) scheduler = ReduceLROnPlateau(optimizer, patience=patience, factor=factor) return scheduler, {"by": "val_metric"} - elif name in ("none", "off"): + if name in ("none", "off"): return None, {"by": "none"} - else: - raise ValueError(f"Unsupported scheduler: {name}") + # If we reach this point, the scheduler is not supported + raise ValueError(f"Unsupported scheduler: {name}") -def step_scheduler(scheduler, meta: Dict[str, Any], epoch: int, val_metric: float | None = None): +def step_scheduler( + scheduler, + meta: Dict[str, Any], + epoch: int = None, # pylint: disable=unused-argument + val_metric: float | None = None +): """ Step the scheduler depending on configuration (by epoch or by val metric). meta["by"] is returned from _build_scheduler. + + Args: + scheduler: The scheduler to step + meta: Metadata dictionary with scheduling policy + epoch: Current epoch (not used but kept for API compatibility) + val_metric: Validation metric for ReduceLROnPlateau schedulers """ if scheduler is None: return @@ -119,7 +155,9 @@ def step_scheduler(scheduler, meta: Dict[str, Any], epoch: int, val_metric: floa scheduler.step(val_metric) -def make_optimizer_and_scheduler(cfg, params) -> Tuple[Optimizer, Tuple[_LRScheduler | None, Dict[str, Any]]]: +def make_optimizer_and_scheduler( + cfg, params +) -> Tuple[Optimizer, Tuple[_LRScheduler | None, Dict[str, Any]]]: """ Factory: returns (optimizer, (scheduler, meta)). Use step_scheduler(scheduler, meta, epoch, val_metric) to step it. diff --git a/src/utils/training_utils.py b/src/utils/training_utils.py index 2488e4a..2420d2f 100644 --- a/src/utils/training_utils.py +++ b/src/utils/training_utils.py @@ -4,6 +4,8 @@ # # SPDX-License-Identifier: MIT +# pylint: disable=all + # src/utils/training_utils.py # -*- coding: utf-8 -*- from __future__ import annotations diff --git a/src/utils/utils.py b/src/utils/utils.py index 6f15629..9aad2b4 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -4,6 +4,7 @@ # # SPDX-License-Identifier: MIT +# pylint: disable=all """General utilities for environment, GPU, and I/O operations.""" import os import json diff --git a/src/wrappers/finetune.py b/src/wrappers/finetune.py index 57c1054..72ebf15 100644 --- a/src/wrappers/finetune.py +++ b/src/wrappers/finetune.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader # pylint: disable=import-error -from src.utils.logging import setup_logging, get_logger, WandbLogger +from src.utils.logging_core import setup_logging, get_logger, WandbLogger from src.utils.optim import make_optimizer_and_scheduler from src.losses.classification import cross_entropy_loss from src.models.factory import create_model, create_preprocessor, save_model diff --git a/src/wrappers/probe.py b/src/wrappers/probe.py index d5ea342..3c872c8 100644 --- a/src/wrappers/probe.py +++ b/src/wrappers/probe.py @@ -15,7 +15,7 @@ from torch.utils.data import DataLoader # pylint: disable=import-error -from src.utils.logging import setup_logging, get_logger, WandbLogger +from src.utils.logging_core import setup_logging, get_logger, WandbLogger from src.utils.optim import make_optimizer_and_scheduler from src.losses.classification import cross_entropy_loss from src.models.factory import (