From 5d813b5127affefe9fa71961f3cde87e523a2b8e Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Thu, 16 Oct 2025 15:56:29 -0700 Subject: [PATCH 01/10] scripts for blackwell attn measurement --- benchmarks/blackwell-attn-measure.sh | 38 +++++++ benchmarks/blackwell-attn-parse.py | 36 +++++++ benchmarks/blackwell-attn-setup-env.sh | 139 +++++++++++++++++++++++++ 3 files changed, 213 insertions(+) create mode 100755 benchmarks/blackwell-attn-measure.sh create mode 100644 benchmarks/blackwell-attn-parse.py create mode 100755 benchmarks/blackwell-attn-setup-env.sh diff --git a/benchmarks/blackwell-attn-measure.sh b/benchmarks/blackwell-attn-measure.sh new file mode 100755 index 000000000..4bf39cef2 --- /dev/null +++ b/benchmarks/blackwell-attn-measure.sh @@ -0,0 +1,38 @@ +#!/bin/bash +set -euxo pipefail + +RUNID=$(echo result_* | xargs -n1 | wc -l) +RUNDIR=$PWD/result_$RUNID +mkdir $RUNDIR +nvidia-smi > $RUNDIR/nvidia-smi.log +lscpu > $RUNDIR/lscpu.log +hostname > $RUNDIR/hostname.log +uv pip list > $RUNDIR/pip-list.log +find . -type d -name ".git" | while read gitdir; do + repo_dir=$(dirname "$gitdir") + commit_hash=$(git -C "$repo_dir" rev-parse HEAD 2>/dev/null) + if [ -n "$commit_hash" ]; then + echo "$repo_dir: $commit_hash" >> $RUNDIR/git-list.log + fi +done + +cd helion +HIDDEN_DIM=2048 +TOTAL_TOKENS=16384 +export WITH_GLUON=1 +export HELION_BENCHMARK_DISABLE_LOGGING=1 +for DHEAD in 64 128; do + NHEADS=$(($HIDDEN_DIM / $DHEAD)) + for SEQLEN in 2048 4096 8192; do + BATCH=$(($TOTAL_TOKENS / $SEQLEN)) + for only in triton_tutorial_flash_dp_persistent_blackwell gluon_blackwell_tutorial_persistent_fwd cudnn_sdpa helion_blackwell_attention_tritonbench; do + python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only.log + done + + pushd benchmarks/tritonbench + for only in aten sdpa triton_tutorial_flash_v2 triton_tutorial_flash_v2_tma flex_attention; do + python run.py --op flash_attention --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only.log + done + popd + done +done diff --git a/benchmarks/blackwell-attn-parse.py b/benchmarks/blackwell-attn-parse.py new file mode 100644 index 000000000..47e1b9bdf --- /dev/null +++ b/benchmarks/blackwell-attn-parse.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import csv +import glob +import sys + +out = csv.writer(open(sys.argv[1] + "/data.csv", "w")) +out.writerow(["batch", "heads", "seqlen", "seqlen_kv", "dhead", "variant", "tflops"]) +for f in glob.glob(sys.argv[1] + "/dhead_*"): + lines = list(reversed(list(open(f)))) + i = -1 + for i in range(len(lines)): + if lines[i].startswith("--------------"): + i -= 1 + break + line = lines[i].replace("(", "").replace(")", ",") + line = line.split(",") + if len(line) == 6: + batch, heads, seqlen, seqlen_kv, dhead, tflops = line + else: + batch, heads, heads_kv, seqlen, seqlen_kv, dhead, tflops = line + assert heads.strip() == heads_kv.strip() + print(lines) + + variant = f.split("/")[-1].split(".log")[0].split("only_")[1] + out.writerow( + [ + int(batch.strip()), + int(heads.strip()), + int(seqlen.strip()), + int(seqlen_kv.strip()), + int(dhead.strip()), + variant, + float(tflops.strip()), + ] + ) diff --git a/benchmarks/blackwell-attn-setup-env.sh b/benchmarks/blackwell-attn-setup-env.sh new file mode 100755 index 000000000..4235d14aa --- /dev/null +++ b/benchmarks/blackwell-attn-setup-env.sh @@ -0,0 +1,139 @@ +#!/bin/bash +set -x +uv venv -p 3.12 --managed-python +. .venv/bin/activate +uv pip install --no-deps -r <(cat << EOF +arpeggio==2.0.3 +asttokens==3.0.0 +autopep8==2.3.2 +caliper-reader==0.4.1 +certifi==2025.10.5 +cfgv==3.4.0 +charset-normalizer==3.4.3 +cmake==3.31.6 +contourpy==1.3.3 +cuda-bindings==12.9.2 +cuda-pathfinder==1.3.0 +cuda-python==12.9.2 +cycler==0.12.1 +decorator==5.2.1 +dill==0.4.0 +distlib==0.4.0 +einops==0.8.1 +execnet==2.1.1 +executing==2.2.1 +expecttest==0.3.0 +filecheck==1.0.3 +filelock==3.19.1 +fonttools==4.60.1 +fsspec==2025.9.0 +hf-xet==1.1.10 +huggingface-hub==0.35.3 +identify==2.6.15 +idna==3.10 +iniconfig==2.1.0 +ipdb==0.13.13 +ipython==9.6.0 +ipython-pygments-lexers==1.1.1 +isort==6.1.0 +jedi==0.19.2 +jinja2==3.1.6 +kiwisolver==1.4.9 +lit==18.1.8 +llnl-hatchet==2025.1.0 +markdown-it-py==4.0.0 +markupsafe==3.0.2 +matplotlib==3.10.7 +matplotlib-inline==0.1.7 +mdurl==0.1.2 +mpmath==1.3.0 +multiprocess==0.70.18 +networkx==3.5 +ninja==1.13.0 +nodeenv==1.9.1 +numpy==2.3.3 +nvidia-cublas-cu12==12.8.4.1 +nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-runtime-cu12==12.8.90 +nvidia-cudnn-cu12==9.10.2.21 +nvidia-cufft-cu12==11.3.3.83 +nvidia-cufile-cu12==1.13.1.3 +nvidia-curand-cu12==10.3.9.90 +nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusparse-cu12==12.5.8.93 +nvidia-cusparselt-cu12==0.7.1 +nvidia-cutlass-dsl==4.1.0.dev0 +nvidia-ml-py==13.580.82 +nvidia-nccl-cu12==2.27.3 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvshmem-cu12==3.3.24 +nvidia-nvtx-cu12==12.8.90 +packaging==25.0 +pandas==2.3.3 +parso==0.8.5 +pexpect==4.9.0 +pillow==11.3.0 +pip==25.2 +platformdirs==4.5.0 +pluggy==1.6.0 +pre-commit==4.3.0 +prompt-toolkit==3.0.52 +psutil==7.1.0 +ptyprocess==0.7.0 +pure-eval==0.2.3 +py==1.11.0 +pybind11==3.0.1 +pycodestyle==2.14.0 +pydot==4.0.1 +pygments==2.19.2 +pyparsing==3.2.5 +pyright==1.1.406 +pytest==8.4.2 +pytest-forked==1.6.0 +pytest-xdist==3.8.0 +python-dateutil==2.9.0.post0 +pytz==2025.2 +pyyaml==6.0.3 +regex==2025.9.18 +requests==2.32.5 +rich==14.2.0 +ruff==0.14.0 +safetensors==0.6.2 +scipy==1.16.2 +setuptools==78.1.0 +six==1.17.0 +stack-data==0.6.3 +sympy==1.14.0 +tabulate==0.9.0 +textx==4.2.3 +tokenizers==0.20.3 +tqdm==4.67.1 +traitlets==5.14.3 +transformers==4.46.1 +typing-extensions==4.15.0 +tzdata==2025.2 +urllib3==2.5.0 +virtualenv==20.35.2 +wcwidth==0.2.14 +wheel==0.45.1 +EOF +) +git clone https://github.com/facebookexperimental/triton.git +pushd triton + git checkout 2f987ec37f7856f02b11de1c4a742975bdb77739 + make dev-install-llvm +popd +uv pip install --pre torch==2.10.0.dev20251008+cu128 torchvision==0.25.0.dev20251009+cu128 --index-url https://download.pytorch.org/whl/nightly/cu128 --no-deps +git clone https://github.com/pytorch/helion.git +pushd helion + git checkout f5ba06da5811f295d8c7373a47c7ee3c90d76a13 + uv pip install -e --no-deps . + pushd benchmarks + git clone https://github.com/meta-pytorch/tritonbench.git + pushd tritonbench + git checkout 9a4bbc7070b134fb274114018ac02b38fcfd4ba7 + uv pip install -e --no-deps . + popd + popd +popd From b465ccd58fabb07906911ac7efa520d44bdbd7e5 Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Fri, 17 Oct 2025 10:25:48 -0700 Subject: [PATCH 02/10] fix --- benchmarks/blackwell-attn-setup-env.sh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/benchmarks/blackwell-attn-setup-env.sh b/benchmarks/blackwell-attn-setup-env.sh index 4235d14aa..f713e45a9 100755 --- a/benchmarks/blackwell-attn-setup-env.sh +++ b/benchmarks/blackwell-attn-setup-env.sh @@ -1,5 +1,9 @@ #!/bin/bash -set -x +set -ex +# prereqs +command -v uv +command -v clang +command -v lld uv venv -p 3.12 --managed-python . .venv/bin/activate uv pip install --no-deps -r <(cat << EOF @@ -128,12 +132,12 @@ uv pip install --pre torch==2.10.0.dev20251008+cu128 torchvision==0.25.0.dev2025 git clone https://github.com/pytorch/helion.git pushd helion git checkout f5ba06da5811f295d8c7373a47c7ee3c90d76a13 - uv pip install -e --no-deps . + uv pip install --no-deps -e . pushd benchmarks git clone https://github.com/meta-pytorch/tritonbench.git pushd tritonbench git checkout 9a4bbc7070b134fb274114018ac02b38fcfd4ba7 - uv pip install -e --no-deps . + uv pip install --no-deps -e . popd popd popd From cce3c01e0b9625a3885df61bea774b8feadce8d8 Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Fri, 17 Oct 2025 13:49:18 -0700 Subject: [PATCH 03/10] fix --- benchmarks/blackwell-attn-measure.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/blackwell-attn-measure.sh b/benchmarks/blackwell-attn-measure.sh index 4bf39cef2..d3c1644af 100755 --- a/benchmarks/blackwell-attn-measure.sh +++ b/benchmarks/blackwell-attn-measure.sh @@ -2,6 +2,7 @@ set -euxo pipefail RUNID=$(echo result_* | xargs -n1 | wc -l) +RUNID=$(($RUNID + 1)) RUNDIR=$PWD/result_$RUNID mkdir $RUNDIR nvidia-smi > $RUNDIR/nvidia-smi.log From 1580460cc0e989c425778b870f1efff809827f7c Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Mon, 20 Oct 2025 10:15:22 -0700 Subject: [PATCH 04/10] update measurements --- benchmarks/blackwell-attn-measure.sh | 34 +++++++++++++++++++------- benchmarks/blackwell-attn-setup-env.sh | 32 ++++++++++++++++-------- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/benchmarks/blackwell-attn-measure.sh b/benchmarks/blackwell-attn-measure.sh index d3c1644af..584166fc6 100755 --- a/benchmarks/blackwell-attn-measure.sh +++ b/benchmarks/blackwell-attn-measure.sh @@ -1,14 +1,19 @@ #!/bin/bash -set -euxo pipefail +set -uxo pipefail -RUNID=$(echo result_* | xargs -n1 | wc -l) -RUNID=$(($RUNID + 1)) +RUNID=$(echo result_* | xargs -n1 | grep -v '\*' | wc -l) RUNDIR=$PWD/result_$RUNID mkdir $RUNDIR +set -e nvidia-smi > $RUNDIR/nvidia-smi.log lscpu > $RUNDIR/lscpu.log hostname > $RUNDIR/hostname.log -uv pip list > $RUNDIR/pip-list.log +. ./venv-fb-triton/bin/activate +uv pip list > $RUNDIR/fb-pip-list.log +deactivate +. venv-stock-triton/bin/activate +uv pip list > $RUNDIR/stock-pip-list.log +deactivate find . -type d -name ".git" | while read gitdir; do repo_dir=$(dirname "$gitdir") commit_hash=$(git -C "$repo_dir" rev-parse HEAD 2>/dev/null) @@ -17,6 +22,7 @@ find . -type d -name ".git" | while read gitdir; do fi done +root=$PWD cd helion HIDDEN_DIM=2048 TOTAL_TOKENS=16384 @@ -26,13 +32,23 @@ for DHEAD in 64 128; do NHEADS=$(($HIDDEN_DIM / $DHEAD)) for SEQLEN in 2048 4096 8192; do BATCH=$(($TOTAL_TOKENS / $SEQLEN)) - for only in triton_tutorial_flash_dp_persistent_blackwell gluon_blackwell_tutorial_persistent_fwd cudnn_sdpa helion_blackwell_attention_tritonbench; do - python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only.log + for only in cudnn_sdpa; do + $root/venv-stock-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_stock-triton.log done - pushd benchmarks/tritonbench - for only in aten sdpa triton_tutorial_flash_v2 triton_tutorial_flash_v2_tma flex_attention; do - python run.py --op flash_attention --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only.log + for venv in stock-triton fb-triton; do + for only in helion_blackwell_attention_tritonbench; do + $root/venv-$venv/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_$venv.log + done + done + + for only in helion_blackwell_attention_tritonbench; do + WITH_ACC=1 $root/venv-fb-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_fb-triton-acc.log + done + + pushd benchmarks/tritonbench + for only in sdpa triton_tutorial_flash_v2 triton_tutorial_flash_v2_tma flex_attention; do + $root/venv-stock-triton/bin/python run.py --op flash_attention --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_stock-triton.log done popd done diff --git a/benchmarks/blackwell-attn-setup-env.sh b/benchmarks/blackwell-attn-setup-env.sh index f713e45a9..f8362ca6b 100755 --- a/benchmarks/blackwell-attn-setup-env.sh +++ b/benchmarks/blackwell-attn-setup-env.sh @@ -4,8 +4,10 @@ set -ex command -v uv command -v clang command -v lld -uv venv -p 3.12 --managed-python -. .venv/bin/activate +command -v nvcc +for triton_kind in fb-triton stock-triton; do +uv venv -p 3.12 --managed-python venv-$triton_kind +. venv-$triton_kind/bin/activate uv pip install --no-deps -r <(cat << EOF arpeggio==2.0.3 asttokens==3.0.0 @@ -123,21 +125,31 @@ wcwidth==0.2.14 wheel==0.45.1 EOF ) -git clone https://github.com/facebookexperimental/triton.git -pushd triton - git checkout 2f987ec37f7856f02b11de1c4a742975bdb77739 - make dev-install-llvm -popd +if [ $triton_kind == fb-triton ]; then + git clone https://github.com/facebookexperimental/triton.git + pushd triton + git checkout a0fed580f88c02a30fcf22349fc242ff233b99c3 + make dev-install-llvm + popd +else + uv pip install --pre pytorch-triton==3.5.0+git27664085 --index-url https://download.pytorch.org/whl/nightly/cu128 --no-deps +fi uv pip install --pre torch==2.10.0.dev20251008+cu128 torchvision==0.25.0.dev20251009+cu128 --index-url https://download.pytorch.org/whl/nightly/cu128 --no-deps -git clone https://github.com/pytorch/helion.git +if [ $triton_kind == stock-triton ]; then + # only install in one venv since it takes a while to build + uv pip install --no-deps --no-cache --no-build-isolation -v flash-attn +fi +git clone https://github.com/pytorch/helion.git || true pushd helion - git checkout f5ba06da5811f295d8c7373a47c7ee3c90d76a13 + git checkout d50227d909706f8e37de02d25c5787e72355e9f3 uv pip install --no-deps -e . pushd benchmarks - git clone https://github.com/meta-pytorch/tritonbench.git + git clone https://github.com/meta-pytorch/tritonbench.git || true pushd tritonbench git checkout 9a4bbc7070b134fb274114018ac02b38fcfd4ba7 uv pip install --no-deps -e . popd popd popd +deactivate +done From 5dfd2b0000511462ac741b13de1cc3ad71dd385e Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Mon, 20 Oct 2025 11:17:56 -0700 Subject: [PATCH 05/10] measure with acc, downselect measurements --- benchmarks/blackwell-attn-measure.sh | 22 ++++++++++------------ benchmarks/blackwell-attn-setup-env.sh | 2 +- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/benchmarks/blackwell-attn-measure.sh b/benchmarks/blackwell-attn-measure.sh index 584166fc6..3b8c1af83 100755 --- a/benchmarks/blackwell-attn-measure.sh +++ b/benchmarks/blackwell-attn-measure.sh @@ -32,24 +32,22 @@ for DHEAD in 64 128; do NHEADS=$(($HIDDEN_DIM / $DHEAD)) for SEQLEN in 2048 4096 8192; do BATCH=$(($TOTAL_TOKENS / $SEQLEN)) - for only in cudnn_sdpa; do - $root/venv-stock-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_stock-triton.log - done - for venv in stock-triton fb-triton; do - for only in helion_blackwell_attention_tritonbench; do - $root/venv-$venv/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_$venv.log - done + for only in cudnn_sdpa helion_blackwell_attention_tritonbench; do + $root/venv-fb-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_fb-triton.log done for only in helion_blackwell_attention_tritonbench; do WITH_ACC=1 $root/venv-fb-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_fb-triton-acc.log done - pushd benchmarks/tritonbench - for only in sdpa triton_tutorial_flash_v2 triton_tutorial_flash_v2_tma flex_attention; do - $root/venv-stock-triton/bin/python run.py --op flash_attention --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_stock-triton.log - done - popd + for only in helion_attention sdpa triton_tutorial_flash_v2 triton_tutorial_flash_v2_tma flex_attention; do + $root/venv-stock-triton/bin/python benchmarks/run.py --kernel flash_attention --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_stock-triton.log + done + + # for only in helion_attention; do + # $root/venv-stock-triton/bin/python benchmarks/run.py --kernel flash_attention --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_stock-triton.log + # done + done done diff --git a/benchmarks/blackwell-attn-setup-env.sh b/benchmarks/blackwell-attn-setup-env.sh index f8362ca6b..8f2ee9bdb 100755 --- a/benchmarks/blackwell-attn-setup-env.sh +++ b/benchmarks/blackwell-attn-setup-env.sh @@ -141,7 +141,7 @@ if [ $triton_kind == stock-triton ]; then fi git clone https://github.com/pytorch/helion.git || true pushd helion - git checkout d50227d909706f8e37de02d25c5787e72355e9f3 + git checkout 294b3db5dda69874b69a769baf6a903e67dc575f uv pip install --no-deps -e . pushd benchmarks git clone https://github.com/meta-pytorch/tritonbench.git || true From 576c8f5b273470f785b5066aa540e5330538d2aa Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Mon, 20 Oct 2025 14:27:16 -0700 Subject: [PATCH 06/10] update --- benchmarks/blackwell-attn-measure.sh | 2 +- benchmarks/blackwell-attn-parse.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/benchmarks/blackwell-attn-measure.sh b/benchmarks/blackwell-attn-measure.sh index 3b8c1af83..180879b66 100755 --- a/benchmarks/blackwell-attn-measure.sh +++ b/benchmarks/blackwell-attn-measure.sh @@ -41,7 +41,7 @@ for DHEAD in 64 128; do WITH_ACC=1 $root/venv-fb-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_fb-triton-acc.log done - for only in helion_attention sdpa triton_tutorial_flash_v2 triton_tutorial_flash_v2_tma flex_attention; do + for only in sdpa triton_tutorial_flash_v2 triton_tutorial_flash_v2_tma flex_attention; do $root/venv-stock-triton/bin/python benchmarks/run.py --kernel flash_attention --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_stock-triton.log done diff --git a/benchmarks/blackwell-attn-parse.py b/benchmarks/blackwell-attn-parse.py index 47e1b9bdf..bc3304cda 100644 --- a/benchmarks/blackwell-attn-parse.py +++ b/benchmarks/blackwell-attn-parse.py @@ -15,12 +15,14 @@ break line = lines[i].replace("(", "").replace(")", ",") line = line.split(",") - if len(line) == 6: - batch, heads, seqlen, seqlen_kv, dhead, tflops = line - else: - batch, heads, heads_kv, seqlen, seqlen_kv, dhead, tflops = line - assert heads.strip() == heads_kv.strip() - print(lines) + try: + if len(line) == 6: + batch, heads, seqlen, seqlen_kv, dhead, tflops = line + else: + batch, heads, heads_kv, seqlen, seqlen_kv, dhead, tflops = line + assert heads.strip() == heads_kv.strip() + except: + continue variant = f.split("/")[-1].split(".log")[0].split("only_")[1] out.writerow( From 43e85f53fd0b6163fa7781fc9d9c5ea11e7d2803 Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Mon, 20 Oct 2025 15:15:58 -0700 Subject: [PATCH 07/10] move pin to war --- benchmarks/blackwell-attn-setup-env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/blackwell-attn-setup-env.sh b/benchmarks/blackwell-attn-setup-env.sh index 8f2ee9bdb..a7c18c96c 100755 --- a/benchmarks/blackwell-attn-setup-env.sh +++ b/benchmarks/blackwell-attn-setup-env.sh @@ -141,7 +141,7 @@ if [ $triton_kind == stock-triton ]; then fi git clone https://github.com/pytorch/helion.git || true pushd helion - git checkout 294b3db5dda69874b69a769baf6a903e67dc575f + git checkout 736d02b uv pip install --no-deps -e . pushd benchmarks git clone https://github.com/meta-pytorch/tritonbench.git || true From f98d1f45e4c7be0ac74fe67084e4ef142af70391 Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Mon, 20 Oct 2025 15:36:18 -0700 Subject: [PATCH 08/10] update to include stock triton --- benchmarks/blackwell-attn-measure.sh | 4 ++++ benchmarks/blackwell-attn-setup-env.sh | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/benchmarks/blackwell-attn-measure.sh b/benchmarks/blackwell-attn-measure.sh index 180879b66..3090362f3 100755 --- a/benchmarks/blackwell-attn-measure.sh +++ b/benchmarks/blackwell-attn-measure.sh @@ -37,6 +37,10 @@ for DHEAD in 64 128; do $root/venv-fb-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_fb-triton.log done + for only in cudnn_sdpa helion_blackwell_attention_tritonbench; do + $root/venv-stock-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_stock-triton.log + done + for only in helion_blackwell_attention_tritonbench; do WITH_ACC=1 $root/venv-fb-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_fb-triton-acc.log done diff --git a/benchmarks/blackwell-attn-setup-env.sh b/benchmarks/blackwell-attn-setup-env.sh index a7c18c96c..fca38a2f9 100755 --- a/benchmarks/blackwell-attn-setup-env.sh +++ b/benchmarks/blackwell-attn-setup-env.sh @@ -141,7 +141,7 @@ if [ $triton_kind == stock-triton ]; then fi git clone https://github.com/pytorch/helion.git || true pushd helion - git checkout 736d02b + git checkout 1990fb8 uv pip install --no-deps -e . pushd benchmarks git clone https://github.com/meta-pytorch/tritonbench.git || true From eb390f7e3d6cfb9bdc35aa617de3a3e529ca5071 Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Mon, 20 Oct 2025 17:10:45 -0700 Subject: [PATCH 09/10] remove redundant sdpa --- benchmarks/blackwell-attn-measure.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/blackwell-attn-measure.sh b/benchmarks/blackwell-attn-measure.sh index 3090362f3..3b3ee30ef 100755 --- a/benchmarks/blackwell-attn-measure.sh +++ b/benchmarks/blackwell-attn-measure.sh @@ -37,7 +37,7 @@ for DHEAD in 64 128; do $root/venv-fb-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_fb-triton.log done - for only in cudnn_sdpa helion_blackwell_attention_tritonbench; do + for only in helion_blackwell_attention_tritonbench; do $root/venv-stock-triton/bin/python benchmarks/run.py --kernel blackwell_attentions --d-head $DHEAD --seq-len $SEQLEN --batch $BATCH --n-heads $NHEADS --metrics tflops --simple-output --rep 3000 --sleep 1.0 --num-inputs 1 --only $only --force --input-id 0 |& tee $RUNDIR/dhead_$DHEAD-seqlen_$SEQLEN-only_$only-venv_stock-triton.log done From d1d765288b714f8d2311a29b931ed5d969b55e57 Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Mon, 20 Oct 2025 20:19:50 -0700 Subject: [PATCH 10/10] add plot --- benchmarks/blackwell-attn-plot.py | 840 ++++++++++++++++++++++++++++++ 1 file changed, 840 insertions(+) create mode 100644 benchmarks/blackwell-attn-plot.py diff --git a/benchmarks/blackwell-attn-plot.py b/benchmarks/blackwell-attn-plot.py new file mode 100644 index 000000000..3b6b7056e --- /dev/null +++ b/benchmarks/blackwell-attn-plot.py @@ -0,0 +1,840 @@ +#!/usr/bin/env python3 +""" +Terminal bar plot generator from CSV files. +No dependencies required - uses only Python standard library. + +Usage: + python plot_bars.py data.csv --x col1 --legend col2 --value col3 + python plot_bars.py data.csv --x col1,col2 --legend col3 --value col4 --colors key1=red,key2=blue + python plot_bars.py data.csv --x col1 --legend col2 --value col3 --vertical + python plot_bars.py data.csv --x col1 --legend col2 --value col3 --patterns key1=/,key2=\\ + python plot_bars.py data.csv --x col1 --legend col2 --value col3 --rename old_name=new_name +""" + +from __future__ import annotations + +import argparse +from collections import defaultdict +import csv +import sys + +# Unicode block characters for drawing bars +BLOCKS = [" ", "▏", "▎", "▍", "▌", "▋", "▊", "▉", "█"] + +# Vertical bar characters +VERTICAL_BLOCKS = [" ", "▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"] + +# Pattern characters for bars +PATTERNS = { + "/": "/", + "\\": "\\", + "x": "x", + "+": "+", + "-": "-", + "|": "|", + ".": ".", + "o": "o", + "*": "*", + "#": "#", + "=": "=", +} + +# ANSI color codes +COLORS = { + "red": "\033[91m", + "green": "\033[92m", + "yellow": "\033[93m", + "blue": "\033[94m", + "magenta": "\033[95m", + "cyan": "\033[96m", + "white": "\033[97m", + "gray": "\033[90m", + "reset": "\033[0m", +} + +# Default color palette +DEFAULT_PALETTE = ["blue", "green", "red", "yellow", "magenta", "cyan", "white", "gray"] + + +def parse_csv(filename: str) -> tuple[list[str], list[dict[str, str]]]: + """Parse CSV file and return headers and rows.""" + with open(filename) as f: + reader = csv.DictReader(f) + headers = reader.fieldnames + rows = list(reader) + return headers, rows + + +def make_tuple(row: dict[str, str], columns: list[str]) -> tuple[str, ...]: + """Create a tuple from specified columns.""" + return tuple(row[col] for col in columns) + + +def aggregate_data( + rows: list[dict[str, str]], + x_cols: list[str], + legend_cols: list[str] | None, + value_col: str, +) -> dict[tuple, dict[tuple, list[float]]]: + """ + Aggregate data by x-axis and legend keys. + Returns: {x_key: {legend_key: [values]}} + """ + data = defaultdict(lambda: defaultdict(list)) + + for row in rows: + x_key = make_tuple(row, x_cols) + + if legend_cols: + legend_key = make_tuple(row, legend_cols) + else: + legend_key = ("",) + + try: + value = float(row[value_col]) + except (ValueError, KeyError): + continue + + data[x_key][legend_key].append(value) + + return data + + +def format_tuple(t: tuple[str, ...]) -> str: + """Format tuple for display.""" + if len(t) == 1: + return t[0] + return "(" + ", ".join(t) + ")" + + +def draw_bar( + value: float, max_value: float, width: int, color: str, pattern: str | None = None +) -> str: + """Draw a single bar using Unicode block characters.""" + if max_value == 0: + return "" + + # Calculate bar length in characters + ratio = value / max_value + full_blocks = int(ratio * width) + remainder = (ratio * width - full_blocks) * 8 + partial_block_idx = int(remainder) + + # Build the bar + if pattern and pattern in PATTERNS: + # Use pattern character instead of block + bar = PATTERNS[pattern] * full_blocks + if full_blocks < width and partial_block_idx > 0: + bar += BLOCKS[partial_block_idx] + else: + bar = BLOCKS[-1] * full_blocks + if full_blocks < width and partial_block_idx > 0: + bar += BLOCKS[partial_block_idx] + + # Apply color + color_code = COLORS.get(color, "") + reset_code = COLORS["reset"] if color_code else "" + + return f"{color_code}{bar}{reset_code}" + + +def apply_exclusions( + data: dict[tuple, dict[tuple, list[float]]], exclude_list: list[str] +) -> dict[tuple, dict[tuple, list[float]]]: + """Filter out excluded series from data.""" + if not exclude_list: + return data + + filtered_data = {} + for x_key, legend_dict in data.items(): + filtered_dict = {} + for legend_key, values in legend_dict.items(): + key_str = format_tuple(legend_key) + if key_str not in exclude_list: + filtered_dict[legend_key] = values + if filtered_dict: # Only keep x_key if it has data after filtering + filtered_data[x_key] = filtered_dict + + return filtered_data + + +def apply_combinations( + data: dict[tuple, dict[tuple, list[float]]], + combinations: list[tuple[str, list[str], str]], +) -> tuple[dict[tuple, dict[tuple, list[float]]], dict[tuple, str]]: + """ + Combine multiple series into one using specified aggregation. + combinations: List of (new_name, [series_to_combine], aggregation_method) + Returns: (modified_data, aggregation_map) where aggregation_map maps combined series to their agg method + """ + if not combinations: + return data, {} + + import statistics + + new_data = {} + agg_map = {} # Maps combined series keys to their aggregation method + + for x_key, legend_dict in data.items(): + new_dict = dict(legend_dict) # Copy existing data + + for new_name, series_names, agg_method in combinations: + # Collect values from each series separately for proper aggregation + series_values = [] # List of lists, one per series + keys_to_remove = [] + + for legend_key, values in legend_dict.items(): + key_str = format_tuple(legend_key) + if key_str in series_names: + series_values.append(values) + keys_to_remove.append(legend_key) + + if series_values: + # Remove the original series + for key in keys_to_remove: + new_dict.pop(key, None) + + # Combine based on aggregation method + # For each x position, we need to aggregate across the series + new_key = (new_name,) + + # Get max length (some series might have different numbers of values) + max_len = max(len(vals) for vals in series_values) + + combined_values = [] + for i in range(max_len): + # Get all values at position i across all series + point_values = [] + for vals in series_values: + if i < len(vals): + point_values.append(vals[i]) + + if point_values: + if agg_method == "min": + combined_values.append(min(point_values)) + elif agg_method == "max": + combined_values.append(max(point_values)) + elif agg_method == "mean": + combined_values.append(statistics.mean(point_values)) + elif agg_method == "sum": + combined_values.append(sum(point_values)) + else: + combined_values.append(statistics.mean(point_values)) + + new_dict[new_key] = combined_values + agg_map[new_key] = agg_method + + new_data[x_key] = new_dict + + return new_data, agg_map + + +def aggregate_values( + data: dict[tuple, dict[tuple, list[float]]], agg_method: str = "sum" +) -> dict[tuple, dict[tuple, float]]: + """ + Aggregate lists of values to single values using specified method. + """ + import statistics + + aggregated = {} + for x_key, legend_dict in data.items(): + agg_dict = {} + for legend_key, values in legend_dict.items(): + if not values: + agg_dict[legend_key] = 0.0 + elif agg_method == "sum": + agg_dict[legend_key] = sum(values) + elif agg_method == "mean": + agg_dict[legend_key] = statistics.mean(values) + elif agg_method == "min": + agg_dict[legend_key] = min(values) + elif agg_method == "max": + agg_dict[legend_key] = max(values) + else: + agg_dict[legend_key] = sum(values) # Default to sum + aggregated[x_key] = agg_dict + + return aggregated + + +def apply_renames( + legend_keys: list[tuple], renames: dict[str, str] +) -> dict[tuple, str]: + """Create a mapping of original keys to renamed display strings.""" + rename_map = {} + for key in legend_keys: + key_str = format_tuple(key) + rename_map[key] = renames.get(key_str, key_str) + return rename_map + + +def assign_colors( + legend_keys: list[tuple], custom_colors: dict[str, str] +) -> dict[tuple, str]: + """Assign colors to legend keys.""" + color_map = {} + palette_idx = 0 + + for key in legend_keys: + key_str = format_tuple(key) + + # Check if custom color is specified + if key_str in custom_colors: + color_map[key] = custom_colors[key_str] + else: + # Auto-assign from palette + color_map[key] = DEFAULT_PALETTE[palette_idx % len(DEFAULT_PALETTE)] + palette_idx += 1 + + return color_map + + +def assign_patterns( + legend_keys: list[tuple], custom_patterns: dict[str, str] +) -> dict[tuple, str | None]: + """Assign patterns to legend keys.""" + pattern_map = {} + + for key in legend_keys: + key_str = format_tuple(key) + + # Check if custom pattern is specified + if key_str in custom_patterns: + pattern_map[key] = custom_patterns[key_str] + else: + pattern_map[key] = None + + return pattern_map + + +def plot_bars( + data: dict[tuple, dict[tuple, float]], + x_cols: list[str], + legend_cols: list[str] | None, + value_col: str, + custom_colors: dict[str, str], + custom_patterns: dict[str, str], + renames: dict[str, str], + bar_width: int = 50, + show_values: bool = True, +): + """Generate and print terminal bar plot.""" + + # Get all legend keys and assign colors, patterns, and renames + all_legend_keys = set() + for legend_dict in data.values(): + all_legend_keys.update(legend_dict.keys()) + legend_keys = sorted(all_legend_keys) + color_map = assign_colors(legend_keys, custom_colors) + pattern_map = assign_patterns(legend_keys, custom_patterns) + rename_map = apply_renames(legend_keys, renames) + + # Find max value for scaling + max_value = 0 + for legend_dict in data.values(): + for value in legend_dict.values(): + max_value = max(max_value, value) + + # Calculate label width + x_label = "+".join(x_cols) + max_x_label_len = max(len(format_tuple(k)) for k in data) + label_width = max(len(x_label), max_x_label_len) + + # Print title + print( + f"\n{value_col} by {x_label}" + + (f" (grouped by {'+'.join(legend_cols)})" if legend_cols else "") + ) + print("=" * (label_width + bar_width + 20)) + print() + + # Print legend if there are multiple series + if len(legend_keys) > 1 or (len(legend_keys) == 1 and legend_keys[0] != ("",)): + print("Legend:") + for key in legend_keys: + display_name = rename_map[key] + color = color_map[key] + pattern = pattern_map[key] + color_code = COLORS.get(color, "") + reset_code = COLORS["reset"] + # Show pattern in legend if applicable + legend_char = PATTERNS.get(pattern, "█") if pattern else "█" + print(f" {color_code}{legend_char}{reset_code} {display_name}") + print() + + # Sort x-axis keys + sorted_x_keys = sorted(data.keys()) + + # Print bars + for x_key in sorted_x_keys: + x_label = format_tuple(x_key) + print(f"{x_label:<{label_width}} │", end="") + + legend_dict = data[x_key] + + # If multiple legend keys, stack them horizontally with separators + if len(legend_keys) > 1: + print() + for legend_key in legend_keys: + value = legend_dict.get(legend_key, 0) + if ( + value > 0 or len(legend_keys) <= 3 + ): # Show empty bars if few categories + color = color_map[legend_key] + pattern = pattern_map[legend_key] + bar = draw_bar(value, max_value, bar_width, color, pattern) + legend_label = rename_map[legend_key] + value_str = f" {value:.2f}" if show_values else "" + print(f"{' ' * label_width} │ {bar}{value_str}") + else: + # Single series - show inline + legend_key = legend_keys[0] if legend_keys else ("",) + value = legend_dict.get(legend_key, 0) + color = color_map.get(legend_key, "white") + pattern = pattern_map.get(legend_key) + bar = draw_bar(value, max_value, bar_width, color, pattern) + value_str = f" {value:.2f}" if show_values else "" + print(f" {bar}{value_str}") + + print() + print(f"Max value: {max_value:.2f}") + print() + + +def plot_bars_vertical( + data: dict[tuple, dict[tuple, float]], + x_cols: list[str], + legend_cols: list[str] | None, + value_col: str, + custom_colors: dict[str, str], + custom_patterns: dict[str, str], + renames: dict[str, str], + bar_height: int = 20, + show_values: bool = True, +): + """Generate and print vertical terminal bar plot.""" + + # Get all legend keys and assign colors, patterns, and renames + all_legend_keys = set() + for legend_dict in data.values(): + all_legend_keys.update(legend_dict.keys()) + legend_keys = sorted(all_legend_keys) + color_map = assign_colors(legend_keys, custom_colors) + pattern_map = assign_patterns(legend_keys, custom_patterns) + rename_map = apply_renames(legend_keys, renames) + + # Find max value for scaling + max_value = 0 + for legend_dict in data.values(): + for value in legend_dict.values(): + max_value = max(max_value, value) + + # Print title + x_label = "+".join(x_cols) + print( + f"\n{value_col} by {x_label}" + + (f" (grouped by {'+'.join(legend_cols)})" if legend_cols else "") + ) + print("=" * 80) + print() + + # Print legend if there are multiple series + if len(legend_keys) > 1 or (len(legend_keys) == 1 and legend_keys[0] != ("",)): + print("Legend:") + for key in legend_keys: + display_name = rename_map[key] + color = color_map[key] + pattern = pattern_map[key] + color_code = COLORS.get(color, "") + reset_code = COLORS["reset"] + legend_char = PATTERNS.get(pattern, "█") if pattern else "█" + print(f" {color_code}{legend_char}{reset_code} {display_name}") + print() + + # Sort x-axis keys + sorted_x_keys = sorted(data.keys()) + + # Calculate column width for labels + max_label_width = max(len(format_tuple(k)) for k in sorted_x_keys) + + # For each legend key, we need a column (or group of columns if showing values) + num_x_positions = len(sorted_x_keys) + num_series = len(legend_keys) + + # Column width should accommodate both the label and the bars + # When multiple series, each series gets 1 character, plus extra spacing between groups + min_width_for_bars = num_series * 2 if num_series > 1 else 3 + col_width = max( + max_label_width, min_width_for_bars, 10 + ) # At least 10 chars for spacing + + # Print the bars from top to bottom + for row in range(bar_height, -1, -1): + line_parts = [] + for x_key in sorted_x_keys: + legend_dict = data[x_key] + + # Determine which legend keys to show + if len(legend_keys) == 1: + # Single series + legend_key = legend_keys[0] + value = legend_dict.get(legend_key, 0) + ratio = value / max_value if max_value > 0 else 0 + height = ratio * bar_height + + color = color_map[legend_key] + pattern = pattern_map[legend_key] + color_code = COLORS.get(color, "") + reset_code = COLORS["reset"] + + if row == 0: + # Bottom row - show x-axis label + label = format_tuple(x_key) + line_parts.append(f"{label:^{col_width}}") + elif row <= height: + # Show bar character + # Calculate which block character to use + if abs(row - height) < 1 and height % 1 > 0: + # Partial block at the top + block_idx = int((height % 1) * 8) + char = ( + VERTICAL_BLOCKS[block_idx] + if not pattern + else PATTERNS.get(pattern, "█") + ) + else: + # Full block + char = ( + VERTICAL_BLOCKS[-1] + if not pattern + else PATTERNS.get(pattern, "█") + ) + # Center the character manually to avoid color code length issues + padding = (col_width - 1) // 2 + line_parts.append( + " " * padding + + f"{color_code}{char}{reset_code}" + + " " * (col_width - 1 - padding) + ) + else: + # Empty space + line_parts.append(" " * col_width) + else: + # Multiple series - show them side by side + chars = [] + for legend_key in legend_keys: + value = legend_dict.get(legend_key, 0) + ratio = value / max_value if max_value > 0 else 0 + height = ratio * bar_height + + color = color_map[legend_key] + pattern = pattern_map[legend_key] + color_code = COLORS.get(color, "") + reset_code = COLORS["reset"] + + if row == 0: + continue # Handle labels separately + if row <= height: + if abs(row - height) < 1 and height % 1 > 0: + block_idx = int((height % 1) * 8) + char = ( + VERTICAL_BLOCKS[block_idx] + if not pattern + else PATTERNS.get(pattern, "█") + ) + else: + char = ( + VERTICAL_BLOCKS[-1] + if not pattern + else PATTERNS.get(pattern, "█") + ) + chars.append(f"{color_code}{char}{reset_code}") + else: + chars.append(" ") + + if row == 0: + label = format_tuple(x_key) + line_parts.append(f"{label:^{col_width}}") + else: + # Add spacing between bars within a group + bars_with_spacing = " ".join(chars) + # Calculate the visual width (number of actual characters, not counting color codes) + visual_width = len(legend_keys) + ( + len(legend_keys) - 1 + ) # chars + spaces between + padding = (col_width - visual_width) // 2 + line_parts.append( + " " * padding + + bars_with_spacing + + " " * (col_width - visual_width - padding) + ) + + print(" ".join(line_parts)) + + # Print values if requested + if show_values: + print() + for legend_key in legend_keys: + display_name = rename_map[legend_key] + values_line = [display_name[:col_width].ljust(col_width)] + for x_key in sorted_x_keys: + value = data[x_key].get(legend_key, 0) + values_line.append(f"{value:.2f}".center(col_width)) + print(" ".join(values_line)) + + print() + print(f"Max value: {max_value:.2f}") + print() + + +def parse_color_spec(color_spec: str) -> dict[str, str]: + """Parse color specification like 'key1=red,key2=blue'.""" + colors = {} + if not color_spec: + return colors + + for pair in color_spec.split(","): + if "=" not in pair: + continue + key, color = pair.split("=", 1) + colors[key.strip()] = color.strip() + + return colors + + +def parse_pattern_spec(pattern_spec: str) -> dict[str, str]: + """Parse pattern specification like 'key1=/,key2=\\'.""" + patterns = {} + if not pattern_spec: + return patterns + + for pair in pattern_spec.split(","): + if "=" not in pair: + continue + key, pattern = pair.split("=", 1) + patterns[key.strip()] = pattern.strip() + + return patterns + + +def parse_rename_spec(rename_spec: str) -> dict[str, str]: + """Parse rename specification like 'old1=new1,old2=new2'.""" + renames = {} + if not rename_spec: + return renames + + for pair in rename_spec.split(","): + if "=" not in pair: + continue + old_name, new_name = pair.split("=", 1) + renames[old_name.strip()] = new_name.strip() + + return renames + + +def parse_exclude_spec(exclude_spec: str) -> list[str]: + """Parse exclude specification like 'series1,series2,series3'.""" + if not exclude_spec: + return [] + return [name.strip() for name in exclude_spec.split(",")] + + +def parse_combine_spec(combine_spec: str) -> list[tuple[str, list[str], str]]: + """ + Parse combine specification like 'new_name=series1+series2:min,another=s3+s4:max'. + Returns: List of (new_name, [series_to_combine], aggregation_method) + """ + combinations = [] + if not combine_spec: + return combinations + + for combo in combine_spec.split(","): + if "=" not in combo: + continue + + new_name, rest = combo.split("=", 1) + new_name = new_name.strip() + + # Check if aggregation method is specified + if ":" in rest: + series_part, agg_method = rest.rsplit(":", 1) + agg_method = agg_method.strip().lower() + if agg_method not in ["min", "max", "mean", "sum"]: + print( + f"Warning: Unknown aggregation method '{agg_method}', using 'mean'", + file=sys.stderr, + ) + agg_method = "mean" + else: + series_part = rest + agg_method = "mean" # Default to mean + + # Parse series names (split by +) + series_names = [s.strip() for s in series_part.split("+")] + + if len(series_names) < 2: + print( + f"Warning: Combination '{combo}' needs at least 2 series to combine", + file=sys.stderr, + ) + continue + + combinations.append((new_name, series_names, agg_method)) + + return combinations + + +def main(): + parser = argparse.ArgumentParser( + description="Generate terminal bar plots from CSV files", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s data.csv --x model --value accuracy + %(prog)s data.csv --x model --legend dataset --value score + %(prog)s data.csv --x model,size --legend dataset --value time --colors gpt=blue,llama=red + %(prog)s data.csv --x model --value throughput --width 60 --no-values + %(prog)s data.csv --x model --legend dataset --value score --vertical + %(prog)s data.csv --x model --legend dataset --value score --patterns dataset1=/,dataset2=\\ + %(prog)s data.csv --x model --legend dataset --value score --rename old_name=New Name + %(prog)s data.csv --x model --legend dataset --value score --exclude unwanted_series + %(prog)s data.csv --x model --legend dataset --value score --combine "Best=A+B:min,Worst=C+D:max" + +Available colors: red, green, yellow, blue, magenta, cyan, white, gray +Available patterns: /, \\, x, +, -, |, ., o, *, #, = +Aggregation methods for --combine: min, max, mean, sum (default: mean) + """, + ) + + parser.add_argument("csv_file", help="CSV file to read") + parser.add_argument( + "--x", + required=True, + help="Comma-separated column(s) for x-axis (creates tuple if multiple)", + ) + parser.add_argument( + "--legend", + help="Comma-separated column(s) for legend/color grouping (creates tuple if multiple)", + ) + parser.add_argument( + "--value", required=True, help="Column containing the numeric values to plot" + ) + parser.add_argument( + "--colors", help="Custom color mapping: key1=color1,key2=color2,..." + ) + parser.add_argument( + "--patterns", + help="Custom pattern mapping: key1=pattern1,key2=pattern2,... (e.g., key1=/,key2=\\)", + ) + parser.add_argument( + "--rename", help="Rename legend entries: old1=new1,old2=new2,..." + ) + parser.add_argument( + "--exclude", help="Exclude series from plot: series1,series2,..." + ) + parser.add_argument( + "--combine", + help="Combine series: new_name=series1+series2:agg_method (agg_method: min, max, mean, sum). Example: Combined=A+B:min,Average=C+D:mean", + ) + parser.add_argument( + "--width", + type=int, + default=50, + help="Width of bars in characters for horizontal bars (default: 50)", + ) + parser.add_argument( + "--height", + type=int, + default=20, + help="Height of bars in characters for vertical bars (default: 20)", + ) + parser.add_argument( + "--vertical", + action="store_true", + help="Create vertical bars instead of horizontal", + ) + parser.add_argument( + "--no-values", action="store_true", help="Hide numeric values next to bars" + ) + + args = parser.parse_args() + + # Parse column specifications + x_cols = [col.strip() for col in args.x.split(",")] + legend_cols = ( + [col.strip() for col in args.legend.split(",")] if args.legend else None + ) + value_col = args.value + custom_colors = parse_color_spec(args.colors) + custom_patterns = parse_pattern_spec(args.patterns) + renames = parse_rename_spec(args.rename) + exclude_list = parse_exclude_spec(args.exclude) + combinations = parse_combine_spec(args.combine) + + # Load and process data + try: + headers, rows = parse_csv(args.csv_file) + except FileNotFoundError: + print(f"Error: File '{args.csv_file}' not found", file=sys.stderr) + sys.exit(1) + except Exception as e: + print(f"Error reading CSV: {e}", file=sys.stderr) + sys.exit(1) + + # Validate columns + all_cols = set(x_cols + (legend_cols or []) + [value_col]) + missing_cols = all_cols - set(headers) + if missing_cols: + print( + f"Error: Columns not found in CSV: {', '.join(missing_cols)}", + file=sys.stderr, + ) + print(f"Available columns: {', '.join(headers)}", file=sys.stderr) + sys.exit(1) + + # Aggregate and process data + data = aggregate_data(rows, x_cols, legend_cols, value_col) + + # Apply exclusions + data = apply_exclusions(data, exclude_list) + + # Apply combinations (returns data and map of which series used special aggregation) + data, combination_agg_map = apply_combinations(data, combinations) + + # Aggregate values (sum is default for backward compatibility) + # Note: combined series have already been aggregated by their specified method + data = aggregate_values(data, agg_method="sum") + + if not data: + print("No data to plot", file=sys.stderr) + sys.exit(1) + + if args.vertical: + plot_bars_vertical( + data, + x_cols, + legend_cols, + value_col, + custom_colors, + custom_patterns, + renames, + bar_height=args.height, + show_values=not args.no_values, + ) + else: + plot_bars( + data, + x_cols, + legend_cols, + value_col, + custom_colors, + custom_patterns, + renames, + bar_width=args.width, + show_values=not args.no_values, + ) + + +if __name__ == "__main__": + main()