Skip to content

Conversation

@tsdocode
Copy link

@tsdocode tsdocode commented May 28, 2025

Did:

  • Add Static KVCache from gpt-fast
  • Adjust logic to use new KV Cache
  • Fix code for torch.compile compatible

Result for 1000 tokens:

A100:
~x2.6 faster for fp32
~x3.2 faster for fp16

Branch main:
Using device: cuda
<torch.utils.benchmark.utils.common.Measurement object at 0x7f2b05192c50>
generate_tokens(tokens, image)
setup: from __main__ import generate_tokens
  23.38 s
  1 measurement, 10 runs , 24 threads



Using device: cuda with dtype: torch.float32, torch.compile: False
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8015c3df90>
generate_tokens(tokens, image)
setup: from __main__ import generate_tokens
  26.81 s
  1 measurement, 10 runs , 24 threads

Using device: cuda with dtype: torch.float32,  torch.compile: True
setup: from __main__ import generate_tokens
  9.97 s
  1 measurement, 10 runs , 24 threads


Using device: cuda with dtype: torch.float16, torch.compile: False
<torch.utils.benchmark.utils.common.Measurement object at 0x7f7c70058310>
generate_tokens(tokens, image)
setup: from __main__ import generate_tokens
  26.65 s
  1 measurement, 10 runs , 24 threads


Using device: cuda with dtype: torch.float16, torch.compile: True
<torch.utils.benchmark.utils.common.Measurement object at 0x7fe84850edd0>
generate_tokens(tokens, image)
setup: from __main__ import generate_tokens
  7.84 s
  1 measurement, 10 runs , 24 threads

H100:
~x1.5 faster for fp32
~x2.7 faster for fp16

Using device: cuda with dtype: torch.float32, torch.compile: False
H100

setup: from __main__ import generate_tokens
  11.05 s
  1 measurement, 10 runs , 112 threads


generate_tokens(tokens, image)

Using device: cuda with dtype: torch.float32, torch.compile: True
setup: from __main__ import generate_tokens
  7.07 s
  1 measurement, 10 runs , 112 threads


Using device: cuda with dtype: torch.float16, torch.compile: False
<torch.utils.benchmark.utils.common.Measurement object at 0x7ff124ac3510>
generate_tokens(tokens, image)
setup: from __main__ import generate_tokens
  10.81 s
  1 measurement, 10 runs , 112 threads

Using device: cuda with dtype: torch.float16, torch.compile: True
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd165ea8210>
generate_tokens(tokens, image)
setup: from __main__ import generate_tokens
  3.91 s
  1 measurement, 10 runs , 112 threads

TODO:

  • Adjust code for compatible with previous KVCache
  • Check training compatible

@tsdocode
Copy link
Author

tsdocode commented May 28, 2025

Simple inference code to verify output:

import time
import torch
from PIL import Image

torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

from models.vision_language_model import VisionLanguageModel
from data.processors import get_tokenizer, get_image_processor

from torch.utils import benchmark

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
print(f"Using device: {device}")

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True

# debugging
torch._logging.set_logs(graph_breaks=True, recompiles=True)


torch.manual_seed(666)


if __name__ == "__main__":
    model = VisionLanguageModel.from_pretrained("lusxvr/nanoVLM-222M").to(
        device, dtype=dtype
    )
    model.eval()

    # model.decoder = torch.compile(model.decoder, mode="reduce-overhead", fullgraph=True)

    tokenizer = get_tokenizer(model.cfg.lm_tokenizer)
    image_processor = get_image_processor(model.cfg.vit_img_size)

    text = "What is this?"
    template = f"Question: {text} Answer:"
    encoded_batch = tokenizer.batch_encode_plus([template], return_tensors="pt")
    tokens = encoded_batch["input_ids"].to(device)

    image_path = "assets/image.png"
    image = Image.open(image_path)
    image = image_processor(image)
    image = image.unsqueeze(0).to(device, dtype)

    # Print table header
    print("\n" + "="*80)
    print(f"{'Configuration':<25} {'Time (s)':<12} {'Generated Text'}")
    print("="*80)

    # Without KV cache
    start = time.time()
    result = model.generate(
        tokens, image, max_new_tokens=128, use_kv_cache=False
    )
    end = time.time()
    generated_text = tokenizer.decode(result[0])
    print(f"{'Without KV cache':<25} {end - start:<12.3f} {generated_text[:50]}...")

    # Dynamic KV cache
    start = time.time()
    result = model.generate(
        tokens, 
        image, 
        max_new_tokens=128, 
        use_kv_cache=True, 
        kv_cache_implementation="dynamic"
    )
    end = time.time()
    generated_text = tokenizer.decode(result[0])
    print(f"{'Dynamic KV cache':<25} {end - start:<12.3f} {generated_text[:50]}...")

    # Static KV cache
    start = time.time()
    result = model.generate(
        tokens,
        image,
        max_new_tokens=128, 
        use_kv_cache=True, 
        kv_cache_implementation="static"
    )
    end = time.time()
    generated_text = tokenizer.decode(result[0])
    print(f"{'Static KV cache':<25} {end - start:<12.3f} {generated_text[:50]}...")

    model.decoder = torch.compile(
        model.decoder, mode="reduce-overhead", fullgraph=True
    )

    # Static KV cache (compiled) - multiple runs
    print("-"*80)
    print("Static KV cache (compiled) - Multiple runs:")
    print("-"*80)
    for i in range(3):
        start = time.time()
        result = model.generate(
            tokens, 
            image, 
            max_new_tokens=128, 
            use_kv_cache=True, 
            kv_cache_implementation="static"
        )
        end = time.time()
        generated_text = tokenizer.decode(result[0])
        print(f"{'Run ' + str(i+1):<25} {end - start:<12.3f} {generated_text[:50]}...")

    print("="*80)

