Skip to content

[BUG] T.atomic_add cannot compile in some cases #1885

@ColmaLiu

Description

@ColmaLiu

Required prerequisites

What version of TileLang are you using?

0.1.8+cuda.git2a85f09d

System information

3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] linux
0.1.8+cuda.git2a85f09d
2.10.0+cu128

Problem description

T.atomic_add fails to compile when generating AtomicAddx2, and the generated kernel code does not perform type casting.

Reproducible example code

The Python snippets:

import tilelang
from tilelang import language as T

N = 2

@tilelang.jit()
def tl_matmul_streamk():
    C = T.empty((N,), T.float16)
    with T.Kernel(threads=1):
        C_local = T.alloc_fragment((N,), T.float32)
        for i in T.Parallel(N):
            T.atomic_add(C[i], C_local[i])
    return C

def main():
    C = tl_matmul_streamk()
    print(C)
    print(tl_matmul_streamk.get_kernel_source())
main()

Traceback

Traceback (most recent call last):
  File "/home/lyn/workspace/tilelang/examples/gemm_streamk/minimized.py", line 19, in <module>
    main()
  File "/home/lyn/workspace/tilelang/examples/gemm_streamk/minimized.py", line 16, in main
    C = tl_matmul_streamk()
        ^^^^^^^^^^^^^^^^^^^
  File "/home/lyn/workspace/tilelang/tilelang/jit/__init__.py", line 448, in __call__
    kernel = self.compile(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lyn/workspace/tilelang/tilelang/jit/__init__.py", line 378, in compile
    kernel_result = compile(
                    ^^^^^^^^
  File "/home/lyn/workspace/tilelang/tilelang/jit/__init__.py", line 98, in compile
    return cached(
           ^^^^^^^
  File "/home/lyn/workspace/tilelang/tilelang/cache/__init__.py", line 74, in cached
    return _dispatch_map[execution_backend].cached(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lyn/workspace/tilelang/tilelang/cache/kernel_cache.py", line 264, in cached
    kernel = JITKernel(
             ^^^^^^^^^^
  File "/home/lyn/workspace/tilelang/tilelang/jit/kernel.py", line 137, in __init__
    adapter = self._compile_and_create_adapter(func, out_idx)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lyn/workspace/tilelang/tilelang/jit/kernel.py", line 241, in _compile_and_create_adapter
    artifact = tilelang.lower(
               ^^^^^^^^^^^^^^^
  File "/home/lyn/workspace/tilelang/tilelang/engine/lower.py", line 269, in lower
    codegen_mod = device_codegen(device_mod, target) if enable_device_compile else device_codegen_without_compile(device_mod, target)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lyn/workspace/tilelang/tilelang/engine/lower.py", line 188, in device_codegen
    device_mod = tvm.ffi.get_global_func(global_func)(device_mod, target)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python/tvm_ffi/cython/function.pxi", line 923, in tvm_ffi.core.Function.__call__
  File "<unknown>", line 0, in tvm::codegen::BuildTileLangCUDA(tvm::IRModule, tvm::Target)
  File "python/tvm_ffi/cython/function.pxi", line 1077, in tvm_ffi.core.tvm_ffi_callback
  File "/home/lyn/workspace/tilelang/tilelang/engine/lower.py", line 103, in tilelang_callback_cuda_compile
    ptx = nvcc.compile_cuda(

  File "/home/lyn/workspace/tilelang/tilelang/contrib/nvcc.py", line 114, in compile_cuda
    raise RuntimeError(msg)

RuntimeError: #include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
#ifdef ENABLE_BF16
#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>
#endif

extern "C" __global__ void tl_matmul_streamk_kernel(half_t* __restrict__ C);
extern "C" __global__ void tl_matmul_streamk_kernel(half_t* __restrict__ C) {
  float C_local[2];
  AtomicAddx2((&(C[0])), *(float2*)(C_local + 0));
}


Compilation error:
/home/lyn/workspace/tilelang/3rdparty/../src/tl_templates/cuda/./instruction/../common.h(554): warning #20012-D: __device__ annotation is ignored on a function("float_e4m3_t") that is explicitly defaulted on its first declaration
    __inline__ __attribute__((always_inline)) __attribute__((device)) __attribute__((host))
                                                             ^

Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"

/home/lyn/workspace/tilelang/3rdparty/../src/tl_templates/cuda/./instruction/../common.h(554): warning #20012-D: __host__ annotation is ignored on a function("float_e4m3_t") that is explicitly defaulted on its first declaration
    __inline__ __attribute__((always_inline)) __attribute__((device)) __attribute__((host))
                                                                                     ^

/home/lyn/workspace/tilelang/3rdparty/../src/tl_templates/cuda/./instruction/../common.h(568): warning #20012-D: __device__ annotation is ignored on a function("float_e5m2_t") that is explicitly defaulted on its first declaration
    __inline__ __attribute__((always_inline)) __attribute__((device)) __attribute__((host))
                                                             ^

/home/lyn/workspace/tilelang/3rdparty/../src/tl_templates/cuda/./instruction/../common.h(568): warning #20012-D: __host__ annotation is ignored on a function("float_e5m2_t") that is explicitly defaulted on its first declaration
    __inline__ __attribute__((always_inline)) __attribute__((device)) __attribute__((host))
                                                                                     ^

/tmp/tmpbq6ygcys/tvm_kernels.cu(14): error: no instance of overloaded function "AtomicAddx2" matches the argument list
            argument types are: (cutlass::half_t *, float2)
    AtomicAddx2((&(C[0])), *(float2*)(C_local + 0));
    ^

1 error detected in the compilation of "/tmp/tmpbq6ygcys/tvm_kernels.cu".

Command: /usr/local/cuda/bin/nvcc --cubin -O3 -lineinfo -arch=sm_90a -std=c++17 -I/home/lyn/workspace/tilelang/3rdparty/../src -I/home/lyn/workspace/tilelang/3rdparty/cutlass/include -o /tmp/tmpbq6ygcys/tvm_kernels.cubin /tmp/tmpbq6ygcys/tvm_kernels.cu

Expected behavior

I think it should generate code like below:

#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
#ifdef ENABLE_BF16
#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>
#endif

extern "C" __global__ void tl_matmul_streamk_kernel(half_t* __restrict__ C);
extern "C" __global__ void tl_matmul_streamk_kernel(half_t* __restrict__ C) {
  float C_local[2];
  half_t C_local_cast[2];
  uint1 __1;
  float2 v_ = *(float2*)(C_local + 0);
  ((half2*)(&__1))[0] = __float22half2_rn(((float2*)(&v_))[0]);
  *(uint1*)(C_local_cast + 0) = __1;
  AtomicAddx2((&(C[0])), (uint1*)(C_local_cast + 0));
}

Additional context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions