- 
                Notifications
    You must be signed in to change notification settings 
- Fork 405
Perf: Adding torch.compile + static cache, ~3x speed up #88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| Simple inference code to verify output: import time
import torch
from PIL import Image
torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)
from models.vision_language_model import VisionLanguageModel
from data.processors import get_tokenizer, get_image_processor
from torch.utils import benchmark
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
print(f"Using device: {device}")
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
# debugging
torch._logging.set_logs(graph_breaks=True, recompiles=True)
torch.manual_seed(666)
if __name__ == "__main__":
    model = VisionLanguageModel.from_pretrained("lusxvr/nanoVLM-222M").to(
        device, dtype=dtype
    )
    model.eval()
    # model.decoder = torch.compile(model.decoder, mode="reduce-overhead", fullgraph=True)
    tokenizer = get_tokenizer(model.cfg.lm_tokenizer)
    image_processor = get_image_processor(model.cfg.vit_img_size)
    text = "What is this?"
    template = f"Question: {text} Answer:"
    encoded_batch = tokenizer.batch_encode_plus([template], return_tensors="pt")
    tokens = encoded_batch["input_ids"].to(device)
    image_path = "assets/image.png"
    image = Image.open(image_path)
    image = image_processor(image)
    image = image.unsqueeze(0).to(device, dtype)
    # Print table header
    print("\n" + "="*80)
    print(f"{'Configuration':<25} {'Time (s)':<12} {'Generated Text'}")
    print("="*80)
    # Without KV cache
    start = time.time()
    result = model.generate(
        tokens, image, max_new_tokens=128, use_kv_cache=False
    )
    end = time.time()
    generated_text = tokenizer.decode(result[0])
    print(f"{'Without KV cache':<25} {end - start:<12.3f} {generated_text[:50]}...")
    # Dynamic KV cache
    start = time.time()
    result = model.generate(
        tokens, 
        image, 
        max_new_tokens=128, 
        use_kv_cache=True, 
        kv_cache_implementation="dynamic"
    )
    end = time.time()
    generated_text = tokenizer.decode(result[0])
    print(f"{'Dynamic KV cache':<25} {end - start:<12.3f} {generated_text[:50]}...")
    # Static KV cache
    start = time.time()
    result = model.generate(
        tokens,
        image,
        max_new_tokens=128, 
        use_kv_cache=True, 
        kv_cache_implementation="static"
    )
    end = time.time()
    generated_text = tokenizer.decode(result[0])
    print(f"{'Static KV cache':<25} {end - start:<12.3f} {generated_text[:50]}...")
    model.decoder = torch.compile(
        model.decoder, mode="reduce-overhead", fullgraph=True
    )
    # Static KV cache (compiled) - multiple runs
    print("-"*80)
    print("Static KV cache (compiled) - Multiple runs:")
    print("-"*80)
    for i in range(3):
        start = time.time()
        result = model.generate(
            tokens, 
            image, 
            max_new_tokens=128, 
            use_kv_cache=True, 
            kv_cache_implementation="static"
        )
        end = time.time()
        generated_text = tokenizer.decode(result[0])
        print(f"{'Run ' + str(i+1):<25} {end - start:<12.3f} {generated_text[:50]}...")
    print("="*80)Output on A100:  | 
| This pr is WIP, but above code is testable @andimarafioti @lusxvr If you guy have a moment, please help me give it a shot! | 
| Update:  | 
| @andimarafioti can you help me review this PR? | 
| Hi! Sorry, I was busy with other stuff, I'll get to it soon 🙏 | 
| I was looking at it, is this 3x faster that the current kv-cache implementation? How much of that is the torch compile and how mache the kv cache implementation? | 
| 
 3x faster come from the combination of static kvcache + torch.compile: 
 | 
| Update new benchmark result on A100 with  @andimarafioti please help me advise the next step! | 
Did:
Result for 1000 tokens:
A100:
~x2.6 faster for fp32
~x3.2 faster for fp16
H100:
~x1.5 faster for fp32
~x2.7 faster for fp16
TODO: