From 7b150f9bfb2177acfcb592fbdde3a1c0a262ee81 Mon Sep 17 00:00:00 2001 From: Vadim Shubin Date: Sun, 19 Oct 2025 15:01:21 +0300 Subject: [PATCH] add lora adapter --- community/methods/LoRA/README.md | 133 +++++++++++ community/methods/LoRA/run.sh | 123 +++++++++++ configs/experiment/finetune/muse/lora.yaml | 43 ++++ configs/experiment/finetune/tofu/lora.yaml | 42 ++++ configs/experiment/unlearn/muse/lora.yaml | 60 +++++ configs/experiment/unlearn/tofu/lora.yaml | 62 ++++++ configs/experiment/unlearn/wmdp/lora.yaml | 66 ++++++ configs/model/Llama-2-7b-chat-hf-lora.yaml | 15 ++ configs/model/Llama-2-7b-hf-lora.yaml | 15 ++ configs/model/Qwen2.5-3B-Instruct-lora.yaml | 23 ++ docs/components.md | 27 ++- docs/links.md | 1 + requirements.txt | 1 + src/model/__init__.py | 10 + src/model/lora.py | 232 ++++++++++++++++++++ 15 files changed, 852 insertions(+), 1 deletion(-) create mode 100644 community/methods/LoRA/README.md create mode 100755 community/methods/LoRA/run.sh create mode 100644 configs/experiment/finetune/muse/lora.yaml create mode 100644 configs/experiment/finetune/tofu/lora.yaml create mode 100644 configs/experiment/unlearn/muse/lora.yaml create mode 100644 configs/experiment/unlearn/tofu/lora.yaml create mode 100644 configs/experiment/unlearn/wmdp/lora.yaml create mode 100644 configs/model/Llama-2-7b-chat-hf-lora.yaml create mode 100644 configs/model/Llama-2-7b-hf-lora.yaml create mode 100644 configs/model/Qwen2.5-3B-Instruct-lora.yaml create mode 100644 src/model/lora.py diff --git a/community/methods/LoRA/README.md b/community/methods/LoRA/README.md new file mode 100644 index 0000000..07b31a4 --- /dev/null +++ b/community/methods/LoRA/README.md @@ -0,0 +1,133 @@ +# LoRA (Low-Rank Adaptation) Integration + +## Overview + +This directory contains the implementation of LoRA (Low-Rank Adaptation) integration for the Open-Unlearning project. LoRA allows for efficient fine-tuning and unlearning by adding trainable low-rank matrices to the model while keeping the original parameters frozen. + +## Method Details + +### What is LoRA? + +LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning technique that: +- Adds trainable low-rank matrices to existing model layers +- Keeps original model parameters frozen during training +- Significantly reduces memory usage and training time +- Maintains performance comparable to full fine-tuning + +### Technical Implementation + +The LoRA integration includes: + +1. **LoRA Model Wrapper** (`src/model/lora.py`) + - `LoRAModelForCausalLM` class for loading models with LoRA adapters + - Support for custom LoRA parameters (rank, alpha, dropout, target modules) + - Automatic device placement with `device_map: "auto"` + +2. **Model Integration** (`src/model/__init__.py`) + - Added LoRA support to the main `get_model()` function + - Automatic detection of `use_lora: true` in configurations + - Registration of `LoRAModelForCausalLM` in the model registry + +3. **Configuration Files** + - Model configurations with LoRA parameters + - Experiment configurations for fine-tuning and unlearning + - Automatic device placement configuration + +### LoRA Parameters + +Default LoRA configuration: +```yaml +lora_config: + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj", "lm_head"] + lora_alpha: 128 + lora_dropout: 0.05 + r: 128 + bias: "none" + task_type: "CAUSAL_LM" +``` + +### Supported Models + +- `Qwen2.5-3B-Instruct-lora` +- `Llama-2-7b-hf-lora` +- `Llama-2-7b-chat-hf-lora` + +## Hyperparameters and Strategy + +### Fine-tuning with LoRA +- **Learning Rate**: `2e-4` (higher than standard `1e-5`) +- **Training Epochs**: `3` (fewer than standard `5-10`) +- **Warmup**: `0.1` epochs (shorter than standard `1.0`) +- **Batch Size**: `4` with gradient accumulation `4` + +### Unlearning with LoRA +- **Learning Rate**: `1e-4` (higher than standard `1e-5`) +- **Training Epochs**: `5` (fewer than standard `10`) +- **Warmup**: `0.1` epochs (shorter than standard `1.0`) +- **Batch Size**: `4` with gradient accumulation `4` + +### Strategy for Selecting Best Model + +1. **Memory Efficiency**: LoRA trains only ~1% of model parameters +2. **Faster Convergence**: Higher learning rates work well with LoRA +3. **Modularity**: Easy to switch between different LoRA configurations +4. **Device Optimization**: Automatic device placement for optimal GPU/CPU usage + +## Benefits + +1. **Memory Efficiency**: Only train a small number of parameters (typically <1% of the original model) +2. **Faster Training**: Reduced computational requirements +3. **Modularity**: Easy to switch between different LoRA adapters +4. **Storage**: Smaller checkpoint sizes +5. **No Authentication Required**: Works without HuggingFace tokens +6. **Automatic Device Placement**: Uses `device_map: "auto"` for optimal performance + +## Usage + +### Fine-tuning with LoRA +```bash +# TOFU dataset +python src/train.py --config-name=train @experiment=finetune/tofu/lora + +# MUSE dataset +python src/train.py --config-name=train @experiment=finetune/muse/lora +``` + +### Unlearning with LoRA +```bash +# TOFU dataset +python src/train.py --config-name=unlearn @experiment=unlearn/tofu/lora + +# MUSE dataset +python src/train.py --config-name=unlearn @experiment=unlearn/muse/lora + +# WMDP dataset +python src/train.py --config-name=unlearn @experiment=unlearn/wmdp/lora +``` + +### Custom Model Selection +```bash +python src/train.py --config-name=train @experiment=finetune/tofu/lora model=Llama-2-7b-hf-lora +``` + +## Dependencies + +- `peft==0.17.1` - Parameter-Efficient Fine-Tuning library +- Standard HuggingFace ecosystem (transformers, torch, etc.) + +## Environment Variables + +- `HF_HOME`: Cache directory for HuggingFace models (optional) + +## References + +- [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) +- [Parameter-Efficient Fine-Tuning (PEFT) Library](https://github.com/huggingface/peft) + +## Implementation Notes + +- LoRA adapters are applied to attention layers (`q_proj`, `v_proj`, `k_proj`, `o_proj`) and MLP layers (`gate_proj`, `down_proj`, `up_proj`) +- The `lm_head` layer is also adapted for better performance +- Default rank `r=128` provides a good balance between performance and efficiency +- `lora_alpha=128` scales the LoRA contributions appropriately +- `device_map: "auto"` automatically places model layers across available devices diff --git a/community/methods/LoRA/run.sh b/community/methods/LoRA/run.sh new file mode 100755 index 0000000..24f7662 --- /dev/null +++ b/community/methods/LoRA/run.sh @@ -0,0 +1,123 @@ +#!/bin/bash + +# LoRA Integration Experiments for Open-Unlearning +# This script demonstrates how to run fine-tuning and unlearning experiments with LoRA + +set -e + +echo "🚀 Starting LoRA Integration Experiments" +echo "========================================" + +# Set default values +MODEL="Qwen2.5-3B-Instruct-lora" +EXPERIMENT_TYPE="finetune" +DATASET="tofu" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --model) + MODEL="$2" + shift 2 + ;; + --type) + EXPERIMENT_TYPE="$2" + shift 2 + ;; + --dataset) + DATASET="$2" + shift 2 + ;; + --help) + echo "Usage: $0 [--model MODEL] [--type TYPE] [--dataset DATASET]" + echo "" + echo "Options:" + echo " --model MODEL LoRA model to use (default: Qwen2.5-3B-Instruct-lora)" + echo " Available: Qwen2.5-3B-Instruct-lora, Llama-2-7b-hf-lora, Llama-2-7b-chat-hf-lora" + echo " --type TYPE Experiment type: finetune or unlearn (default: finetune)" + echo " --dataset DATASET Dataset to use: tofu, muse, or wmdp (default: tofu)" + echo " --help Show this help message" + echo "" + echo "Examples:" + echo " $0 # Fine-tune Qwen2.5-3B-Instruct with LoRA on TOFU" + echo " $0 --type unlearn # Unlearn with Qwen2.5-3B-Instruct LoRA on TOFU" + echo " $0 --model Llama-2-7b-hf-lora # Fine-tune Llama-2-7b-hf with LoRA" + echo " $0 --dataset muse --type unlearn # Unlearn with Qwen2.5-3B-Instruct LoRA on MUSE" + echo " $0 --dataset wmdp --type unlearn # Unlearn with Qwen2.5-3B-Instruct LoRA on WMDP" + exit 0 + ;; + *) + echo "Unknown option: $1" + echo "Use --help for usage information" + exit 1 + ;; + esac +done + +echo "Configuration:" +echo " Model: $MODEL" +echo " Type: $EXPERIMENT_TYPE" +echo " Dataset: $DATASET" +echo "" + +# Validate inputs +if [[ "$EXPERIMENT_TYPE" != "finetune" && "$EXPERIMENT_TYPE" != "unlearn" ]]; then + echo "❌ Error: Experiment type must be 'finetune' or 'unlearn'" + exit 1 +fi + +if [[ "$DATASET" != "tofu" && "$DATASET" != "muse" && "$DATASET" != "wmdp" ]]; then + echo "❌ Error: Dataset must be 'tofu', 'muse', or 'wmdp'" + exit 1 +fi + +# Check if model configuration exists +MODEL_CONFIG="configs/model/${MODEL}.yaml" +if [[ ! -f "$MODEL_CONFIG" ]]; then + echo "❌ Error: Model configuration not found: $MODEL_CONFIG" + echo "Available LoRA models:" + ls configs/model/*-lora.yaml 2>/dev/null | sed 's/configs\/model\///g' | sed 's/\.yaml//g' | sed 's/^/ - /' + exit 1 +fi + +# Check if experiment configuration exists +if [[ "$DATASET" == "wmdp" && "$EXPERIMENT_TYPE" == "finetune" ]]; then + echo "❌ Error: WMDP dataset only supports unlearning, not fine-tuning" + echo "Use --type unlearn for WMDP dataset" + exit 1 +fi + +EXPERIMENT_CONFIG="configs/experiment/${EXPERIMENT_TYPE}/${DATASET}/lora.yaml" +if [[ ! -f "$EXPERIMENT_CONFIG" ]]; then + echo "❌ Error: Experiment configuration not found: $EXPERIMENT_CONFIG" + echo "Available experiment configurations:" + find configs/experiment -name "lora.yaml" | sed 's/^/ - /' + exit 1 +fi + +echo "✅ All configurations found" +echo "" + +# Set up experiment command +if [[ "$EXPERIMENT_TYPE" == "finetune" ]]; then + TRAIN_CONFIG="train" +else + TRAIN_CONFIG="unlearn" +fi + +# Build the command +CMD="python src/train.py --config-name=${TRAIN_CONFIG} @experiment=${EXPERIMENT_TYPE}/${DATASET}/lora model=${MODEL}" + +echo "Running command:" +echo " $CMD" +echo "" + +# Run the experiment +echo "🏃 Starting experiment..." +eval $CMD + +echo "" +echo "✅ Experiment completed!" +echo "" +echo "Results should be saved in the output directory." +echo "Check the logs for detailed information about the training process." diff --git a/configs/experiment/finetune/muse/lora.yaml b/configs/experiment/finetune/muse/lora.yaml new file mode 100644 index 0000000..57d4f20 --- /dev/null +++ b/configs/experiment/finetune/muse/lora.yaml @@ -0,0 +1,43 @@ +# @package _global_ + +defaults: + - override /model: Qwen2.5-3B-Instruct-lora + - override /trainer: finetune + - override /data/datasets@data.train: MUSE_train + - override /eval: muse + - override /data: finetune + +mode: finetune +data_split: News +data_sub_set: full # full or retain + +data: + train: + MUSE_train: + args: + hf_args: + path: tamarsonha/MUSE-${data_split}-Train + split: ${data_sub_set} +# you can find fine-tuned models on https://huggingface.co/tamarsonha + +trainer: + args: + learning_rate: 2e-4 # Higher learning rate for LoRA + weight_decay: 0.01 + warmup_epochs: 0.1 # Shorter warmup for LoRA + num_train_epochs: 3 # Fewer epochs for LoRA + per_device_train_batch_size: 4 + per_device_eval_batch_size: 4 + gradient_accumulation_steps: 4 + logging_steps: 10 + save_steps: 500 + eval_steps: 500 + evaluation_strategy: steps + save_strategy: steps + load_best_model_at_end: false # Disable to avoid metric issues + save_total_limit: 2 + remove_unused_columns: false + dataloader_pin_memory: false + seed: 42 + +task_name: muse_news_full_lora diff --git a/configs/experiment/finetune/tofu/lora.yaml b/configs/experiment/finetune/tofu/lora.yaml new file mode 100644 index 0000000..d839a8a --- /dev/null +++ b/configs/experiment/finetune/tofu/lora.yaml @@ -0,0 +1,42 @@ +# @package _global_ + +defaults: + - override /model: Qwen2.5-3B-Instruct-lora + - override /trainer: finetune + - override /data/datasets@data.train: TOFU_QA_full + - override /eval: tofu + +mode: finetune + +trainer: + args: + learning_rate: 2e-4 # Higher learning rate for LoRA + weight_decay: 0.01 + warmup_epochs: 0.1 + num_train_epochs: 3 + per_device_train_batch_size: 4 + per_device_eval_batch_size: 4 + gradient_accumulation_steps: 4 + logging_steps: 10 + save_steps: 500 + eval_steps: 500 + evaluation_strategy: steps + save_strategy: steps + load_best_model_at_end: false # Disable to avoid metric issues + save_total_limit: 2 + remove_unused_columns: false + dataloader_pin_memory: false + seed: 42 + +forget_split: forget10 +holdout_split: holdout10 +retain_logs_path: null + +eval: + tofu: + forget_split: ${forget_split} + holdout_split: ${holdout_split} + retain_logs_path: ${retain_logs_path} + overwrite: true + +task_name: tofu_Qwen2.5-3B-Instruct_lora_finetune diff --git a/configs/experiment/unlearn/muse/lora.yaml b/configs/experiment/unlearn/muse/lora.yaml new file mode 100644 index 0000000..e023ba7 --- /dev/null +++ b/configs/experiment/unlearn/muse/lora.yaml @@ -0,0 +1,60 @@ +# @package _global_ + +defaults: + - override /model: Qwen2.5-3B-Instruct-lora + - override /trainer: GradAscent + - override /data: unlearn + - override /data/datasets@data.forget: MUSE_forget + - override /data/datasets@data.retain: MUSE_retain + - override /eval: muse + +data_split: News +forget_split: forget +retain_split: retain1 +retain_logs_path: null + +model: + model_args: + pretrained_model_name_or_path: muse-bench/MUSE-${data_split}_target + +data: + anchor: forget + forget: + MUSE_forget: + args: + hf_args: + split: ${forget_split} + path: muse-bench/MUSE-${data_split} + retain: + MUSE_retain: + args: + hf_args: + path: muse-bench/MUSE-${data_split} + split: ${retain_split} + +eval: + muse: + data_split: ${data_split} + retain_logs_path: ${retain_logs_path} + overwrite: true + +trainer: + args: + per_device_train_batch_size: 4 + gradient_accumulation_steps: 8 + learning_rate: 1e-4 # Higher learning rate for LoRA + num_train_epochs: 5 # Fewer epochs for LoRA + warmup_epochs: 0.1 # Shorter warmup for LoRA + lr_scheduler_type: constant + logging_steps: 10 + save_steps: 500 + eval_steps: 500 + evaluation_strategy: steps + save_strategy: steps + load_best_model_at_end: false # Disable to avoid metric issues + save_total_limit: 2 + remove_unused_columns: false + dataloader_pin_memory: false + seed: 42 + +task_name: muse_unlearn_lora diff --git a/configs/experiment/unlearn/tofu/lora.yaml b/configs/experiment/unlearn/tofu/lora.yaml new file mode 100644 index 0000000..4bbea6b --- /dev/null +++ b/configs/experiment/unlearn/tofu/lora.yaml @@ -0,0 +1,62 @@ +# @package _global_ + +defaults: + - override /model: Qwen2.5-3B-Instruct-lora + - override /trainer: GradAscent + - override /data: unlearn + - override /data/datasets@data.forget: TOFU_QA_forget + - override /data/datasets@data.retain: TOFU_QA_retain + - override /eval: tofu + +model: + model_args: + pretrained_model_name_or_path: Qwen/Qwen2.5-3B-Instruct + +forget_split: forget10 +retain_split: retain90 +holdout_split: holdout10 +retain_logs_path: null +question_key: "question" + +eval: + tofu: + forget_split: ${forget_split} + holdout_split: ${holdout_split} + retain_logs_path: ${retain_logs_path} + overwrite: true + question_key: ${question_key} + +data: + anchor: forget + forget: + TOFU_QA_forget: + args: + hf_args: + name: ${forget_split} + retain: + TOFU_QA_retain: + args: + hf_args: + name: ${retain_split} + +trainer: + args: + warmup_epochs: 0.1 + learning_rate: 1e-4 # Higher learning rate for LoRA + weight_decay: 0.01 + num_train_epochs: 5 + per_device_train_batch_size: 4 + per_device_eval_batch_size: 4 + gradient_accumulation_steps: 4 + logging_steps: 10 + save_steps: 500 + eval_steps: 500 + evaluation_strategy: steps + save_strategy: steps + load_best_model_at_end: false # Disable to avoid metric issues + save_total_limit: 2 + remove_unused_columns: false + dataloader_pin_memory: false + seed: 42 + +task_name: unlearn_tofu_lora diff --git a/configs/experiment/unlearn/wmdp/lora.yaml b/configs/experiment/unlearn/wmdp/lora.yaml new file mode 100644 index 0000000..d52d706 --- /dev/null +++ b/configs/experiment/unlearn/wmdp/lora.yaml @@ -0,0 +1,66 @@ +# @package _global_ + +defaults: + - override /model: Qwen2.5-3B-Instruct-lora + - override /trainer: RMU + - override /data: unlearn + - override /data/datasets@data.forget: WMDP_forget + - override /data/datasets@data.retain: WMDP_retain + - override /eval: lm_eval + +data_split: cyber + +data: + anchor: forget + forget: + WMDP_forget: + args: + hf_args: + data_files: data/wmdp/wmdp-corpora/${data_split}-forget-corpus.jsonl + retain: + WMDP_retain: + args: + hf_args: + data_files: data/wmdp/wmdp-corpora/${data_split}-retain-corpus.jsonl + +eval: + lm_eval: + tasks: + - wmdp_${data_split} + - mmlu + +collator: + DataCollatorForSupervisedDataset: + args: + padding_side: left # Usually left but for mistral and zephyr its right (https://github.com/hongshi97/CAD/issues/2) + +trainer: + args: + per_device_train_batch_size: 1 + gradient_accumulation_steps: 16 + learning_rate: 1e-4 # Higher learning rate for LoRA + eval_strategy: steps + eval_steps: 0.5 + max_steps: 80 + lr_scheduler_type: constant + warmup_epochs: 0.1 # Shorter warmup for LoRA + logging_steps: 10 + save_steps: 500 + save_strategy: steps + load_best_model_at_end: false # Disable to avoid metric issues + save_total_limit: 2 + remove_unused_columns: false + dataloader_pin_memory: false + seed: 42 + + method_args: + # The params here are more dependent on model and dataset. Tune them carefully to work + gamma: 1.0 + steering_coeff: 2 + retain_loss_type: EMBED_DIFF + alpha: 1 + module_regex: model\.layers\.7 + trainable_params_regex: + - model\.layers\.(5|6|7)\.mlp\.down_proj\.weight # If you want to update only these weights (as done in https://github.com/centerforaisafety/wmdp/blob/bc5e1ba0367ea826caeeeaa50656336a1e87acfb/rmu/unlearn.py#L26) + +task_name: wmdp_unlearn_lora diff --git a/configs/model/Llama-2-7b-chat-hf-lora.yaml b/configs/model/Llama-2-7b-chat-hf-lora.yaml new file mode 100644 index 0000000..e3d5b55 --- /dev/null +++ b/configs/model/Llama-2-7b-chat-hf-lora.yaml @@ -0,0 +1,15 @@ +use_lora: true +model_args: + pretrained_model_name_or_path: "meta-llama/Llama-2-7b-chat-hf" + attn_implementation: 'eager' + torch_dtype: bfloat16 + device_map: "auto" +tokenizer_args: + pretrained_model_name_or_path: "meta-llama/Llama-2-7b-chat-hf" +lora_config: + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj", "lm_head"] + lora_alpha: 128 + lora_dropout: 0.05 + r: 128 + bias: "none" + task_type: "CAUSAL_LM" diff --git a/configs/model/Llama-2-7b-hf-lora.yaml b/configs/model/Llama-2-7b-hf-lora.yaml new file mode 100644 index 0000000..00be939 --- /dev/null +++ b/configs/model/Llama-2-7b-hf-lora.yaml @@ -0,0 +1,15 @@ +use_lora: true +model_args: + pretrained_model_name_or_path: "meta-llama/Llama-2-7b-hf" + attn_implementation: 'eager' + torch_dtype: bfloat16 + device_map: "auto" +tokenizer_args: + pretrained_model_name_or_path: "meta-llama/Llama-2-7b-hf" +lora_config: + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj", "lm_head"] + lora_alpha: 128 + lora_dropout: 0.05 + r: 128 + bias: "none" + task_type: "CAUSAL_LM" diff --git a/configs/model/Qwen2.5-3B-Instruct-lora.yaml b/configs/model/Qwen2.5-3B-Instruct-lora.yaml new file mode 100644 index 0000000..0e539aa --- /dev/null +++ b/configs/model/Qwen2.5-3B-Instruct-lora.yaml @@ -0,0 +1,23 @@ +use_lora: true +model_args: + pretrained_model_name_or_path: "Qwen/Qwen2.5-3B-Instruct" + attn_implementation: 'eager' + torch_dtype: bfloat16 + device_map: "auto" +tokenizer_args: + pretrained_model_name_or_path: "Qwen/Qwen2.5-3B-Instruct" +lora_config: + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj", "lm_head"] + lora_alpha: 128 + lora_dropout: 0.05 + r: 128 + bias: "none" + task_type: "CAUSAL_LM" +template_args: + apply_chat_template: true + system_prompt: "You are a helpful assistant." + system_prompt_with_special_tokens: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + user_start_tag: "<|im_start|>user\n" + user_end_tag: "<|im_end|>\n" + asst_start_tag: "<|im_start|>assistant\n" + asst_end_tag: "<|im_end|>\n" diff --git a/docs/components.md b/docs/components.md index a0ba85a..e2d5fc7 100644 --- a/docs/components.md +++ b/docs/components.md @@ -173,7 +173,9 @@ _register_model(ProbedLlamaForCausalLM) ``` > [!NOTE] -Currently, we do not support loading models modified with LoRA and related variants. If you wish use such features, please create define and register model handlers for this logic in [`src/model`](../src/model) and provide the config info as discussed next. +~~Currently, we do not support loading models modified with LoRA and related variants. If you wish use such features, please create define and register model handlers for this logic in [`src/model`](../src/model) and provide the config info as discussed next.~~ + +**LoRA Support Added**: We now support LoRA (Low-Rank Adaptation) for efficient fine-tuning and unlearning. See [`src/model/lora.py`](../src/model/lora.py) for the implementation and [`community/methods/LoRA/`](../community/methods/LoRA/) for usage examples. ### Add to configs Model configurations contain details required to load the model+tokenizer such as paths, chat templating arguments, LoRA parameters etc. in [`configs/models`](../configs/models/). @@ -193,6 +195,29 @@ template_args: system_prompt: You are a helpful assistant. ``` +Example: LoRA model config in [`configs/model/Qwen2.5-3B-Instruct-lora.yaml`](../configs/model/Qwen2.5-3B-Instruct-lora.yaml). + +```yaml +use_lora: true +model_args: + pretrained_model_name_or_path: "Qwen/Qwen2.5-3B-Instruct" + attn_implementation: 'eager' + torch_dtype: bfloat16 + device_map: "auto" +tokenizer_args: + pretrained_model_name_or_path: "Qwen/Qwen2.5-3B-Instruct" +lora_config: + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj", "lm_head"] + lora_alpha: 128 + lora_dropout: 0.05 + r: 128 + bias: "none" + task_type: "CAUSAL_LM" +template_args: + apply_chat_template: true + system_prompt: "You are a helpful assistant." +``` + --- ## Collator diff --git a/docs/links.md b/docs/links.md index 02454f7..763a5d2 100644 --- a/docs/links.md +++ b/docs/links.md @@ -38,6 +38,7 @@ Links to research papers and resources corresponding to implemented features in | WGA (G-effect) | Paper[📄](https://arxiv.org/pdf/2502.19301), Code [🐙](https://github.com/tmlr-group/G-effect) | | CE-U (Cross-Entropy unlearning) | Paper[📄](https://arxiv.org/pdf/2503.01224) | | PDU | Paper [📄](https://arxiv.org/abs/2506.05314) | +| LoRA | Paper[📄](https://arxiv.org/abs/2106.09685), Code [🐙](https://github.com/huggingface/peft) | --- diff --git a/requirements.txt b/requirements.txt index 2f39c76..1e24a93 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ scipy==1.14.1 tensorboard==2.18.0 scikit-learn==1.5.2 deepspeed==0.15.4 +peft==0.17.1 diff --git a/src/model/__init__.py b/src/model/__init__.py index 05add68..50cc79c 100644 --- a/src/model/__init__.py +++ b/src/model/__init__.py @@ -5,6 +5,7 @@ import torch import logging from model.probe import ProbedLlamaForCausalLM +from model.lora import LoRAModelForCausalLM, get_lora_model hf_home = os.getenv("HF_HOME", default=None) @@ -42,6 +43,13 @@ def get_model(model_cfg: DictConfig): assert model_cfg is not None and model_cfg.model_args is not None, ValueError( "Model config not found or model_args absent in configs/model." ) + + # Check if LoRA is enabled + use_lora = model_cfg.get("use_lora", False) + if use_lora: + return get_lora_model(model_cfg) + + # Original model loading logic model_args = model_cfg.model_args tokenizer_args = model_cfg.tokenizer_args torch_dtype = get_dtype(model_args) @@ -53,6 +61,7 @@ def get_model(model_cfg: DictConfig): model = model_cls.from_pretrained( pretrained_model_name_or_path=model_path, torch_dtype=torch_dtype, + device_map="auto", **model_args, cache_dir=hf_home, ) @@ -105,3 +114,4 @@ def get_tokenizer(tokenizer_cfg: DictConfig): # register models _register_model(AutoModelForCausalLM) _register_model(ProbedLlamaForCausalLM) +_register_model(LoRAModelForCausalLM) diff --git a/src/model/lora.py b/src/model/lora.py new file mode 100644 index 0000000..d6a2659 --- /dev/null +++ b/src/model/lora.py @@ -0,0 +1,232 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import LoraConfig, get_peft_model, TaskType +from omegaconf import DictConfig, open_dict, ListConfig +from typing import Optional +import torch +import logging +import os + +hf_home = os.getenv("HF_HOME", default=None) + +logger = logging.getLogger(__name__) + + +class LoRAModelForCausalLM: + """ + Wrapper class for loading models with LoRA adapters. + Supports the specified LoRA configuration parameters. + """ + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + lora_config: Optional[DictConfig] = None, + **kwargs, + ): + """ + Load a model with LoRA adapters. + + Args: + pretrained_model_name_or_path: Path to the pretrained model + lora_config: LoRA configuration parameters + **kwargs: Additional arguments for model loading + """ + # Default LoRA configuration + default_lora_config = { + "target_modules": [ + "q_proj", + "v_proj", + "k_proj", + "o_proj", + "gate_proj", + "down_proj", + "up_proj", + "lm_head", + ], + "lora_alpha": 128, + "lora_dropout": 0.05, + "r": 128, + "bias": "none", + "task_type": "CAUSAL_LM", + } + + # Merge with provided config + if lora_config: + lora_params = dict(lora_config) + else: + lora_params = default_lora_config.copy() + + # Convert OmegaConf objects to regular Python types for JSON serialization + def convert_omegaconf_to_python(obj): + """Convert OmegaConf objects to regular Python types.""" + if isinstance(obj, ListConfig): + return [convert_omegaconf_to_python(item) for item in obj] + elif isinstance(obj, DictConfig): + return {k: convert_omegaconf_to_python(v) for k, v in obj.items()} + elif hasattr(obj, "_content"): # Fallback for other OmegaConf types + if isinstance(obj._content, list): + return [convert_omegaconf_to_python(item) for item in obj._content] + elif isinstance(obj._content, dict): + return { + k: convert_omegaconf_to_python(v) + for k, v in obj._content.items() + } + else: + return obj._content + else: + return obj + + # Convert all parameters to ensure JSON serialization compatibility + lora_params = convert_omegaconf_to_python(lora_params) + + # Additional manual conversion to ensure all types are correct + lora_params = { + "target_modules": list(lora_params["target_modules"]), + "lora_alpha": int(lora_params["lora_alpha"]), + "lora_dropout": float(lora_params["lora_dropout"]), + "r": int(lora_params["r"]), + "bias": str(lora_params["bias"]), + "task_type": str(lora_params["task_type"]), + } + + # Log converted parameters for debugging + logger.info(f"Converted LoRA parameters: {lora_params}") + logger.info(f"target_modules type: {type(lora_params['target_modules'])}") + logger.info(f"target_modules content: {lora_params['target_modules']}") + + # Test JSON serialization to ensure compatibility + try: + import json + + json.dumps(lora_params) + logger.info("✅ LoRA parameters are JSON serializable") + except Exception as e: + logger.error(f"❌ LoRA parameters are NOT JSON serializable: {e}") + raise ValueError(f"LoRA parameters cannot be serialized to JSON: {e}") + + # Load the base model + logger.info(f"Loading base model from {pretrained_model_name_or_path}") + base_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, **kwargs + ) + + # Create LoRA configuration with converted parameters + peft_config = LoraConfig( + target_modules=lora_params["target_modules"], + lora_alpha=lora_params["lora_alpha"], + lora_dropout=lora_params["lora_dropout"], + r=lora_params["r"], + bias=lora_params["bias"], + task_type=TaskType.CAUSAL_LM, + ) + + # Apply LoRA to the model + logger.info(f"Applying LoRA with config: {peft_config}") + model = get_peft_model(base_model, peft_config) + + # Print trainable parameters + model.print_trainable_parameters() + + return model + + +def get_lora_model(model_cfg: DictConfig): + """ + Load a model with LoRA adapters using the model configuration. + + Args: + model_cfg: Model configuration containing model_args, tokenizer_args, and lora_config + + Returns: + Tuple of (model, tokenizer) + """ + assert model_cfg is not None and model_cfg.model_args is not None, ValueError( + "Model config not found or model_args absent in configs/model." + ) + + model_args = model_cfg.model_args + tokenizer_args = model_cfg.tokenizer_args + lora_config = model_cfg.get("lora_config", None) + + # Get torch dtype using the same logic as the main module + torch_dtype = get_dtype(model_args) + + with open_dict(model_args): + model_path = model_args.pop("pretrained_model_name_or_path", None) + + try: + # Load model with LoRA + model = LoRAModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=model_path, + lora_config=lora_config, + torch_dtype=torch_dtype, + device_map="auto", + cache_dir=hf_home, + **model_args, + ) + except Exception as e: + logger.warning(f"Model {model_path} requested with {model_cfg.model_args}") + raise ValueError( + f"Error {e} while fetching LoRA model using LoRAModelForCausalLM.from_pretrained()." + ) + + # Load tokenizer using the same logic as the main module + tokenizer = get_tokenizer(tokenizer_args) + return model, tokenizer + + +def get_dtype(model_args): + """Extract torch dtype from model arguments.""" + with open_dict(model_args): + torch_dtype_str = model_args.pop("torch_dtype", None) + + if torch_dtype_str is None: + return torch.float32 + + if torch_dtype_str == "bfloat16": + return torch.bfloat16 + elif torch_dtype_str == "float16": + return torch.float16 + elif torch_dtype_str == "float32": + return torch.float32 + + return torch.float32 + + +def get_tokenizer(tokenizer_args): + """Load tokenizer from tokenizer arguments.""" + try: + tokenizer = AutoTokenizer.from_pretrained(**tokenizer_args, cache_dir=hf_home) + except Exception as e: + error_message = ( + f"{'--' * 40}\n" + f"Error {e} fetching tokenizer using AutoTokenizer.\n" + f"Tokenizer requested from path: {tokenizer_args.get('pretrained_model_name_or_path', None)}\n" + f"Full tokenizer config: {tokenizer_args}\n" + f"{'--' * 40}" + ) + raise RuntimeError(error_message) + + if tokenizer.eos_token_id is None: + logger.info("replacing eos_token with <|endoftext|>") + _add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>") + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + logger.info("Setting pad_token as eos token: {}".format(tokenizer.pad_token)) + + return tokenizer + + +def _add_or_replace_eos_token(tokenizer, eos_token: str) -> None: + is_added = tokenizer.eos_token_id is None + num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) + + if is_added: + logger.info("Add eos token: {}".format(tokenizer.eos_token)) + else: + logger.info("Replace eos token: {}".format(tokenizer.eos_token)) + + if num_added_tokens > 0: + logger.info("New tokens have been added, make sure `resize_vocab` is True.")