Skip to content

Conversation

@IvanYashchuk
Copy link
Collaborator

@IvanYashchuk IvanYashchuk commented Nov 13, 2025

ChunkedHybridCache used in the inference benchmark is deprecated and should be replaced with StaticCache (https://github.com/huggingface/transformers/blob/ce40ca0d4c7d2e0a3f8bd3ddc30f29c6a105efb5/src/transformers/cache_utils.py#L1356).

This PR also removes unused keyword arguments when initializing StaticCache.

cc @crcrpar

@IvanYashchuk
Copy link
Collaborator Author

@kshitij12345, @riccardofelluga could you please review the change?

@riccardofelluga riccardofelluga self-requested a review November 13, 2025 14:06
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With pjnl-20251113 (and transformers version 4.55.4), running python thunder/benchmarks/benchmark_inference.py --model-name meta-llama/Llama-4-Maverick-17B-128E --mode eager --input-length 1024 --output-length 32 --batch-size 1 --num-iterations 20 --num-layers 2

leads to

Warming up with 10 iterations...
Traceback (most recent call last):
  File "/opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_inference.py", line 733, in <module>
    main()
  File "/opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_inference.py", line 722, in main
    benchmark.run_benchmark()
  File "/opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_inference.py", line 458, in run_benchmark
    input_ids, past_key_values = self.generate_batch()
                                 ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_inference.py", line 342, in generate_batch
    past_key_values = StaticCache(
                      ^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py", line 1451, in __init__
    super().__init__(layer_classes=StaticLayer, *args, **kwargs)
  File "/usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py", line 1110, in __init__
    self.append_new_layers(self.num_hidden_layers - 1)
  File "/usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py", line 1172, in append_new_layers
    new_layer = new_layer_class(**kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: StaticLayer.__init__() missing 1 required positional argument: 'batch_size'

# Transformers deprecated HybridChunkedCache in favour of static in 4.55.x
past_key_values = StaticCache(
config=self.hf_config,
max_batch_size=input_ids.shape[0],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the error here, I think max_batch_size is required.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for running it with transformers version 4.55.4! I was running with the latest release. Need to update the requirements pin first before merging this change.

max_batch_size=input_ids.shape[0],
max_cache_len=input_ids.shape[1] + self.config.output_length,
device=DEVICE,
dtype=torch.bfloat16,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, device and dtype seem necessary -

from transformers.cache_utils import StaticCache
from transformers import AutoConfig, AutoModelForCausalLM
import torch

model_id = "meta-llama/Llama-4-Maverick-17B-128E"
config = AutoConfig.from_pretrained(model_id)

if hasattr(config, "text_config"):
    config = config.text_config

config.num_hidden_layers = 2

past_key_values = StaticCache(config=config, max_batch_size=1, max_cache_len=256)

print(past_key_values.layers[0].keys.dtype)  # torch.float32
print(past_key_values.layers[0].keys.device)  # cpu

past_key_values = StaticCache(config=config, max_batch_size=1, max_cache_len=256, dtype=torch.bfloat16, device="cuda")

print(past_key_values.layers[0].keys.dtype)  # torch.bfloat16
print(past_key_values.layers[0].keys.device)  # cuda:0

Copy link
Collaborator

@riccardofelluga riccardofelluga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea to move on to the StaticCache. Just need couple of fixed on the args of the object.

Does perf improve?

dummy_key_states = torch.empty(1, self.hf_config.num_key_value_heads // WORLD_SIZE, 1, 1, device=DEVICE)
past_key_values.initialise_cache_layer(layer_idx, dummy_key_states)
past_key_values = StaticCache(
config=self.hf_config,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
config=self.hf_config,
config=self.hf_config,
max_batch_size=input_ids.shape[0],

past_key_values.initialise_cache_layer(layer_idx, dummy_key_states)
past_key_values = StaticCache(
config=self.hf_config,
max_cache_len=input_ids.shape[1] + self.config.output_length,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also device and dtype seem to be required:

RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_bmm)
Suggested change
max_cache_len=input_ids.shape[1] + self.config.output_length,
max_cache_len=input_ids.shape[1] + self.config.output_length,
device=DEVICE,
dtype=torch.bfloat16,

@IvanYashchuk
Copy link
Collaborator Author

Good idea to move on to the StaticCache.

It's not moving on, it's already used because of if LooseVersion(transformers.__version__) >= LooseVersion("4.55"): line.

for layer_idx in range(self.hf_config.num_hidden_layers):
# key_states.shape[1] is used to retrieve the number of key value heads, all other dimensions can be 1 and ignored
# https://github.com/huggingface/transformers/blob/9300728665aaeb0ebf4db99f9d9fbce916b4a183/src/transformers/cache_utils.py#L1822
dummy_key_states = torch.empty(1, self.hf_config.num_key_value_heads // WORLD_SIZE, 1, 1, device=DEVICE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to preserve hf_config.num_key_value_heads // WORLD_SIZE for distributed setting.

The patch can be something like the following

diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py
index 212f5f8e..13af8175 100644
--- a/thunder/benchmarks/benchmark_inference.py
+++ b/thunder/benchmarks/benchmark_inference.py
@@ -339,9 +339,15 @@ class InferenceBenchmark:
         input_length = self.config.input_length
 
         input_ids = torch.randint(0, self.vocab_size, (batch_size, input_length), device=DEVICE)
+        import copy
+        hf_config = copy.copy(self.hf_config)
+        hf_config.num_key_value_heads //= WORLD_SIZE
         past_key_values = StaticCache(
-            config=self.hf_config,
+            config=hf_config,
             max_cache_len=input_ids.shape[1] + self.config.output_length,
+            max_batch_size=batch_size,
+            dtype=torch.bfloat16,
+            device=DEVICE,
         )
 
         return input_ids, past_key_values

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants