-
Notifications
You must be signed in to change notification settings - Fork 108
Remove ChunkedHybridCache from benchmark_inference.py #2733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
|
@kshitij12345, @riccardofelluga could you please review the change? |
kshitij12345
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With pjnl-20251113 (and transformers version 4.55.4), running python thunder/benchmarks/benchmark_inference.py --model-name meta-llama/Llama-4-Maverick-17B-128E --mode eager --input-length 1024 --output-length 32 --batch-size 1 --num-iterations 20 --num-layers 2
leads to
Warming up with 10 iterations...
Traceback (most recent call last):
File "/opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_inference.py", line 733, in <module>
main()
File "/opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_inference.py", line 722, in main
benchmark.run_benchmark()
File "/opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_inference.py", line 458, in run_benchmark
input_ids, past_key_values = self.generate_batch()
^^^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_inference.py", line 342, in generate_batch
past_key_values = StaticCache(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py", line 1451, in __init__
super().__init__(layer_classes=StaticLayer, *args, **kwargs)
File "/usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py", line 1110, in __init__
self.append_new_layers(self.num_hidden_layers - 1)
File "/usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py", line 1172, in append_new_layers
new_layer = new_layer_class(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: StaticLayer.__init__() missing 1 required positional argument: 'batch_size'| # Transformers deprecated HybridChunkedCache in favour of static in 4.55.x | ||
| past_key_values = StaticCache( | ||
| config=self.hf_config, | ||
| max_batch_size=input_ids.shape[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the error here, I think max_batch_size is required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for running it with transformers version 4.55.4! I was running with the latest release. Need to update the requirements pin first before merging this change.
| max_batch_size=input_ids.shape[0], | ||
| max_cache_len=input_ids.shape[1] + self.config.output_length, | ||
| device=DEVICE, | ||
| dtype=torch.bfloat16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, device and dtype seem necessary -
from transformers.cache_utils import StaticCache
from transformers import AutoConfig, AutoModelForCausalLM
import torch
model_id = "meta-llama/Llama-4-Maverick-17B-128E"
config = AutoConfig.from_pretrained(model_id)
if hasattr(config, "text_config"):
config = config.text_config
config.num_hidden_layers = 2
past_key_values = StaticCache(config=config, max_batch_size=1, max_cache_len=256)
print(past_key_values.layers[0].keys.dtype) # torch.float32
print(past_key_values.layers[0].keys.device) # cpu
past_key_values = StaticCache(config=config, max_batch_size=1, max_cache_len=256, dtype=torch.bfloat16, device="cuda")
print(past_key_values.layers[0].keys.dtype) # torch.bfloat16
print(past_key_values.layers[0].keys.device) # cuda:0
riccardofelluga
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea to move on to the StaticCache. Just need couple of fixed on the args of the object.
Does perf improve?
| dummy_key_states = torch.empty(1, self.hf_config.num_key_value_heads // WORLD_SIZE, 1, 1, device=DEVICE) | ||
| past_key_values.initialise_cache_layer(layer_idx, dummy_key_states) | ||
| past_key_values = StaticCache( | ||
| config=self.hf_config, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| config=self.hf_config, | |
| config=self.hf_config, | |
| max_batch_size=input_ids.shape[0], |
| past_key_values.initialise_cache_layer(layer_idx, dummy_key_states) | ||
| past_key_values = StaticCache( | ||
| config=self.hf_config, | ||
| max_cache_len=input_ids.shape[1] + self.config.output_length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also device and dtype seem to be required:
RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_bmm)| max_cache_len=input_ids.shape[1] + self.config.output_length, | |
| max_cache_len=input_ids.shape[1] + self.config.output_length, | |
| device=DEVICE, | |
| dtype=torch.bfloat16, |
It's not moving on, it's already used because of |
| for layer_idx in range(self.hf_config.num_hidden_layers): | ||
| # key_states.shape[1] is used to retrieve the number of key value heads, all other dimensions can be 1 and ignored | ||
| # https://github.com/huggingface/transformers/blob/9300728665aaeb0ebf4db99f9d9fbce916b4a183/src/transformers/cache_utils.py#L1822 | ||
| dummy_key_states = torch.empty(1, self.hf_config.num_key_value_heads // WORLD_SIZE, 1, 1, device=DEVICE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also need to preserve hf_config.num_key_value_heads // WORLD_SIZE for distributed setting.
The patch can be something like the following
diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py
index 212f5f8e..13af8175 100644
--- a/thunder/benchmarks/benchmark_inference.py
+++ b/thunder/benchmarks/benchmark_inference.py
@@ -339,9 +339,15 @@ class InferenceBenchmark:
input_length = self.config.input_length
input_ids = torch.randint(0, self.vocab_size, (batch_size, input_length), device=DEVICE)
+ import copy
+ hf_config = copy.copy(self.hf_config)
+ hf_config.num_key_value_heads //= WORLD_SIZE
past_key_values = StaticCache(
- config=self.hf_config,
+ config=hf_config,
max_cache_len=input_ids.shape[1] + self.config.output_length,
+ max_batch_size=batch_size,
+ dtype=torch.bfloat16,
+ device=DEVICE,
)
return input_ids, past_key_values
ChunkedHybridCache used in the inference benchmark is deprecated and should be replaced with StaticCache (https://github.com/huggingface/transformers/blob/ce40ca0d4c7d2e0a3f8bd3ddc30f29c6a105efb5/src/transformers/cache_utils.py#L1356).
This PR also removes unused keyword arguments when initializing StaticCache.
cc @crcrpar