Skip to content

Commit d980e6d

Browse files
committed
add experimental llama example
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 0229b79 commit d980e6d

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

experimental/llama3_attention.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
from llmcompressor.utils import dispatch_for_generation
7+
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs
8+
9+
# Select model and load it.
10+
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
11+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
12+
tokenizer = AutoTokenizer.from_pretrained(model_id)
13+
14+
# Select calibration dataset.
15+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
16+
DATASET_SPLIT = "train_sft"
17+
18+
# Select number of samples. 512 samples is a good place to start.
19+
# Increasing the number of samples can improve accuracy.
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
# Load dataset and preprocess.
24+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
25+
ds = ds.shuffle(seed=42)
26+
27+
28+
def preprocess(example):
29+
return {
30+
"text": tokenizer.apply_chat_template(
31+
example["messages"],
32+
tokenize=False,
33+
)
34+
}
35+
36+
37+
ds = ds.map(preprocess)
38+
39+
40+
# Tokenize inputs.
41+
def tokenize(sample):
42+
return tokenizer(
43+
sample["text"],
44+
padding=False,
45+
max_length=MAX_SEQUENCE_LENGTH,
46+
truncation=True,
47+
add_special_tokens=False,
48+
)
49+
50+
51+
ds = ds.map(tokenize, remove_columns=ds.column_names)
52+
53+
# Configure the quantization algorithm to run.
54+
recipe = QuantizationModifier(
55+
config_groups={
56+
"attention": QuantizationScheme(
57+
targets=["LlamaAttention"],
58+
input_activations=QuantizationArgs(
59+
num_bits=8, type="float", strategy="attn_head"
60+
),
61+
)
62+
}
63+
)
64+
65+
# Apply algorithms.
66+
oneshot(
67+
model=model,
68+
dataset=ds,
69+
recipe=recipe,
70+
max_seq_length=MAX_SEQUENCE_LENGTH,
71+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
72+
)
73+
74+
# Confirm generations of the quantized model look sane.
75+
print("\n\n")
76+
print("========== SAMPLE GENERATION ==============")
77+
dispatch_for_generation(model)
78+
sample = tokenizer("Hello my name is", return_tensors="pt")
79+
sample = {key: value.to(model.device) for key, value in sample.items()}
80+
output = model.generate(**sample, max_new_tokens=100)
81+
print(tokenizer.decode(output[0]))
82+
print("==========================================\n\n")
83+
84+
# Save to disk compressed.
85+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-attention-fp8-head"
86+
model.save_pretrained(SAVE_DIR, save_compressed=True)
87+
tokenizer.save_pretrained(SAVE_DIR)

0 commit comments

Comments
 (0)