Skip to content

Commit 0d35d15

Browse files
dong-rivermolereddyDornavineeth
authored
Add UNDIAL (#89)
* Fix hyperlinks in README (#2) * testing commit * Fixes * cleanup * Fixed DPO command * download idk * Revert "Dpo fix" * download idk data * fix dpo experiment config * RMU (#6) * IdkDPO script fix in tofu_unlearn.sh (#65) * Fix hyperlinks in README * Download I don't know data in setup_data.py * Fix tofu_unlearn.sh for IdkDPO --------- Co-authored-by: Anmol Mekala <[email protected]> * overwrite=True * RMU added * Fix ref model device * ruff fix * RMU updated * Update rmu.py * Update README.md: add RMU * Added references and renamed functions --------- Co-authored-by: Anmol Mekala <[email protected]> * Add structure to contributions, setup leaderboard, update documentation (#8) * docs: updates, small corrections, re-formats * modified ruff commands * modified ruff commands * CI/CD minor updates * added contributing + leaderboard * fix minor spelling misatkes * docs: bunch of minor updates * docs fixes --------- Co-authored-by: molereddy <[email protected]> * UNDIAL * UNDIAL2 * UNDIAL3 * Ruff quality formatting changes * fix config * fix docs and script * Update readme --------- Co-authored-by: Anmol Mekala <[email protected]> Co-authored-by: Dornavineeth <[email protected]> Co-authored-by: Vineeth <[email protected]> Co-authored-by: molereddy <[email protected]>
1 parent 679dd0d commit 0d35d15

File tree

8 files changed

+178
-2
lines changed

8 files changed

+178
-2
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
## 📖 Overview
2020

21-
We provide efficient and streamlined implementations of the TOFU, MUSE and WMDP unlearning benchmarks while supporting 6 unlearning methods, 5+ datasets, 10+ evaluation metrics, and 7+ LLM architectures. Each of these can be easily extended to incorporate more variants.
21+
We provide efficient and streamlined implementations of the TOFU, MUSE and WMDP unlearning benchmarks while supporting 7 unlearning methods, 5+ datasets, 10+ evaluation metrics, and 7+ LLM architectures. Each of these can be easily extended to incorporate more variants.
2222

2323
We invite the LLM unlearning community to collaborate by adding new benchmarks, unlearning methods, datasets and evaluation metrics here to expand OpenUnlearning's features, gain feedback from wider usage and drive progress in the field.
2424

@@ -62,7 +62,7 @@ We provide several variants for each of the components in the unlearning pipelin
6262
| **Component** | **Available Options** |
6363
|------------------------|----------------------|
6464
| **Benchmarks** | [TOFU](https://arxiv.org/abs/2401.06121), [MUSE](https://muse-bench.github.io/), [WMDP](https://www.wmdp.ai/) |
65-
| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU |
65+
| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU, UNDIAL |
6666
| **Evaluation Metrics** | Verbatim Probability, Verbatim ROUGE, Knowledge QA-ROUGE, Model Utility, Forget Quality, TruthRatio, Extraction Strength, Exact Memorization, 6 MIA attacks, [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) |
6767
| **Datasets** | MUSE-News (BBC), MUSE-Books (Harry Potter), TOFU (different splits), WMDP-Bio, WMDP-Cyber |
6868
| **Model Families** | TOFU: LLaMA-3.2, LLaMA-3.1, LLaMA-2; MUSE: LLaMA-2; Additional: Phi-3.5, Phi-1.5, Gemma, Zephyr |

community/methods/UNDIAL/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# UNDIAL: Self-Distillation with Adjusted Logits for Robust Unlearning in Large Language Models (NAACL 2025)
2+
3+
- Authors: Yijiang River Dong, Hongzhou Lin, Mikhail Belkin, Ramón Huerta, Ivan Vulić
4+
- Link​: https://arxiv.org/pdf/2402.10052
5+
6+
# Setup
7+
- Hyperparameters: The original paper uses Llama-2 7B with LoRA to tune the model (rank=8, alpha=16) and learning rate of 1e-4. It's suggested to search the learning rate over [1e-5, 3e-4, 1e-4], and use an effective batch size of 32 (batch_size * gradient_accumulation). The other important hyperparemeter is beta, the strength of penalty, which typically takes a number between [3,10,30]. If we change to other models, adjusting learning rate accordingly.
8+
9+
- Computation Setup: All experiments are run on one A100.
10+
- Other Details: The original paper does not use the retain set and aims to retain knowledge in all domains, not just on the retain set. So alpha is set to 0. Practionioners could search over the alpha or gamma to better retain the performance on the retain set.
11+
12+
# Results
13+
Run `run.sh` script.
14+
15+
# Citation
16+
@misc{dong2024undial,
17+
title={UNDIAL: Self-Distillation with Adjusted Logits for Robust Unlearning in Large Language Models},
18+
author={Yijiang River Dong and Hongzhou Lin and Mikhail Belkin and Ramon Huerta and Ivan Vulić},
19+
year={2024},
20+
eprint={2402.10052},
21+
archivePrefix={arXiv},
22+
primaryClass={cs.CL},
23+
url={https://arxiv.org/abs/2402.10052},
24+
}

community/methods/UNDIAL/run.sh

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#!/bin/bash
2+
3+
export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
4+
echo "Master Port: $MASTER_PORT"
5+
6+
########################################################################################################################
7+
########################################### Unlearn TOFU models ########################################################
8+
########################################################################################################################
9+
10+
models=(
11+
"Llama-3.2-1B-Instruct"
12+
)
13+
trainers_experiments=(
14+
"UNDIAL unlearn/tofu/default.yaml"
15+
)
16+
forget_retain_splits=(
17+
"forget10 retain90"
18+
"forget05 retain95"
19+
"forget01 retain99"
20+
)
21+
22+
per_device_train_batch_size=16
23+
gradient_accumulation_steps=2
24+
25+
26+
lrs=(1e-5 1e-4 3e-4)
27+
alphas=(1 2 5)
28+
betas=(3 10 30)
29+
30+
31+
for split in "${forget_retain_splits[@]}"; do
32+
forget_split=$(echo $split | cut -d' ' -f1)
33+
retain_split=$(echo $split | cut -d' ' -f2)
34+
for model in "${models[@]}"; do
35+
for trainer_experiment in "${trainers_experiments[@]}"; do
36+
trainer=$(echo $trainer_experiment | cut -d' ' -f1)
37+
experiment=$(echo $trainer_experiment | cut -d' ' -f2)
38+
for lr in "${lrs[@]}"; do
39+
for beta in "${betas[@]}"; do
40+
for alpha in "${alphas[@]}"; do
41+
task_name=tofu_${model}_${forget_split}_${trainer}_lr${lr}_beta${beta}_alpha${alpha}
42+
model_path=open-unlearning/tofu_${model}_full
43+
echo ${task_name}: Unlearning ${model_path} using ${trainer}
44+
45+
# Unlearn
46+
CUDA_VISIBLE_DEVICES=0 \
47+
python src/train.py --config-name=unlearn.yaml \
48+
experiment=${experiment} \
49+
trainer=${trainer} \
50+
task_name=${task_name} \
51+
model=${model} \
52+
forget_split=${forget_split} \
53+
retain_split=${retain_split} \
54+
model.model_args.pretrained_model_name_or_path=${model_path} \
55+
retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json \
56+
trainer.args.per_device_train_batch_size=$per_device_train_batch_size \
57+
trainer.args.gradient_accumulation_steps=$gradient_accumulation_steps \
58+
trainer.args.eval_strategy=no \
59+
trainer.args.eval_on_start=False \
60+
trainer.args.learning_rate=$lr \
61+
trainer.method_args.beta=$beta \
62+
trainer.method_args.alpha=$alpha
63+
64+
# Eval
65+
CUDA_VISIBLE_DEVICES=0 python src/eval.py \
66+
experiment=eval/tofu/default.yaml \
67+
forget_split=${forget_split} \
68+
model=${model} \
69+
task_name=${task_name} \
70+
model.model_args.pretrained_model_name_or_path=saves/unlearn/${task_name} \
71+
paths.output_dir=saves/unlearn/${task_name}/evals \
72+
retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json
73+
done
74+
done
75+
done
76+
done
77+
done
78+
done

configs/trainer/UNDIAL.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
defaults:
2+
- finetune
3+
4+
handler: UNDIAL # corresponds to the class defined in src/trainer/unlearn/grad_diff.py
5+
args: # HuggingFace TrainingArguments
6+
learning_rate: 1e-4
7+
num_train_epochs: 10
8+
method_args: # Your own method-specific arguments
9+
gamma: 1.0
10+
alpha: 0.0
11+
beta: 10.0 # the strength of penalty for memorized tokens
12+
retain_loss_type: NLL

docs/links.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Links to research papers and resources corresponding to implemented features in
2525
| SimNPO | Paper [📄](https://arxiv.org/abs/2410.07163), Code [🐙](https://github.com/OPTML-Group/Unlearn-Simple) |
2626
| IdkDPO | TOFU ([📄](https://arxiv.org/abs/2401.06121)) |
2727
| RMU | WMDP paper ([🐙](https://github.com/centerforaisafety/wmdp/tree/main/rmu), [🌐](https://www.wmdp.ai/)), later used in G-effect ([🐙](https://github.com/tmlr-group/G-effect/blob/main/dataloader.py)) |
28+
| UNDIAL | Paper [📄](https://arxiv.org/pdf/2402.10052), Code [🐙](https://github.com/dong-river/LLM_unlearning/tree/main) |
2829

2930
---
3031

src/trainer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from trainer.unlearn.dpo import DPO
1111
from trainer.unlearn.simnpo import SimNPO
1212
from trainer.unlearn.rmu import RMU
13+
from trainer.unlearn.undial import UNDIAL
1314

1415
import logging
1516

@@ -88,3 +89,4 @@ def load_trainer(
8889
_register_trainer(DPO)
8990
_register_trainer(SimNPO)
9091
_register_trainer(RMU)
92+
_register_trainer(UNDIAL)

src/trainer/unlearn/undial.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from trainer.utils import compute_undial_loss
2+
from trainer.unlearn.grad_diff import GradDiff
3+
4+
5+
class UNDIAL(GradDiff):
6+
def __init__(self, beta=1.0, *args, **kwargs):
7+
super().__init__(*args, **kwargs)
8+
self.beta = beta
9+
if self.ref_model is None:
10+
self.ref_model = self._prepare_ref_model(self.model)
11+
12+
def compute_loss(self, model, inputs, return_outputs=False):
13+
forget_inputs = inputs["forget"]
14+
forget_loss, forget_outputs = compute_undial_loss(
15+
model, self.ref_model, forget_inputs, self.beta
16+
)
17+
18+
retain_inputs = inputs["retain"]
19+
retain_inputs = {
20+
"input_ids": retain_inputs["input_ids"],
21+
"attention_mask": retain_inputs["attention_mask"],
22+
"labels": retain_inputs["labels"],
23+
}
24+
retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs)
25+
26+
loss = self.gamma * forget_loss + self.alpha * retain_loss
27+
return (loss, forget_outputs) if return_outputs else loss

src/trainer/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,35 @@ def compute_dpo_loss(model, ref_model, win_inputs=None, lose_inputs=None, beta=1
6666

6767
loss = -2 / beta * F.logsigmoid(beta * (win_log_ratio - lose_log_ratio)).mean()
6868
return loss, (win_outputs, lose_outputs)
69+
70+
71+
def compute_undial_loss(model, ref_model, inputs, beta):
72+
# Forward pass on the student (trainable) model
73+
outputs = model(**inputs)
74+
logits = outputs.logits
75+
labels = inputs["labels"]
76+
77+
shift_labels = labels[..., 1:].contiguous()
78+
shift_logits = logits[..., :-1, :].contiguous()
79+
80+
# Forward pass on the teacher model (no grad)
81+
with torch.no_grad():
82+
teacher_logits = ref_model(**inputs).logits
83+
shift_teacher_logits = teacher_logits[..., :-1, :].contiguous()
84+
85+
# Build the mask that identifies the tokens need to be unlearned
86+
mask = torch.zeros_like(shift_teacher_logits)
87+
batch_idx = torch.arange(mask.shape[0]).view(-1, 1, 1)
88+
seq_idx = torch.arange(mask.shape[1]).view(1, -1, 1)
89+
mask[batch_idx, seq_idx, shift_labels.unsqueeze(-1)] = 1.0
90+
91+
# Adjust teacher logits: subtract di_strength on the correct token
92+
pre_softmax = shift_teacher_logits - mask * beta
93+
soft_label = F.softmax(pre_softmax, dim=-1)
94+
95+
loss_fct = nn.CrossEntropyLoss(reduction="none")
96+
loss = loss_fct(
97+
shift_logits.view(-1, shift_logits.size(-1)),
98+
soft_label.view(-1, soft_label.size(-1)),
99+
)
100+
return loss.mean(), outputs

0 commit comments

Comments
 (0)