Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions benchmarks/blackwell-attn-measure.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/bin/bash
set -uxo pipefail

RUNID=$(echo result_* | xargs -n1 | grep -v '\*' | wc -l)
RUNDIR=$PWD/result_$RUNID
mkdir $RUNDIR
set -e
nvidia-smi > $RUNDIR/nvidia-smi.log
lscpu > $RUNDIR/lscpu.log
hostname > $RUNDIR/hostname.log
. ./venv-fb-triton/bin/activate
uv pip list > $RUNDIR/fb-pip-list.log
deactivate
. venv-stock-triton/bin/activate
uv pip list > $RUNDIR/stock-pip-list.log
deactivate
find . -type d -name ".git" | while read gitdir; do
repo_dir=$(dirname "$gitdir")
commit_hash=$(git -C "$repo_dir" rev-parse HEAD 2>/dev/null)
if [ -n "$commit_hash" ]; then
echo "$repo_dir: $commit_hash" >> $RUNDIR/git-list.log
fi
done

root=$PWD
cd helion
HIDDEN_DIM=2048
TOTAL_TOKENS=16384
export WITH_GLUON=1
export HELION_BENCHMARK_DISABLE_LOGGING=1
for DHEAD in 64 128; do
NHEADS=$(($HIDDEN_DIM / $DHEAD))
for SEQLEN in 2048 4096 8192; do
BATCH=$(($TOTAL_TOKENS / $SEQLEN))

for only in cudnn_sdpa helion_blackwell_attention_tritonbench; do
$root/venv-fb-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_fb-triton.log
done

for only in helion_blackwell_attention_tritonbench; do
$root/venv-stock-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_stock-triton.log
done

for only in helion_blackwell_attention_tritonbench; do
WITH_ACC=1 $root/venv-fb-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_fb-triton-acc.log
done

for only in sdpa triton_tutorial_flash_v2 triton_tutorial_flash_v2_tma flex_attention; do
$root/venv-stock-triton/bin/python benchmarks/run.py --kernel flash_attention --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_stock-triton.log
done

# for only in helion_attention; do
# $root/venv-stock-triton/bin/python benchmarks/run.py --kernel flash_attention --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_stock-triton.log
# done

done
done
38 changes: 38 additions & 0 deletions benchmarks/blackwell-attn-parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

import csv
import glob
import sys

out = csv.writer(open(sys.argv[1] + "/data.csv", "w"))
out.writerow(["batch", "heads", "seqlen", "seqlen_kv", "dhead", "variant", "tflops"])
for f in glob.glob(sys.argv[1] + "/dhead_*"):
lines = list(reversed(list(open(f))))
i = -1
for i in range(len(lines)):
if lines[i].startswith("--------------"):
i -= 1
break
line = lines[i].replace("(", "").replace(")", ",")
line = line.split(",")
try:
if len(line) == 6:
batch, heads, seqlen, seqlen_kv, dhead, tflops = line
else:
batch, heads, heads_kv, seqlen, seqlen_kv, dhead, tflops = line
assert heads.strip() == heads_kv.strip()
except:
continue

variant = f.split("/")[-1].split(".log")[0].split("only_")[1]
out.writerow(
[
int(batch.strip()),
int(heads.strip()),
int(seqlen.strip()),
int(seqlen_kv.strip()),
int(dhead.strip()),
variant,
float(tflops.strip()),
]
)
Loading
Loading