Output on A100:

================================================================================
Configuration             Time (s)     Generated Text
================================================================================
Without KV cache          5.102         This image is sitting on the path with a cat on t...
Dynamic KV cache          3.390         This is a cat sitting on the ground, which is of ...
Static KV cache           4.205         This is a cat sitting on the ground. In the backg...
--------------------------------------------------------------------------------
Static KV cache (compiled) - Multiple runs:
Run 1                     24.999        This is a cat sitting on the ground, which is gre...
Run 2                     2.085         This is a cat sitting on the ground. I think this...
Run 3                     2.188         This picture is clicked outside. In the center th...

@tsdocode tsdocode marked this pull request as draft May 28, 2025 18:20
@tsdocode
Copy link
Author

This pr is WIP, but above code is testable

@andimarafioti @lusxvr If you guy have a moment, please help me give it a shot!

@tsdocode tsdocode marked this pull request as ready for review May 29, 2025 16:57
@tsdocode
Copy link
Author

tsdocode commented May 29, 2025

Update:
Code in this PR is compatible with:no_kv_cache, dynamic_kv_cache (adapt logic from #69), and static_kv_cache

================================================================================
Configuration             Time (s)     Generated Text
================================================================================
Without KV cache          5.102         This image is sitting on the path with a cat on t...
Dynamic KV cache          3.390         This is a cat sitting on the ground, which is of ...
Static KV cache           4.205         This is a cat sitting on the ground. In the backg...
--------------------------------------------------------------------------------
Static KV cache (compiled) - Multiple runs:
Run 1                     24.999        This is a cat sitting on the ground, which is gre...
Run 2                     2.085         This is a cat sitting on the ground. I think this...
Run 3                     2.188         This picture is clicked outside. In the center th...

@tsdocode
Copy link
Author

@andimarafioti can you help me review this PR?

@andimarafioti
Copy link
Member

Hi! Sorry, I was busy with other stuff, I'll get to it soon 🙏

@andimarafioti
Copy link
Member

I was looking at it, is this 3x faster that the current kv-cache implementation? How much of that is the torch compile and how mache the kv cache implementation?

@tsdocode
Copy link
Author

tsdocode commented Jun 3, 2025

I was looking at it, is this 3x faster that the current kv-cache implementation? How much of that is the torch compile and how mache the kv cache implementation?

3x faster come from the combination of static kvcache + torch.compile:

  • torch.compile require static shape in inference, unless it will recompile everystep => why we need implement static kvcache
  • without torch compile:
    • dynamic kvcache: speed is nearly the same as previous
    • static shape: speed is a little lower on longer sequence length

@tsdocode
Copy link
Author

tsdocode commented Jun 8, 2025

Update new benchmark result on A100 with torch.compile + static cache + fp16 after separate prefill and decode method in LLM, it is now ~5x faster:

(f5) ➜  nanoVLM git:(perf/torch-compile) ✗ python benchmark-inference.py
Using device: cuda with dtype: torch.float16
<torch.utils.benchmark.utils.common.Measurement object at 0x7ff7ae582a90>
generate_tokens(tokens, image)
setup: from __main__ import generate_tokens
  5.52 s
  1 measurement, 10 runs , 24 threads

@andimarafioti please help me advise the next step!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants