Skip to content

Conversation

@vickiw973
Copy link

No description provided.

@vickiw973
Copy link
Author

The tests can pass with the latest CuTe DSL release

pip3 install --upgrade nvidia-cutlass-dsl --pre

For NVFP4 GEMV (using FFMA to simulate the computation logic)

reference-kernels/problems/nvidia/nvfp4_gemv> python3 eval.py test task.yml 
compile: start
compile: pass
test-count: 10
test.0.spec: m: 128; k: 256; l: 1; seed: 1111
test.0.status: pass
test.1.spec: m: 128; k: 1536; l: 1; seed: 1111
test.1.status: pass
test.2.spec: m: 128; k: 3072; l: 1; seed: 1111
test.2.status: pass
test.3.spec: m: 256; k: 7168; l: 1; seed: 1111
test.3.status: pass
test.4.spec: m: 256; k: 7168; l: 1; seed: 1111
test.4.status: pass
test.5.spec: m: 2432; k: 4608; l: 2; seed: 1111
test.5.status: pass
test.6.spec: m: 384; k: 7168; l: 2; seed: 1111
test.6.status: pass
test.7.spec: m: 512; k: 512; l: 2; seed: 1111
test.7.status: pass
test.8.spec: m: 512; k: 4096; l: 2; seed: 1111
test.8.status: pass
test.9.spec: m: 512; k: 1536; l: 2; seed: 1111
test.9.status: pass
check: pass

For NVFP4 GEMM (using tensor-core)

reference-kernels/problems/nvidia/nvfp4_gemm> python3 eval.py test task.yml
compile: start
compile: pass
test-count: 10
test.0.status: pass
test.1.spec: m: 128; n: 1536; k: 7168; l: 1; seed: 1111
test.1.status: pass
test.2.spec: m: 128; n: 3072; k: 1536; l: 1; seed: 1111
test.2.status: pass
test.3.spec: m: 256; n: 7168; k: 256; l: 1; seed: 1111
test.3.status: pass
test.4.spec: m: 256; n: 7168; k: 2048; l: 1; seed: 1111
test.4.status: pass
test.5.spec: m: 2304; n: 4608; k: 7168; l: 1; seed: 1111
test.5.status: pass
test.6.spec: m: 384; n: 7168; k: 2304; l: 1; seed: 1111
test.6.status: pass
test.7.spec: m: 512; n: 512; k: 7168; l: 1; seed: 1111
test.7.status: pass
test.8.spec: m: 512; n: 4096; k: 512; l: 1; seed: 1111
test.8.status: pass
test.9.spec: m: 512; n: 1536; k: 7168; l: 1; seed: 1111
test.9.status: pass
check: pass

For NVFP4 dual_gemm(using tensor-core)

reference-kernels/problems/nvidia/nvfp4_dual_gemm>  python3 eval.py test task.yml
compile: start
compile: pass
test-count: 10
test.0.spec: m: 128; n: 256; k: 256; l: 1; seed: 1111
test.0.status: pass
test.1.spec: m: 128; n: 1536; k: 7168; l: 1; seed: 1111
test.1.status: pass
test.2.spec: m: 128; n: 3072; k: 1536; l: 1; seed: 1111
test.2.status: pass
test.3.spec: m: 256; n: 7168; k: 256; l: 1; seed: 1111
test.3.status: pass
test.4.spec: m: 256; n: 7168; k: 2048; l: 1; seed: 1111
test.4.status: pass
test.5.spec: m: 2304; n: 4608; k: 7168; l: 1; seed: 1111
test.5.status: pass
test.6.spec: m: 384; n: 7168; k: 2304; l: 1; seed: 1111
test.6.status: pass
test.7.spec: m: 512; n: 512; k: 7168; l: 1; seed: 1111
test.7.status: pass
test.8.spec: m: 512; n: 4096; k: 512; l: 1; seed: 1111
test.8.status: pass
test.9.spec: m: 512; n: 1536; k: 7168; l: 1; seed: 1111
test.9.status: pass
check: pass

For NVFP4 group gemm(using tensor-core)

reference-kernels/problems/nvidia/nvfp4_group_gemm>
compile: start
compile: pass
test-count: 10
test.0.spec: m: 128; n: 256; k: 512; g: 8; seed: 1111
test.0.status: pass
test.1.spec: m: 128; n: 256; k: 512; g: 2; seed: 1111
test.1.status: pass
test.2.spec: m: 128; n: 384; k: 640; g: 3; seed: 1111
test.2.status: pass
test.3.spec: m: 256; n: 384; k: 640; g: 4; seed: 1111
test.3.status: pass
test.4.spec: m: 256; n: 512; k: 384; g: 2; seed: 1111
test.4.status: pass
test.5.spec: m: 384; n: 512; k: 384; g: 2; seed: 1111
test.5.status: pass
test.6.spec: m: 384; n: 640; k: 512; g: 2; seed: 1111
test.6.status: pass
test.7.spec: m: 256; n: 640; k: 128; g: 8; seed: 1111
test.7.status: pass
test.8.spec: m: 512; n: 768; k: 256; g: 5; seed: 1111
test.8.status: pass
test.9.spec: m: 512; n: 768; k: 768; g: 3; seed: 1111
test.9.status: pass
check: pass

# Slice to per mma tile index
#
# ((atom_v, rest_v), RestK)
tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename tAgA_mkl

# Wait for AB buffer empty
ab_empty = ab_producer.acquire_and_advance()

# TMALDG A/B1/B2/SFA/SFB1/SFB2
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change comment to TMA load

tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32
)
# (T2R_M, T2R_N, EPI_M, EPI_N)
tTR_rAcc2 = cute.make_fragment(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace this with make_rmem_tensor_like


# Release tensor memory allocation lock
if warp_idx == 0:
cute.arch.relinquish_tmem_alloc_permit()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this.

m, n, k, l = problem_size

# Setup attributes that depend on gemm inputs
cta_tile_shape_mnk = (
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove cta_tile_shape_mnk and directly use mma_tiler_mnk

)
atom_thr_size = cute.size(tiled_mma.thr_id.shape)

# TMA load for A
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace with setup TMA for A

key_size_b = lambda item: item[1][1] * item[1][2]
key_size_c = lambda item: item[1][0] * item[1][1]
# Find the indices of the groups with the smallest tensor sizes
min_a_idx, _ = min(enumerate(problem_sizes), key=key_size_a)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this logic and use random shapes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants