diff --git a/.gitignore b/.gitignore index 10550772..6a343dc6 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,5 @@ web/**/public tmp/ AGENTS.md + +.env diff --git a/examples/ffi/Example-FlashInfer-Trace/definitions/gemm_n4096_k4096.json b/examples/ffi/Example-FlashInfer-Trace/definitions/gemm_n4096_k4096.json new file mode 100644 index 00000000..b3b4f616 --- /dev/null +++ b/examples/ffi/Example-FlashInfer-Trace/definitions/gemm_n4096_k4096.json @@ -0,0 +1,48 @@ +{ + "name": "gemm_n4096_k4096", + "description": "General matrix multiply (GEMM) C = A @ B.T. Captured from Llama 3.1 8B attn.o_proj.", + "op_type": "gemm", + "tags": [ + "status:verified", + "model:llama-3.1-8b" + ], + "axes": { + "M": { + "type": "var" + }, + "N": { + "type": "const", + "value": 4096 + }, + "K": { + "type": "const", + "value": 4096 + } + }, + "inputs": { + "A": { + "shape": [ + "M", + "K" + ], + "dtype": "float16" + }, + "B": { + "shape": [ + "N", + "K" + ], + "dtype": "float16" + } + }, + "outputs": { + "C": { + "shape": [ + "M", + "N" + ], + "dtype": "float16" + } + }, + "reference": "import torch\n\ndef run(A, B):\n C = torch.matmul(A, B.T)\n return C" +} diff --git a/examples/ffi/Example-FlashInfer-Trace/solutions/example_agent_solution.json b/examples/ffi/Example-FlashInfer-Trace/solutions/example_agent_solution.json new file mode 100644 index 00000000..da154e15 --- /dev/null +++ b/examples/ffi/Example-FlashInfer-Trace/solutions/example_agent_solution.json @@ -0,0 +1,24 @@ +{ + "name": "example_agent_solution", + "definition": "gemm_n4096_k4096", + "description": "example agent vibecoded kernel generated by gpt-5-2025-08-07 (reasoning effort: high)", + "author": "gpt-5-2025-08-07", + "spec": { + "language": "cuda", + "target_hardware": [ + "B200" + ], + "entry_point": "kernel.cu::gemm_n_4096_k_4096", + "dependencies": [] + }, + "sources": [ + { + "path": "kernel.h", + "content": "#ifndef GEMM_N_4096_K_4096_KERNEL_H\n#define GEMM_N_4096_K_4096_KERNEL_H\n\n#include \n#include \n#include \n\n// Constants fixed by specification\nconstexpr int GEMM_N_CONST = 4096;\nconstexpr int GEMM_K_CONST = 4096;\n\n// Utility: ceiling division\ninline int ceil_div(int a, int b) { return (a + b - 1) / b; }\n\n#endif // GEMM_N_4096_K_4096_KERNEL_H" + }, + { + "path": "kernel.cu", + "content": "#include \"kernel.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace nvcuda;\n\n// Error check macro\n#ifndef CUDA_CHECK\n#define CUDA_CHECK(expr) \\\n do { \\\n cudaError_t _err = (expr); \\\n if (_err != cudaSuccess) { \\\n fprintf(stderr, \"CUDA Error %s at %s:%d: %s\\n\", #expr, __FILE__, __LINE__, cudaGetErrorString(_err)); \\\n abort(); \\\n } \\\n } while (0)\n#endif\n\n// Kernel configuration tuned for B200\n// - Block tile: 128 x 256 (M x N)\n// - K tile: 64\n// - 8 warps per block (256 threads), each warp computes a 64x64 sub-tile via WMMA (4x4 tiles of 16x16)\n// - Accumulate in FP32, convert to FP16 on store\nconstexpr int BLOCK_M = 128;\nconstexpr int BLOCK_N = 256;\nconstexpr int BLOCK_K = 64;\n\nconstexpr int WARPS_PER_BLOCK = 8;\nconstexpr int THREADS_PER_BLOCK = WARPS_PER_BLOCK * 32;\n\nconstexpr int WARP_TILE_M = 64;\nconstexpr int WARP_TILE_N = 64;\n\nconstexpr int WMMA_M = 16;\nconstexpr int WMMA_N = 16;\nconstexpr int WMMA_K = 16;\n\n// Padding to avoid shared memory bank conflicts (in elements)\nconstexpr int SKEW_HALF = 8; // for half elements\nconstexpr int SKEW_FLOAT = 8; // for float elements\n\n// Align pointer p up to 'alignment' bytes\n__device__ __forceinline__ char* align_up(char* p, size_t alignment) {\n uintptr_t ip = reinterpret_cast(p);\n ip = (ip + (alignment - 1)) & ~(alignment - 1);\n return reinterpret_cast(ip);\n}\n\n__global__ __launch_bounds__(THREADS_PER_BLOCK, 2)\nvoid gemm_n_4096_k_4096_kernel(const __half* __restrict__ A,\n const __half* __restrict__ B,\n __half* __restrict__ C,\n int M) {\n // Shared memory layout (dynamically allocated):\n // [A_smem (half) | B_smem (half) | C_smem (float)]\n extern __shared__ char smem_raw[];\n char* smem_ptr = smem_raw;\n\n // Compute sizes\n const int A_smem_elems = BLOCK_M * (BLOCK_K + SKEW_HALF);\n const int B_smem_elems = BLOCK_N * (BLOCK_K + SKEW_HALF);\n const int C_smem_elems = BLOCK_M * (BLOCK_N + SKEW_FLOAT);\n\n const size_t A_smem_bytes = A_smem_elems * sizeof(__half);\n const size_t B_smem_bytes = B_smem_elems * sizeof(__half);\n const size_t C_smem_bytes = C_smem_elems * sizeof(float);\n\n __half* A_smem = reinterpret_cast<__half*>(smem_ptr);\n smem_ptr = align_up(smem_ptr + A_smem_bytes, 16);\n __half* B_smem = reinterpret_cast<__half*>(smem_ptr);\n smem_ptr = align_up(smem_ptr + B_smem_bytes, 16);\n float* C_smem = reinterpret_cast(smem_ptr);\n\n // Block coordinates\n const int block_m = blockIdx.y; // along M\n const int block_n = blockIdx.x; // along N\n const int m0 = block_m * BLOCK_M;\n const int n0 = block_n * BLOCK_N;\n\n // Early exit if out of range (shouldn't happen due to gridDim.y, but guard anyway)\n if (m0 >= M) return;\n\n // Global strides (row-major)\n const int lda = GEMM_K_CONST; // 4096\n const int ldb = GEMM_K_CONST; // 4096\n const int ldc = GEMM_N_CONST; // 4096\n\n // Thread identifiers\n const int tid = threadIdx.x;\n const int warp_id = tid / 32;\n const int lane_id = tid % 32;\n\n // Warp tile coordinates within the block\n const int WARPS_N = BLOCK_N / WARP_TILE_N; // 256/64 = 4\n const int warp_m_tile = warp_id / WARPS_N; // 0..1\n const int warp_n_tile = warp_id % WARPS_N; // 0..3\n\n // Initialize accumulators\n wmma::fragment c_frag[WARP_TILE_M / WMMA_M][WARP_TILE_N / WMMA_N];\n#pragma unroll\n for (int i = 0; i < (WARP_TILE_M / WMMA_M); ++i) {\n#pragma unroll\n for (int j = 0; j < (WARP_TILE_N / WMMA_N); ++j) {\n wmma::fill_fragment(c_frag[i][j], 0.0f);\n }\n }\n\n // Loop over K dimension in tiles of BLOCK_K\n for (int k0 = 0; k0 < GEMM_K_CONST; k0 += BLOCK_K) {\n\n // Load A tile into shared memory: [BLOCK_M x BLOCK_K] with stride (BLOCK_K + SKEW_HALF)\n {\n const int total_vec = (BLOCK_M * BLOCK_K) / 8; // 1024\n#pragma unroll\n for (int v = 0; v < (total_vec / THREADS_PER_BLOCK); ++v) {\n const int vec_idx = tid + v * THREADS_PER_BLOCK;\n const int elem_idx = vec_idx * 8;\n const int row = elem_idx / BLOCK_K;\n const int col = elem_idx % BLOCK_K;\n const int g_row = m0 + row;\n const int g_col = k0 + col;\n\n const __half* gptr = A + g_row * lda + g_col;\n int4 data;\n\n if (g_row < M) {\n data = *reinterpret_cast(gptr);\n } else {\n data = {0, 0, 0, 0};\n }\n\n __half* sptr = A_smem + row * (BLOCK_K + SKEW_HALF) + col;\n *reinterpret_cast(sptr) = data;\n }\n }\n\n // Load B tile into shared memory as [BLOCK_N x BLOCK_K] row-major with stride (BLOCK_K + SKEW_HALF)\n {\n const int total_vec = (BLOCK_N * BLOCK_K) / 8; // 2048\n#pragma unroll\n for (int v = 0; v < (total_vec / THREADS_PER_BLOCK); ++v) {\n const int vec_idx = tid + v * THREADS_PER_BLOCK;\n const int elem_idx = vec_idx * 8;\n const int n = elem_idx / BLOCK_K;\n const int kk = elem_idx % BLOCK_K;\n\n const __half* gptr = B + (n0 + n) * ldb + (k0 + kk);\n int4 data = *reinterpret_cast(gptr);\n\n __half* sptr = B_smem + n * (BLOCK_K + SKEW_HALF) + kk;\n *reinterpret_cast(sptr) = data;\n }\n }\n\n __syncthreads();\n\n // Compute using WMMA over BLOCK_K split into 16-wide k-steps\n#pragma unroll\n for (int kk = 0; kk < BLOCK_K; kk += WMMA_K) {\n // Preload 4 B fragments for this warp (across N within the warp tile)\n wmma::fragment b_frag[WARP_TILE_N / WMMA_N];\n#pragma unroll\n for (int j = 0; j < (WARP_TILE_N / WMMA_N); ++j) {\n const int n_off = warp_n_tile * WARP_TILE_N + j * WMMA_N;\n const __half* b_tile_ptr = B_smem + n_off * (BLOCK_K + SKEW_HALF) + kk;\n wmma::load_matrix_sync(b_frag[j], b_tile_ptr, (BLOCK_K + SKEW_HALF));\n }\n\n // For each of 4 A subtiles in M within the warp tile, multiply with 4 B fragments\n#pragma unroll\n for (int i = 0; i < (WARP_TILE_M / WMMA_M); ++i) {\n const int m_off = warp_m_tile * WARP_TILE_M + i * WMMA_M;\n const __half* a_tile_ptr = A_smem + m_off * (BLOCK_K + SKEW_HALF) + kk;\n\n wmma::fragment a_frag;\n wmma::load_matrix_sync(a_frag, a_tile_ptr, (BLOCK_K + SKEW_HALF));\n\n#pragma unroll\n for (int j = 0; j < (WARP_TILE_N / WMMA_N); ++j) {\n wmma::mma_sync(c_frag[i][j], a_frag, b_frag[j], c_frag[i][j]);\n }\n }\n }\n\n __syncthreads();\n }\n\n // Store accumulators to shared C_smem (float), then cooperatively convert/store to global as half\n#pragma unroll\n for (int i = 0; i < (WARP_TILE_M / WMMA_M); ++i) {\n#pragma unroll\n for (int j = 0; j < (WARP_TILE_N / WMMA_N); ++j) {\n const int row = warp_m_tile * WARP_TILE_M + i * WMMA_M;\n const int col = warp_n_tile * WARP_TILE_N + j * WMMA_N;\n float* c_tile_ptr = C_smem + row * (BLOCK_N + SKEW_FLOAT) + col;\n wmma::store_matrix_sync(c_tile_ptr, c_frag[i][j], (BLOCK_N + SKEW_FLOAT), wmma::mem_row_major);\n }\n }\n\n __syncthreads();\n\n // Cooperative conversion and store to global memory\n const int total_elems = BLOCK_M * BLOCK_N; // 32768\n#pragma unroll 4\n for (int idx = tid; idx < total_elems; idx += THREADS_PER_BLOCK) {\n const int row = idx / BLOCK_N;\n const int col = idx % BLOCK_N;\n const int g_row = m0 + row;\n const int g_col = n0 + col;\n\n if (g_row < M) {\n float val = C_smem[row * (BLOCK_N + SKEW_FLOAT) + col];\n __half h = __float2half_rn(val);\n C[g_row * ldc + g_col] = h;\n }\n }\n}\n\n// TVM FFI binding function\nvoid gemm_n_4096_k_4096(tvm::ffi::TensorView A, tvm::ffi::TensorView B, tvm::ffi::TensorView C) {\n // Validate inputs\n TVM_FFI_ICHECK_EQ(A.ndim(), 2) << \"A must be 2D [M, 4096]\";\n TVM_FFI_ICHECK_EQ(B.ndim(), 2) << \"B must be 2D [4096, 4096]\";\n TVM_FFI_ICHECK_EQ(C.ndim(), 2) << \"C must be 2D [M, 4096]\";\n \n TVM_FFI_ICHECK_EQ(A.size(1), GEMM_K_CONST) << \"A.shape[1] must be 4096 (K)\";\n TVM_FFI_ICHECK_EQ(B.size(0), GEMM_N_CONST) << \"B.shape[0] must be 4096 (N)\";\n TVM_FFI_ICHECK_EQ(B.size(1), GEMM_K_CONST) << \"B.shape[1] must be 4096 (K)\";\n \n const int64_t M = A.size(0);\n TVM_FFI_ICHECK_EQ(C.size(0), M) << \"C.shape[0] must match A.shape[0]\";\n TVM_FFI_ICHECK_EQ(C.size(1), GEMM_N_CONST) << \"C.shape[1] must be 4096 (N)\";\n \n // Check dtype\n DLDataType dt_a = A.dtype();\n DLDataType dt_b = B.dtype();\n DLDataType dt_c = C.dtype();\n \n if (dt_a.code != kDLFloat || dt_a.bits != 16) {\n TVM_FFI_THROW(TypeError) << \"A must be float16\";\n }\n if (dt_b.code != kDLFloat || dt_b.bits != 16) {\n TVM_FFI_THROW(TypeError) << \"B must be float16\";\n }\n if (dt_c.code != kDLFloat || dt_c.bits != 16) {\n TVM_FFI_THROW(TypeError) << \"C must be float16\";\n }\n \n // Check contiguous\n TVM_FFI_ICHECK(A.IsContiguous()) << \"A must be contiguous\";\n TVM_FFI_ICHECK(B.IsContiguous()) << \"B must be contiguous\";\n TVM_FFI_ICHECK(C.IsContiguous()) << \"C must be contiguous\";\n \n // Check device\n DLDevice dev = A.device();\n TVM_FFI_ICHECK_EQ(dev.device_type, kDLCUDA) << \"Tensors must be on CUDA device\";\n TVM_FFI_ICHECK_EQ(B.device().device_type, kDLCUDA) << \"Tensors must be on CUDA device\";\n TVM_FFI_ICHECK_EQ(C.device().device_type, kDLCUDA) << \"Tensors must be on CUDA device\";\n \n if (M <= 0) return;\n \n // Get data pointers\n const __half* A_ptr = reinterpret_cast(A.data_ptr());\n const __half* B_ptr = reinterpret_cast(B.data_ptr());\n __half* C_ptr = reinterpret_cast<__half*>(C.data_ptr());\n \n // Get CUDA stream from environment\n cudaStream_t stream = static_cast(\n TVMFFIEnvGetStream(dev.device_type, dev.device_id));\n \n // Launch configuration\n dim3 block(THREADS_PER_BLOCK, 1, 1);\n dim3 grid(GEMM_N_CONST / BLOCK_N, ceil_div(static_cast(M), BLOCK_M), 1);\n \n // Dynamic shared memory size\n const int A_smem_elems = BLOCK_M * (BLOCK_K + SKEW_HALF);\n const int B_smem_elems = BLOCK_N * (BLOCK_K + SKEW_HALF);\n const int C_smem_elems = BLOCK_M * (BLOCK_N + SKEW_FLOAT);\n \n const size_t shmem_bytes =\n A_smem_elems * sizeof(__half) +\n B_smem_elems * sizeof(__half) +\n C_smem_elems * sizeof(float);\n \n // Opt-in to large dynamic shared memory if needed\n CUDA_CHECK(cudaFuncSetAttribute(gemm_n_4096_k_4096_kernel,\n cudaFuncAttributeMaxDynamicSharedMemorySize,\n (int)shmem_bytes));\n \n gemm_n_4096_k_4096_kernel<<>>(A_ptr, B_ptr, C_ptr, static_cast(M));\n CUDA_CHECK(cudaGetLastError());\n}\n\n// Export the function with TVM FFI\nTVM_FFI_DLL_EXPORT_TYPED_FUNC(gemm_n_4096_k_4096, gemm_n_4096_k_4096);" + } + ] +} diff --git a/examples/ffi/Example-FlashInfer-Trace/workloads/gemm_n4096_k4096.jsonl b/examples/ffi/Example-FlashInfer-Trace/workloads/gemm_n4096_k4096.jsonl new file mode 100644 index 00000000..eb8cfa0a --- /dev/null +++ b/examples/ffi/Example-FlashInfer-Trace/workloads/gemm_n4096_k4096.jsonl @@ -0,0 +1,43 @@ +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "280860e6-08f0-427c-b7c5-9cffcfab1a10", "axes": {"M": 256}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "2e90109a-282e-484f-b94b-61f49e72fde2", "axes": {"M": 248}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "e7c939ae-2083-4b6f-a51a-8c76ffd08926", "axes": {"M": 240}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "3ab6479f-c1d0-4743-b2d4-1a46b01f9db7", "axes": {"M": 232}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "29581017-6470-4d78-9d02-554adecd9822", "axes": {"M": 224}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "e97cc8b3-9a2f-4d0d-aaa9-7522413c78da", "axes": {"M": 216}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "c69f1da6-7c62-46f5-867e-3cf5ed3aac04", "axes": {"M": 208}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "38a98eb2-3a86-41ce-994c-b6a2cec932a6", "axes": {"M": 200}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "62f7844c-b1f5-4e08-a057-70e55f092931", "axes": {"M": 192}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "1342c570-b505-4ccf-9fc2-b377d25b397e", "axes": {"M": 184}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "beb5da20-954c-44ef-8f1b-35a5dd848b06", "axes": {"M": 176}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "c8c9178e-65f2-4124-a322-e66987aa1b34", "axes": {"M": 168}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "c103a7db-34d0-4bc0-abb0-2833c2458c50", "axes": {"M": 160}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "ed0bbc00-57e9-46e0-af02-249a64a46fa0", "axes": {"M": 152}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "83deae78-6557-46e9-b3ef-2ed254192d13", "axes": {"M": 144}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "76fde2b6-3f6e-484f-bcd4-fca79272a690", "axes": {"M": 136}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "29ebd771-0f1c-4894-8532-7265275a02b1", "axes": {"M": 128}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "73dd121d-72c7-4dda-9b6e-37d55c6ed867", "axes": {"M": 120}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "27fa4c34-6e08-459a-a8c4-f37b7cdb037b", "axes": {"M": 112}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "faac134d-36e3-4f99-809c-9e544ea5216f", "axes": {"M": 104}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "1948bb55-6253-4b9c-aa17-4fb13ffef7d0", "axes": {"M": 96}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "3a7e6db4-4127-45c1-9e54-c8c7cc25d632", "axes": {"M": 88}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "dcb04a7e-2faa-4858-a495-71c658b299ad", "axes": {"M": 80}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "ef996a93-b3b6-4702-aae2-f28fdbcfdd48", "axes": {"M": 72}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "67d4c8f3-2ff5-4838-8f07-d5d16f602eb3", "axes": {"M": 64}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "54d34708-309d-462b-829a-74c90243093c", "axes": {"M": 56}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "59ca23f5-a523-4cc4-9c1f-db510753d3f4", "axes": {"M": 48}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "54062a8b-a9ca-47d8-b5ef-7f6f0325ef39", "axes": {"M": 40}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "5230e6ed-48b8-4765-bc9f-a7cdaabed615", "axes": {"M": 32}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "73212638-6584-476b-848d-2cb8ce0b829c", "axes": {"M": 24}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "404d21b1-2237-4e3c-b3ff-9b68878e5d70", "axes": {"M": 16}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "4cdab8cd-cb6b-4e73-8fe0-75b55fd784b5", "axes": {"M": 8}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "f439da26-2483-406c-977b-be185901207f", "axes": {"M": 4}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "6c2f4ba8-94d3-4e8f-997f-b7454242695a", "axes": {"M": 2}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "e39649a6-6f42-4a1b-9731-b45a9a87f7a5", "axes": {"M": 1}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "094ef833-829f-4efa-925d-d5bae9d6a116", "axes": {"M": 7}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "4c6bdefd-dd94-48b6-be3e-7eb25658eefd", "axes": {"M": 35}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "da2a2234-f5e9-4332-b62d-39865128153c", "axes": {"M": 972}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "897b6544-56c2-4d96-98fd-453ae3418e4b", "axes": {"M": 70}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "8a8311fa-8bb6-487d-8a36-7378e9680df8", "axes": {"M": 2053}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "b626104c-94cc-436b-9d2c-1d31432c1a87", "axes": {"M": 8192}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "547e0ce5-e484-4e0c-8b38-f153fb7ce6d4", "axes": {"M": 2379}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} +{"definition": "gemm_n4096_k4096", "solution": null, "workload": {"uuid": "339de815-896d-4d4c-8060-07208d559276", "axes": {"M": 15}, "inputs": {"A": {"type": "random"}, "B": {"type": "random"}}}, "evaluation": null} diff --git a/examples/ffi/Makefile b/examples/ffi/Makefile new file mode 100644 index 00000000..8da5a200 --- /dev/null +++ b/examples/ffi/Makefile @@ -0,0 +1,103 @@ +# Makefile for building the C++ TVM-FFI example +# +# Usage: +# make # Build the example +# make run # Build and run the example +# make clean # Clean build artifacts + +# Compiler settings +CXX := g++ +NVCC := nvcc +CXXFLAGS := -std=c++17 -O3 -Wall + +# TVM-FFI paths +# You can override these by setting environment variables: +# make TVM_FFI_HOME=/path/to/tvm-ffi +# Or specify custom paths directly: +# make TVM_FFI_INCLUDE=/path/to/include TVM_FFI_LIB=/path/to/lib +PYTHON ?= python3 + +TVM_FFI_PKG := $(shell $(PYTHON) -c "import tvm_ffi; print(tvm_ffi.__path__[0])" 2>/dev/null) + +ifeq ($(TVM_FFI_PKG),) + $(warning tvm_ffi not found via Python, falling back to /usr/local) + TVM_FFI_PKG := /usr/local + TVM_FFI_INCLUDE ?= $(TVM_FFI_PKG)/include + TVM_FFI_LIB ?= $(TVM_FFI_PKG)/lib + DLPACK_INCLUDE ?= $(TVM_FFI_PKG)/include +else + TVM_FFI_INCLUDE ?= $(TVM_FFI_PKG)/include + TVM_FFI_LIB ?= $(TVM_FFI_PKG)/lib + DLPACK_INCLUDE ?= $(TVM_FFI_PKG)/../3rdparty/dlpack/include +endif + +# CUDA paths +CUDA_HOME ?= /usr/local/cuda +CUDA_INCLUDE := $(CUDA_HOME)/include +CUDA_LIB := $(CUDA_HOME)/lib64 + +# Include and library flags +INCLUDES := -I$(TVM_FFI_INCLUDE) -I$(DLPACK_INCLUDE) -I$(CUDA_INCLUDE) +LDFLAGS := -L$(TVM_FFI_LIB) -L$(CUDA_LIB) +LIBS := -ltvm_ffi -lcuda -lcudart +RPATH := -Wl,-rpath=$(TVM_FFI_LIB) -Wl,-rpath=$(CUDA_LIB) + +TARGET := cpp_example +SOURCE := cpp_example.cc + +.PHONY: all run clean help check + +all: $(TARGET) + +check: + @echo "Detected Configuration:" + @echo " Python: $(PYTHON)" + @echo " TVM-FFI Package: $(TVM_FFI_PKG)" + @echo " TVM-FFI Include: $(TVM_FFI_INCLUDE)" + @echo " TVM-FFI Library: $(TVM_FFI_LIB)" + @echo " DLPack Include: $(DLPACK_INCLUDE)" + @echo " CUDA Home: $(CUDA_HOME)" + @echo " CUDA Include: $(CUDA_INCLUDE)" + @echo " CUDA Library: $(CUDA_LIB)" + +$(TARGET): $(SOURCE) + @echo "Building $(TARGET)..." + @echo " TVM-FFI include: $(TVM_FFI_INCLUDE)" + @echo " DLPack include: $(DLPACK_INCLUDE)" + @echo " TVM-FFI lib: $(TVM_FFI_LIB)" + @echo " CUDA include: $(CUDA_INCLUDE)" + @echo " CUDA lib: $(CUDA_LIB)" + $(CXX) $(CXXFLAGS) $(INCLUDES) $(SOURCE) -o $(TARGET) $(LDFLAGS) $(LIBS) $(RPATH) + @echo "✓ Build successful: ./$(TARGET)" + +run: $(TARGET) + @echo "Running $(TARGET)..." + @echo "" + ./$(TARGET) + +clean: + rm -f $(TARGET) + @echo "Cleaned build artifacts" + +help: + @echo "TVM-FFI C++ Example Makefile" + @echo "" + @echo "Targets:" + @echo " make - Build the C++ example" + @echo " make run - Build and run the example" + @echo " make check - Show detected paths and configuration" + @echo " make clean - Remove build artifacts" + @echo " make help - Show this help message" + @echo "" + @echo "Environment Variables:" + @echo " PYTHON - Python executable to use for detection (default: python3)" + @echo " TVM_FFI_INCLUDE - Path to TVM-FFI headers (default: auto-detect)" + @echo " TVM_FFI_LIB - Path to TVM-FFI libraries (default: auto-detect)" + @echo " DLPACK_INCLUDE - Path to DLPack headers (default: auto-detect)" + @echo " CUDA_HOME - Path to CUDA installation (default: /usr/local/cuda)" + @echo "" + @echo "Examples:" + @echo " make # Auto-detect all paths" + @echo " make PYTHON=python3.12 # Use specific Python version" + @echo " make CUDA_HOME=/usr/local/cuda-12.0 # Use specific CUDA version" + @echo " make TVM_FFI_INCLUDE=/custom/include TVM_FFI_LIB=/custom/lib # Custom paths" diff --git a/examples/ffi/README.md b/examples/ffi/README.md new file mode 100644 index 00000000..7389fd97 --- /dev/null +++ b/examples/ffi/README.md @@ -0,0 +1,132 @@ +# FlashInfer Bench (TVM-FFI) Kernel Distribution Example + +This directory contains examples demonstrating how to build, distribute, and load agent generated CUDA kernels using TVM-FFI across different environments. + + +## Overview + +The workflow consists of 4 main stages: + +``` +┌──────────────────────────────────────────────────────────────────────┐ +│ Stage 1: Problem Definition │ +│ - Definition in FlashInfer-Trace dataset │ +└────────────────────────────┬─────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────────────┐ +│ Stage 2: LLM Kernel Generation │ +│ - LLM reads the Definition + agent_vibecode.md prompt │ +│ - Generates CUDA kernel with TVM-FFI bindings │ +│ - Outputs Solution JSON with embedded source code │ +└────────────────────────────┬─────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────────────┐ +│ Stage 3: Build & Distribution │ +│ - TVMFFIBuilder compiles the Solution │ +│ - Generates framework-agnostic .so binary │ +│ - Extracts to distributed/ folder with metadata │ +└────────────────────────────┬─────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────────────┐ +│ Stage 4: Cross-Framework Usage │ +│ - JAX/PyTorch/C++ load the same .so file │ +│ - Execute kernel without recompilation │ +│ - Benchmarking or Apply the kernel │ +└──────────────────────────────────────────────────────────────────────┘ +``` + + +## Installation Prerequisites + +### Python Dependencies +```bash +pip install flashinfer-bench tvm-ffi torch +``` + +### For JAX Example +```bash +pip install jax[cuda13-local] jax-tvm-ffi +``` + +### For C++ Example +- CUDA Toolkit +- TVM-FFI C++ headers and libraries +- C++17 compatible compiler + +## Usage + +### 1. Generate Kernel with Agent + +You have two options to generate a CUDA kernel solution: + +**Option A: IDE Coding Agent** + +Use your preferred IDE coding agent to generate a GEMM kernel solution: + +```bash +# Open the instructions in your IDE +cat agent_vibecode.md +``` + +Follow the instructions in `agent_vibecode.md` to have the agent generate the solution interactively. + +**Option B: Kernel Generator** + +Use the kernel generator agent to generate solutions: + +```bash +# Ensure .env is configured in examples/kernel_generator/ +# Required: LLM_API_KEY and BASE_URL +python kernel_generator_example.py +``` + +Configure generation parameters in the script: +- `model_name`: LLM model to use (default: `gpt-5-2025-08-07`) +- `target_gpu`: Target GPU architecture (default: `B200`) +- `gen_rounds`: Number of refinement rounds (default: `10`) + +### 2. Build and Distribute +```bash +cd /flashinfer-bench/examples/ffi +python distribute_kernel.py +``` + +This builds the kernel and extracts `kernel.so` to `distributed/` folder. + +### 3. Run Examples + +**JAX:** +```bash +python jax_example.py +``` + +**PyTorch:** +```bash +python pytorch_example.py +``` + +**C++:** +```bash +make run +``` + +Each example loads the distributed kernel, executes it, and prints output shape and elements. + +## How It Works + +The kernel is built using `TVMFFIBuilder`, producing a self-contained `.so` file. This binary can be loaded across different runtimes: + +- **JAX**: Uses `tvm_ffi.load_module()` and `jax.ffi.ffi_call()` +- **PyTorch**: Uses `torch.utils.cpp_extension.load()` with custom CUDA extensions +- **C++**: Uses `ffi::Module::LoadFromFile()` + +The same `.so` file works across all frameworks without recompilation. + +## Notes + +- Kernels use destination-passing style (pre-allocated outputs) +- All examples use CUDA tensors on GPU device 0 +- Entry point format: `file.cu::function_name` diff --git a/examples/ffi/agent_vibecode.md b/examples/ffi/agent_vibecode.md new file mode 100644 index 00000000..5628ae29 --- /dev/null +++ b/examples/ffi/agent_vibecode.md @@ -0,0 +1,418 @@ +# Agent Instructions: CUDA GEMM Implementation with TVM FFI + +## Task Overview + +Write a complete CUDA implementation solving the GEMM definition `gemm_n4096_k4096` and output it as a JSON file to: + +**Output Path**: `Example-FlashInfer-Trace/solutions/agent_vibecode_gemm.json` + +The implementation must use TVM FFI bindings and conform to the Solution JSON schema. + +## Target Operation + +**Operation**: General Matrix Multiply (GEMM) +**Formula**: `C = A @ B.T` +**Shapes**: +- A: `[M, K]` where M is variable, K = 4096 +- B: `[N, K]` where N = 4096, K = 4096 +- C: `[M, N]` (output) +- **Data type**: `float16` (FP16) + +**Note**: This is computing `A @ B.T` (transpose of B), not `A @ B`. + +## Solution Structure Requirements + +Your solution **must** include exactly 3 source files with these names: + +1. **`kernel.h`**: Header file with function declarations and shared definitions +2. **`kernel.cu`**: CUDA kernel device code implementation +3. **`main.cpp`**: TVM FFI host code with bindings + +## TVM FFI Requirements + +### Required Headers in main.cpp +```cpp +#include // TensorView: tensor arguments +#include // TVM_FFI_DLL_EXPORT_TYPED_FUNC +#include // TVM_FFI_ICHECK, TVM_FFI_THROW +#include // TVMFFIEnvGetStream +#include +#include "kernel.h" +``` + +### Function Signature +The exported function **must** be named `run` and match the definition's input/output names: + +```cpp +void run(tvm::ffi::TensorView A, tvm::ffi::TensorView B, tvm::ffi::TensorView C); +``` + +**Important**: The function takes A, B, and C as parameters. C is pre-allocated by the caller. + +### TVM FFI Binding +Use TVM_FFI_DLL_EXPORT_TYPED_FUNC to expose the function: + +```cpp +TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, run); +``` + +### Input Validation +Validate inputs using TVM FFI error handling: + +```cpp +// Check dimensions +TVM_FFI_ICHECK_EQ(A.ndim(), 2) << "A must be 2D"; +TVM_FFI_ICHECK_EQ(B.ndim(), 2) << "B must be 2D"; +TVM_FFI_ICHECK_EQ(C.ndim(), 2) << "C must be 2D"; + +// Check shapes +TVM_FFI_ICHECK_EQ(A.size(1), 4096) << "A.shape[1] must be 4096 (K)"; +TVM_FFI_ICHECK_EQ(B.size(0), 4096) << "B.shape[0] must be 4096 (N)"; +TVM_FFI_ICHECK_EQ(B.size(1), 4096) << "B.shape[1] must be 4096 (K)"; + +// Check shape compatibility +TVM_FFI_ICHECK_EQ(A.size(1), B.size(1)) << "K dimension mismatch"; +TVM_FFI_ICHECK_EQ(C.size(0), A.size(0)) << "M dimension mismatch"; +TVM_FFI_ICHECK_EQ(C.size(1), B.size(0)) << "N dimension mismatch"; + +// Check data types (float16) +TVM_FFI_ICHECK_EQ(A.dtype().code, kDLFloat) << "A must be float type"; +TVM_FFI_ICHECK_EQ(A.dtype().bits, 16) << "A must be float16"; +TVM_FFI_ICHECK_EQ(B.dtype().code, kDLFloat) << "B must be float type"; +TVM_FFI_ICHECK_EQ(B.dtype().bits, 16) << "B must be float16"; +TVM_FFI_ICHECK_EQ(C.dtype().code, kDLFloat) << "C must be float type"; +TVM_FFI_ICHECK_EQ(C.dtype().bits, 16) << "C must be float16"; + +// Check device (must be CUDA) +TVM_FFI_ICHECK_EQ(A.device().device_type, kDLCUDA) << "A must be on CUDA"; +TVM_FFI_ICHECK_EQ(B.device().device_type, kDLCUDA) << "B must be on CUDA"; +TVM_FFI_ICHECK_EQ(C.device().device_type, kDLCUDA) << "C must be on CUDA"; +``` + +### CUDA Stream Management +Get the CUDA stream from TVM FFI environment: + +```cpp +DLDevice dev = A.device(); +cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(dev.device_type, dev.device_id)); + +// Launch kernel on the stream +kernel_launch<<>>(args...); +``` + +### Memory Access +Access tensor data through TensorView API: + +```cpp +const __half* A_data = static_cast(A.data_ptr()); +const __half* B_data = static_cast(B.data_ptr()); +__half* C_data = static_cast<__half*>(C.data_ptr()); + +int64_t M = A.size(0); +int64_t K = A.size(1); +int64_t N = B.size(0); +``` + +## CUDA Kernel Implementation Guidelines + +### Recommended Approach +Implement a tiled GEMM kernel optimized for float16: + +1. **Use shared memory** for tile caching +2. **Leverage Tensor Cores** if targeting modern GPUs (use `__half` or `half2`) +3. **Thread block tiling**: Typical tile sizes like 128×128 or 256×128 +4. **Handle transposition**: Since we compute `A @ B.T`, adjust memory access patterns + +### Kernel Signature Example +```cpp +__global__ void gemm_kernel_device( + const half* __restrict__ A, + const half* __restrict__ B, + half* __restrict__ C, + int M, int N, int K +); +``` + +### Performance Considerations +- Use `__half` or `half2` types for FP16 operations +- Ensure coalesced memory access +- Minimize bank conflicts in shared memory +- Consider using warp-level primitives for reductions + +## File Organization + +### Required File Structure + +**File 1: `kernel.h`** +- CUDA kernel function declarations +- Host launcher function declarations +- Shared constants and type definitions +- Include guards + +**File 2: `kernel.cu`** +- `__global__` kernel implementations +- `__device__` helper functions +- Host-side kernel launcher function +- CUDA-specific optimizations (shared memory, tensor cores, etc.) + +**File 3: `main.cpp`** +- TVM FFI bindings +- `run` function that matches definition signature: `void run(TensorView A, TensorView B, TensorView C)` +- Input validation using `TVM_FFI_ICHECK_*` macros +- Stream management via `TVMFFIEnvGetStream()` +- `TVM_FFI_DLL_EXPORT_TYPED_FUNC` for function export + +## JSON Schema Format + +The output JSON must conform to the Solution schema and be written to: + +**`Example-FlashInfer-Trace/solutions/agent_vibecode_gemm.json`** + +### JSON Structure + +```json +{ + "name": "agent_example_gemm", + "definition": "gemm_n4096_k4096", + "description": "High-performance CUDA GEMM implementation for C = A @ B.T using TVM FFI bindings", + "author": "vibecode-agent", + "spec": { + "language": "cuda", + "target_hardware": [ + "NVIDIA_H100", + "NVIDIA_A100" + ], + "dependencies": [], + "entry_point": "main.cpp::run" + }, + "sources": [ + { + "path": "kernel.h", + "content": "... complete header file content as string ..." + }, + { + "path": "kernel.cu", + "content": "... complete CUDA kernel code as string ..." + }, + { + "path": "main.cpp", + "content": "... complete TVM FFI binding code as string ..." + } + ] +} +``` + +### Critical Schema Fields + +| Field | Value | Notes | +|-------|-------|-------| +| `name` | `"agent_vibecode_gemm"` | Unique identifier for this solution | +| `definition` | `"gemm_n4096_k4096"` | **Must** match the definition name exactly | +| `language` | `"cuda"` | Lowercase, primary language | +| `target_hardware` | Array of strings | e.g., `["NVIDIA_H100", "NVIDIA_A100"]` | +| `entry_point` | `"main.cpp::run"` | Format: `{filename}::{function_name}` | +| `sources` | Array of 3 file objects | Each with `path` and `content` fields | + +### Entry Point Convention + +The entry point specifies which function the benchmarker will call: +- Format: `"main.cpp::run"` +- The function `run` must be exposed via `TVM_FFI_DLL_EXPORT_TYPED_FUNC` +- The benchmarker will: + 1. Compile all source files into a TVM FFI shared library + 2. Load the compiled module using TVM FFI + 3. Call the `run` function with test inputs `A`, `B`, and pre-allocated `C` + 4. Validate the output C against the reference + +## Complete Implementation Example + +Below is a skeleton showing the structure of all three files: + +### kernel.h +```cpp +#ifndef GEMM_N4096_K4096_KERNEL_H +#define GEMM_N4096_K4096_KERNEL_H + +#include +#include +#include + +// Constants from definition +constexpr int GEMM_N_CONST = 4096; +constexpr int GEMM_K_CONST = 4096; + +// Kernel launcher function +void gemm_n4096_k4096_launch( + const __half* A, + const __half* B, + __half* C, + int M, + cudaStream_t stream +); + +#endif // GEMM_N4096_K4096_KERNEL_H +``` + +### kernel.cu +```cpp +#include "kernel.h" +#include // For tensor cores + +using namespace nvcuda; + +// Kernel configuration +constexpr int BLOCK_M = 128; +constexpr int BLOCK_N = 256; +constexpr int BLOCK_K = 64; + +__global__ void gemm_kernel( + const __half* __restrict__ A, + const __half* __restrict__ B, + __half* __restrict__ C, + int M +) { + // Implement optimized GEMM with: + // - Shared memory tiling + // - WMMA/Tensor Core operations + // - Coalesced memory access + // - Proper synchronization + + // C = A @ B.T + // A is [M, 4096], B is [4096, 4096], C is [M, 4096] +} + +void gemm_n4096_k4096_launch( + const __half* A, + const __half* B, + __half* C, + int M, + cudaStream_t stream +) { + if (M <= 0) return; + + dim3 block(256); + dim3 grid((GEMM_N_CONST + BLOCK_N - 1) / BLOCK_N, + (M + BLOCK_M - 1) / BLOCK_M); + + gemm_kernel<<>>(A, B, C, M); +} +``` + +### main.cpp +```cpp +#include +#include +#include +#include +#include +#include "kernel.h" + +void run(tvm::ffi::TensorView A, tvm::ffi::TensorView B, tvm::ffi::TensorView C) { + // Input validation - dimensions + TVM_FFI_ICHECK_EQ(A.ndim(), 2) << "A must be 2D"; + TVM_FFI_ICHECK_EQ(B.ndim(), 2) << "B must be 2D"; + TVM_FFI_ICHECK_EQ(C.ndim(), 2) << "C must be 2D"; + + // Get dimensions + int64_t M = A.size(0); + int64_t K = A.size(1); + int64_t N = B.size(0); + + // Check shapes + TVM_FFI_ICHECK_EQ(K, 4096) << "A.shape[1] must be 4096 (K)"; + TVM_FFI_ICHECK_EQ(N, 4096) << "B.shape[0] must be 4096 (N)"; + TVM_FFI_ICHECK_EQ(B.size(1), 4096) << "B.shape[1] must be 4096 (K)"; + TVM_FFI_ICHECK_EQ(C.size(0), M) << "C.shape[0] must match A.shape[0] (M)"; + TVM_FFI_ICHECK_EQ(C.size(1), N) << "C.shape[1] must be 4096 (N)"; + + // Check data types (float16) + TVM_FFI_ICHECK_EQ(A.dtype().code, kDLFloat) << "A must be float type"; + TVM_FFI_ICHECK_EQ(A.dtype().bits, 16) << "A must be float16"; + TVM_FFI_ICHECK_EQ(B.dtype().code, kDLFloat) << "B must be float type"; + TVM_FFI_ICHECK_EQ(B.dtype().bits, 16) << "B must be float16"; + TVM_FFI_ICHECK_EQ(C.dtype().code, kDLFloat) << "C must be float type"; + TVM_FFI_ICHECK_EQ(C.dtype().bits, 16) << "C must be float16"; + + // Check device (must be CUDA) + TVM_FFI_ICHECK_EQ(A.device().device_type, kDLCUDA) << "A must be on CUDA"; + TVM_FFI_ICHECK_EQ(B.device().device_type, kDLCUDA) << "B must be on CUDA"; + TVM_FFI_ICHECK_EQ(C.device().device_type, kDLCUDA) << "C must be on CUDA"; + + // Get data pointers + const __half* A_data = static_cast(A.data_ptr()); + const __half* B_data = static_cast(B.data_ptr()); + __half* C_data = static_cast<__half*>(C.data_ptr()); + + // Get CUDA stream from TVM FFI environment + DLDevice dev = A.device(); + cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(dev.device_type, dev.device_id)); + + // Launch kernel + gemm_n4096_k4096_launch(A_data, B_data, C_data, static_cast(M), stream); +} + +// Export the function with TVM FFI +TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, run); +``` + +## Performance Optimization Guidelines + +Your CUDA kernel should include: + +1. **Tensor Core Usage (WMMA)**: Use `nvcuda::wmma` for 16x16x16 matrix operations +2. **Shared Memory Tiling**: Cache tiles of A and B in shared memory +3. **Memory Coalescing**: Ensure threads access consecutive memory addresses +4. **Bank Conflict Avoidance**: Add padding to shared memory arrays +5. **Compute Intensity**: Maximize compute-to-memory-access ratio +6. **Register Optimization**: Minimize register usage for higher occupancy +7. **Stream Pipelining**: Overlap compute and memory operations + +## Output Format + +Write the complete JSON solution to: +**`Example-FlashInfer-Trace/solutions/agent_vibecode_gemm.json`** + +The JSON must be valid and contain: +- All required schema fields +- Complete source code for all 3 files in the `content` fields +- Properly escaped strings (use JSON encoding) + +## Validation Checklist + +Before finalizing, verify: +- [ ] File names are exactly: `kernel.h`, `kernel.cu`, `main.cpp` +- [ ] Entry point is `"main.cpp::run"` +- [ ] Function signature: `void run(tvm::ffi::TensorView A, tvm::ffi::TensorView B, tvm::ffi::TensorView C)` +- [ ] TVM_FFI_DLL_EXPORT_TYPED_FUNC exposes the `run` function +- [ ] All three files included in `sources` array +- [ ] Input validation with `TVM_FFI_ICHECK_*` macros +- [ ] Kernel implements `C = A @ B.T` (transpose of B) +- [ ] Data type is `__half` (float16) +- [ ] CUDA stream from `TVMFFIEnvGetStream()` +- [ ] Checks that all tensors are on CUDA device +- [ ] JSON is valid and properly formatted +- [ ] All TVM FFI headers included correctly + +## Expected Agent Behavior + +1. **Read** the GEMM definition from `definitions/gemm_n4096_k4096.json` +2. **Understand** the operation: `C = A @ B.T` with shapes [M,K] × [N,K] → [M,N] +3. **Implement** a high-performance CUDA kernel with tiling and tensor cores +4. **Create** TVM FFI bindings following the API guidelines +5. **Package** all source code into the Solution JSON format +6. **Write** the JSON to `Example-FlashInfer-Trace/solutions/agent_vibecode_gemm.json` + +The JSON file should be ready to be consumed by the flashinfer-bench benchmarking system. + +## Summary + +This agent.md provides complete instructions for generating a CUDA GEMM kernel implementation using TVM FFI bindings. The key points are: + +- **3 files required**: `kernel.h`, `kernel.cu`, `main.cpp` +- **Entry point**: `main.cpp::run` with signature `void run(TensorView A, TensorView B, TensorView C)` +- **TVM FFI export**: Use `TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, run)` +- **Validation**: Use `TVM_FFI_ICHECK_*` macros for input validation +- **Stream management**: Get stream via `TVMFFIEnvGetStream()` +- **Output**: Write complete JSON to `Example-FlashInfer-Trace/solutions/agent_vibecode_gemm.json` diff --git a/examples/ffi/cpp_example.cc b/examples/ffi/cpp_example.cc new file mode 100644 index 00000000..0f28c1d0 --- /dev/null +++ b/examples/ffi/cpp_example.cc @@ -0,0 +1,173 @@ +/** + * C++ example: Load and run the distributed .so kernel using TVM-FFI C++ API. + * Build with: + * g++ -std=c++17 cpp_example.cc -o cpp_example \ + * -I/path/to/tvm-ffi/include \ + * -L/path/to/tvm-ffi/lib \ + * -ltvm_ffi -lcuda -lcublas \ + * -Wl,-rpath=/path/to/tvm-ffi/lib + */ + + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + + namespace ffi = tvm::ffi; + + std::string read_entry_symbol(const std::string& metadata_path) { + std::ifstream file(metadata_path); + std::string line; + while (std::getline(file, line)) { + if (line.find("Entry Symbol:") == 0) { + size_t colon_pos = line.find(":"); + if (colon_pos != std::string::npos) { + std::string symbol = line.substr(colon_pos + 1); + size_t start = symbol.find_first_not_of(" \t"); + size_t end = symbol.find_last_not_of(" \t\r\n"); + if (start != std::string::npos && end != std::string::npos) { + return symbol.substr(start, end - start + 1); + } + } + } + } + return ""; + } + + void init_random_tensor(ffi::TensorView tensor) { + size_t num_elements = 1; + for (int i = 0; i < tensor.ndim(); ++i) { + num_elements *= tensor.shape()[i]; + } + + std::vector<__half> host_data(num_elements); + + for (size_t i = 0; i < num_elements; ++i) { + host_data[i] = __float2half(static_cast(rand()) / RAND_MAX); + } + + cudaMemcpy(tensor.data_ptr(), host_data.data(), + num_elements * sizeof(__half), cudaMemcpyHostToDevice); + } + + int cuda_tensor_alloc(DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, + void (*SetError)(void* error_ctx, const char* kind, const char* message)) { + size_t num_bytes = 1; + for (int i = 0; i < prototype->ndim; ++i) { + num_bytes *= prototype->shape[i]; + } + num_bytes *= (prototype->dtype.bits * prototype->dtype.lanes + 7) / 8; + + void* ptr; + cudaError_t err = cudaMalloc(&ptr, num_bytes); + if (err != cudaSuccess) { + if (SetError) { + SetError(error_ctx, "RuntimeError", cudaGetErrorString(err)); + } + return -1; + } + + int64_t* shape = new int64_t[prototype->ndim]; + int64_t* strides = nullptr; + for (int i = 0; i < prototype->ndim; ++i) { + shape[i] = prototype->shape[i]; + } + if (prototype->strides) { + strides = new int64_t[prototype->ndim]; + for (int i = 0; i < prototype->ndim; ++i) { + strides[i] = prototype->strides[i]; + } + } + + // Allocate DLManageTensorVersioned structure + DLManagedTensorVersioned* managed = new DLManagedTensorVersioned(); + managed->version = {1, 0}; + managed->manager_ctx = nullptr; + managed->flags = 0; + + // Setup deleter + managed->deleter = [](DLManagedTensorVersioned* self) { + if (self->dl_tensor.data) { + cudaFree(self->dl_tensor.data); + } + delete[] self->dl_tensor.shape; + if (self->dl_tensor.strides) { + delete[] self->dl_tensor.strides; + } + delete self; + }; + + // Setup DLTensor + managed->dl_tensor = *prototype; + managed->dl_tensor.data = ptr; + managed->dl_tensor.shape = shape; + managed->dl_tensor.strides = strides; + + *out = managed; + return 0; + } + + ffi::Tensor allocate_cuda_tensor(std::vector shape, DLDataType dtype) { + DLDevice device{kDLCUDA, 0}; + ffi::ShapeView shape_view(shape.data(), shape.size()); + return ffi::Tensor::FromEnvAlloc(TVMFFIEnvTensorAlloc, shape_view, dtype, device); + } + + int main() { + cudaSetDevice(0); + + DLDevice device{kDLCUDA, 0}; + TVMFFIEnvSetDLPackManagedTensorAllocator(cuda_tensor_alloc, 1, nullptr); + + cudaStream_t stream; + cudaStreamCreate(&stream); + TVMFFIEnvSetStream(device.device_type, device.device_id, stream, nullptr); + + const std::string dist_dir = "distributed"; + std::string entry_symbol = read_entry_symbol(dist_dir + "/kernel_metadata.txt"); + + ffi::Module mod = ffi::Module::LoadFromFile(dist_dir + "/kernel.so"); + ffi::Function kernel_fn = mod->GetFunction(entry_symbol).value(); + + std::cout << "Loaded kernel: " << entry_symbol << std::endl; + + // Prepare inputs: C = A @ B.T + const int64_t M = 1024, N = 4096, K = 4096; + DLDataType dtype{kDLFloat, 16, 1}; + + ffi::Tensor A = allocate_cuda_tensor({M, K}, dtype); + ffi::Tensor B = allocate_cuda_tensor({N, K}, dtype); + ffi::Tensor C = allocate_cuda_tensor({M, N}, dtype); + + init_random_tensor(A); + init_random_tensor(B); + cudaMemset(C.data_ptr(), 0, M * N * sizeof(__half)); + + kernel_fn(A, B, C); + cudaDeviceSynchronize(); + + std::vector<__half> host_output(M * N); + cudaMemcpy(host_output.data(), C.data_ptr(), + M * N * sizeof(__half), cudaMemcpyDeviceToHost); + + std::cout << "Output shape: (" << M << ", " << N << ")" << std::endl; + + std::cout << "First 10 elements: ["; + size_t num_to_print = std::min(size_t(10), host_output.size()); + for (size_t i = 0; i < num_to_print; ++i) { + std::cout << __half2float(host_output[i]); + if (i < num_to_print - 1) std::cout << ", "; + } + std::cout << "]" << std::endl; + + cudaStreamDestroy(stream); + + return 0; + } diff --git a/examples/ffi/distribute_kernel.py b/examples/ffi/distribute_kernel.py new file mode 100644 index 00000000..9831135a --- /dev/null +++ b/examples/ffi/distribute_kernel.py @@ -0,0 +1,38 @@ +import shutil +from pathlib import Path + +from flashinfer_bench import TraceSet +from flashinfer_bench.compile.builders.tvm_ffi_builder import TVMFFIBuilder + + +def main(): + traceset = TraceSet.from_path("Example-FlashInfer-Trace") + + definition_name = "gemm_n4096_k4096" + definition = traceset.definitions[definition_name] + + solutions = list(traceset.solutions[definition_name]) + solution = solutions[0] + print(f"Building solution: {solution.name}") + + builder = TVMFFIBuilder() + runnable = builder.build(definition, solution) + + so_path = runnable.meta["binary"] + entry_symbol = runnable.meta["symbol"] + + dist_dir = Path("distributed") + dist_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(so_path, dist_dir / "kernel.so") + + with open(dist_dir / "kernel_metadata.txt", "w") as f: + f.write(f"Entry Symbol: {entry_symbol}\n") + f.write(f"Definition: {definition.name}\n") + f.write(f"Solution: {solution.name}\n") + + print(f"Built kernel: {dist_dir / 'kernel.so'}") + print(f"Entry symbol: {entry_symbol}") + + +if __name__ == "__main__": + main() diff --git a/examples/ffi/jax_example.py b/examples/ffi/jax_example.py new file mode 100644 index 00000000..8e155e38 --- /dev/null +++ b/examples/ffi/jax_example.py @@ -0,0 +1,56 @@ +""" +JAX example: Load and run the distributed .so kernel using jax-tvm-ffi. + +Requirements: + pip install jax jax-tvm-ffi +""" + +from pathlib import Path + +import jax +import jax.numpy as jnp +import jax_tvm_ffi +import tvm_ffi + + +def main(): + dist_dir = Path("distributed") + so_path = dist_dir / "kernel.so" + + entry_symbol = None + for line in (dist_dir / "kernel_metadata.txt").read_text().split("\n"): + if line.startswith("Entry Symbol:"): + entry_symbol = line.split(":", 1)[1].strip() + break + + # Load and register kernel + mod = tvm_ffi.load_module(str(so_path)) + kernel_fn = getattr(mod, entry_symbol) + jax_tvm_ffi.register_ffi_target(entry_symbol, kernel_fn, platform="gpu") + + print(f"Loaded kernel: {entry_symbol}") + + # Prepare inputs: C = A @ B.T + M, N, K = 1024, 4096, 4096 + jax_device = jax.devices("gpu")[0] + A = jnp.array( + jax.random.normal(jax.random.PRNGKey(0), (M, K)), dtype=jnp.float16, device=jax_device + ) + B = jnp.array( + jax.random.normal(jax.random.PRNGKey(1), (N, K)), dtype=jnp.float16, device=jax_device + ) + + result = jax.ffi.ffi_call( + entry_symbol, jax.ShapeDtypeStruct((M, N), jnp.float16), vmap_method="broadcast_all" + )(A, B) + + # Verify + reference = jnp.matmul(A, B.T) + max_diff = jnp.abs(result - reference).max() + + print(f"Max diff vs reference: {max_diff:.6f}") + print(f"Correct: {max_diff < 1e-2}") + + +if __name__ == "__main__": + main() diff --git a/examples/ffi/kernel_generator_example.py b/examples/ffi/kernel_generator_example.py new file mode 100644 index 00000000..2ca1f642 --- /dev/null +++ b/examples/ffi/kernel_generator_example.py @@ -0,0 +1,118 @@ +import json +import os +import sys +from pathlib import Path + +from dotenv import load_dotenv + +from flashinfer_bench import TraceSet +from flashinfer_bench.data import save_json_file + +# Get the path to examples/kernel_generator +script_dir = Path(__file__).parent # examples/ffi +examples_dir = script_dir.parent # examples +kernel_gen_dir = examples_dir / "kernel_generator" + +sys.path.insert(0, str(kernel_gen_dir)) +from kernel_generator import KernelGenerator + +load_dotenv(kernel_gen_dir / ".env") + + +def main(): + """ + Generate optimized CUDA solutions for FFI bindings. + """ + # TODO: select model, target gpu, definition + model_name = "gpt-5-2025-08-07" # Choose model + language = "cuda" # Target CUDA for FFI bindings + target_gpu = "B200" # Target GPU + + print(f"Loading Example TraceSet") + traceset_path = script_dir / "Example-FlashInfer-Trace" + traceset = TraceSet.from_path(traceset_path) + + definition_name = "gemm_n4096_k4096" + definition = traceset.definitions[definition_name] + + api_key = os.getenv("LLM_API_KEY") + base_url = os.getenv("BASE_URL") + if not api_key: + print( + "Please set LLM_API_KEY environment variable or modify this script to pass api_key explicitly" + ) + return + + generator = KernelGenerator( + model_name=model_name, + language=language, + target_gpu=target_gpu, + api_key=api_key, + base_url=base_url, + reasoning_effort="high", + ) + + print(f"\n{'='*60}") + print(f"Generating CUDA solution for: {definition_name}") + print(f"Definition type: {definition.op_type}") + print(f"Target GPU: {target_gpu}") + print(f"{'='*60}") + + # Get workloads for this definition + workloads = traceset.workloads.get(definition_name, []) + if not workloads: + print(f"No workloads found for definition '{definition_name}'") + return + + print(f"Found {len(workloads)} workloads for this definition") + + # Generate solution with beam search + solution = None + max_attempts = 2 + + for attempt in range(1, max_attempts + 1): + try: + print(f"\nAttempt {attempt}/{max_attempts}") + + solution = generator.generate( + traceset=traceset, + definition=definition, + gen_rounds=10, # search depth + beam=True, + beam_width=3, + ) + + print(f"Successfully generated solution for {definition_name}") + break + + except Exception as e: + print(f"Attempt {attempt} failed: {e}") + if attempt < max_attempts: + print(f"Retrying... ({attempt + 1}/{max_attempts})") + else: + print(f"All attempts failed - aborting") + return + + # Save the solution + if solution: + try: + solutions_dir = Path(traceset_path) / "solutions" + solutions_dir.mkdir(parents=True, exist_ok=True) + + solution_filename = f"{solution.name}.json" + solution_path = solutions_dir / solution_filename + + save_json_file(solution, solution_path) + + print(f"\n{'='*60}") + print(f"SUCCESS!") + print(f"{'='*60}") + print(f"Solution saved to: {solution_path}") + + except Exception as e: + print(f"Failed to save solution: {e}") + return + + +if __name__ == "__main__": + main() diff --git a/examples/ffi/pytorch_example.py b/examples/ffi/pytorch_example.py new file mode 100644 index 00000000..1a2ecede --- /dev/null +++ b/examples/ffi/pytorch_example.py @@ -0,0 +1,52 @@ +""" +PyTorch example: Load and run the distributed .so kernel using tvm-ffi. +""" + +from pathlib import Path + +import torch +import tvm_ffi + + +def main(): + dist_dir = Path("distributed") + so_path = dist_dir / "kernel.so" + + entry_symbol = None + for line in (dist_dir / "kernel_metadata.txt").read_text().split("\n"): + if line.startswith("Entry Symbol:"): + entry_symbol = line.split(":", 1)[1].strip() + break + + if entry_symbol is None: + raise ValueError("Entry symbol not found in metadata") + + mod = tvm_ffi.load_module(str(so_path)) + kernel_fn = getattr(mod, entry_symbol) + + print(f"Loaded kernel: {entry_symbol}") + + # Prepare inputs: C = A @ B.T + M, N, K = 1024, 4096, 4096 + + torch.manual_seed(0) + A = torch.randn(M, K, dtype=torch.float16, device="cuda") + + torch.manual_seed(1) + B = torch.randn(N, K, dtype=torch.float16, device="cuda") + + C = torch.empty(M, N, dtype=torch.float16, device="cuda") + + # Run kernel: C = A @ B.T + kernel_fn(A, B, C) + + # Verify against PyTorch reference + reference = torch.matmul(A, B.T) + max_diff = torch.abs(C - reference).max().item() + + print(f"Max diff vs reference: {max_diff:.6f}") + print(f"Correct: {max_diff < 1e-2}") + + +if __name__ == "__main__": + main() diff --git a/flashinfer_bench/agents/ffi_prompt.py b/flashinfer_bench/agents/ffi_prompt.py new file mode 100644 index 00000000..ab80df33 --- /dev/null +++ b/flashinfer_bench/agents/ffi_prompt.py @@ -0,0 +1,1868 @@ +FFI_PROMPT_SIMPLE = """ +Use TVM FFI format for your generated kernel host function and bindings + +Use the following headers: +#include // TensorView: tensor arguments +#include // TVM_FFI_DLL_EXPORT_TYPED_FUNC +#include // TVM_FFI_ICHECK, TVM_FFI_THROW + +Include when using CUDA / env-managed streams or allocators: +#include // TVMFFIEnvGetStream, TVMFFIEnvTensorAlloc + +# TVM FFI API Reference + +## TensorView (tvm/ffi/container/tensor.h) + +**Purpose:** Non-owning view of tensor data. Use `tvm::ffi::TensorView` for function parameters. + +### Essential Methods +```cpp +void* data_ptr() const // Raw data pointer +DLDevice device() const // Device info (.device_type, .device_id) +DLDataType dtype() const // Data type (.code, .bits, .lanes) +int32_t ndim() const // Number of dimensions +int64_t size(int64_t idx) const // Size of dimension idx (negative idx: count from end) +int64_t numel() const // Total number of elements +bool IsContiguous() const // Check if memory is contiguous +``` + +### ShapeView Access +```cpp +ShapeView shape() const // Get shape (array-like) +ShapeView strides() const // Get strides (array-like) + +// ShapeView can be indexed like an array: +// tensor.shape()[0], tensor.shape()[1], etc. +``` + +### Device Type Constants (DLDevice.device_type) +```cpp +kDLCPU = 1 // CPU +kDLCUDA = 2 // CUDA GPU +kDLCUDAHost = 3 // CUDA pinned memory +kDLROCM = 10 // ROCm/HIP +``` + +### Data Type Constants (DLDataType) +```cpp +// DLDataType has: uint8_t code, uint8_t bits, uint16_t lanes +// Common .code values: +kDLInt = 0 // Signed integer +kDLUInt = 1 // Unsigned integer +kDLFloat = 2 // IEEE floating point +kDLBfloat = 4 // Brain floating point + +// Example: float32 has code=2 (kDLFloat), bits=32, lanes=1 +// Example: half/fp16 has code=2, bits=16, lanes=1 +``` + +## Function Export (tvm/ffi/function.h) + +### Export Macro +```cpp +// Use this to export your C++ function for FFI +TVM_FFI_DLL_EXPORT_TYPED_FUNC(export_name, cpp_function) + +// Example: +void MyKernel(tvm::ffi::TensorView a, tvm::ffi::TensorView b) { ... } +TVM_FFI_DLL_EXPORT_TYPED_FUNC(my_kernel, MyKernel); +``` + +**Supported function signatures:** +- `void func(TensorView t1, TensorView t2, ...)` +- `void func(TensorView t, int64_t size, float alpha, ...)` +- `int64_t func(TensorView t)` +- Any combination of: `TensorView`, `int32_t`, `int64_t`, `float`, `double`, `bool`, `std::string` + +## Error Handling (tvm/ffi/error.h) + +### Throwing Errors +```cpp +// Throw with custom error kind +TVM_FFI_THROW(ValueError) << "Invalid input: " << x; +TVM_FFI_THROW(RuntimeError) << "CUDA error: " << cudaGetErrorString(err); +TVM_FFI_THROW(TypeError) << "Expected float32, got int32"; +``` + +### Assertions (for internal logic errors) +```cpp +TVM_FFI_ICHECK(condition) << "message" // General check +TVM_FFI_ICHECK_EQ(x, y) << "x must equal y" // x == y +TVM_FFI_ICHECK_NE(x, y) << "x must not equal y" // x != y +TVM_FFI_ICHECK_LT(x, y) << "x must be less than y" // x < y +TVM_FFI_ICHECK_LE(x, y) << "x must be at most y" // x <= y +TVM_FFI_ICHECK_GT(x, y) << "x must be greater than y" // x > y +TVM_FFI_ICHECK_GE(x, y) << "x must be at least y" // x >= y +``` + +### User Input Validation (use TVM_FFI_THROW instead) +```cpp +// For user-facing errors, use TVM_FFI_THROW with appropriate error kind: +if (x.ndim() != 2) { + TVM_FFI_THROW(ValueError) << "Expected 2D tensor, got " << x.ndim() << "D"; +} +if (x.dtype().code != kDLFloat || x.dtype().bits != 32) { + TVM_FFI_THROW(TypeError) << "Expected float32 dtype"; +} +``` + +## CUDA Stream Management (tvm/ffi/extra/c_env_api.h) + +```cpp +// Get the current CUDA stream for a device +TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t device_type, int32_t device_id) + +// Usage: +DLDevice dev = tensor.device(); +cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(dev.device_type, dev.device_id)); + +// Launch kernel on the stream: +my_kernel<<>>(...); +``` + +Example Usage of FFI binding: +```cpp +// File: add_one_cuda.cu +#include +#include +#include +#include + +namespace my_kernels { + +__global__ void AddOneKernel(float* x, float* y, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + y[idx] = x[idx] + 1.0f; + } +} + +void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) { + // Input validation + TVM_FFI_ICHECK_EQ(x.ndim(), 1) << "x must be 1D"; + TVM_FFI_ICHECK_EQ(y.ndim(), 1) << "y must be 1D"; + TVM_FFI_ICHECK_EQ(x.size(0), y.size(0)) << "Shape mismatch"; + + // Get data pointers + float* x_data = static_cast(x.data_ptr()); + float* y_data = static_cast(y.data_ptr()); + int64_t n = x.size(0); + + // Get CUDA stream from environment + DLDevice dev = x.device(); + cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(dev.device_type, dev.device_id)); + + // Launch kernel + int64_t threads = 256; + int64_t blocks = (n + threads - 1) / threads; + AddOneKernel<<>>(x_data, y_data, n); +} + +// Export the function with name "add_one_cuda" +TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, AddOne); + +} // namespace my_kernels +``` +""" + +FFI_FULL_PROMPT = """ +# TVM FFI Complete API Reference + +This guide provides comprehensive documentation for writing CUDA kernels with TVM FFI bindings. + +## Required Headers + +```cpp +#include // Tensor, TensorView +#include // Array +#include // Map +#include // Tuple<...> +#include // Shape, ShapeView +#include // Variant<...> +#include // String +#include // DLDataType utilities +#include // Function export macros +#include // Error handling +#include // CUDA stream management +#include // Module loading +``` + +--- + +# Part 1: Data Container Types + +## 1.1 Tensor & TensorView (tvm/ffi/container/tensor.h) + +### TensorView - Lightweight Non-Owning View + +**Purpose:** Non-owning view of tensor data. Always use `tvm::ffi::TensorView` for function parameters. + +**Constructors:** +```cpp +TensorView(const Tensor& tensor) // From Tensor +TensorView(const DLTensor* tensor) // From DLTensor pointer +``` + +**Essential Methods:** +```cpp +// Data access +void* data_ptr() const // Raw data pointer +int64_t byte_offset() const // Byte offset from base pointer + +// Device information +DLDevice device() const // Returns device (.device_type, .device_id) + +// Type information +DLDataType dtype() const // Returns data type (.code, .bits, .lanes) + +// Dimension information +int32_t ndim() const // Number of dimensions +int64_t size(int64_t idx) const // Size of dimension idx (supports negative indexing) +int64_t stride(int64_t idx) const // Stride of dimension idx (supports negative indexing) +int64_t numel() const // Total number of elements + +// Shape and strides access +ShapeView shape() const // Get shape (array-like, indexable: shape()[0], shape()[1], ...) +ShapeView strides() const // Get strides (array-like) + +// Memory properties +bool IsContiguous() const // Check if memory layout is contiguous +bool IsAligned(size_t alignment) const // Check if data pointer is aligned +``` + +**Usage Example:** +```cpp +void MyKernel(tvm::ffi::TensorView input, tvm::ffi::TensorView output) { + // Get dimensions + int64_t batch = input.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + + // Get data pointer + float* in_data = static_cast(input.data_ptr()); + float* out_data = static_cast(output.data_ptr()); + + // Check device + DLDevice dev = input.device(); + if (dev.device_type == kDLCUDA) { + // CUDA-specific code + } +} +``` + +### Tensor - Managed Owning Container + +**Purpose:** Reference-counted tensor with memory ownership. Use when you need to allocate or store tensors. + +**All TensorView methods** (Tensor has same interface as TensorView) + +**Static Factory Methods:** +```cpp +// From DLPack +static Tensor FromDLPack(DLManagedTensor* tensor, + size_t require_alignment = 0, + bool require_contiguous = false) + +static Tensor FromDLPackVersioned(DLManagedTensorVersioned* tensor, + size_t require_alignment = 0, + bool require_contiguous = false) + +// Allocate from environment (recommended for kernel outputs) +static Tensor FromEnvAlloc(int (*env_alloc)(DLTensor*, TVMFFIObjectHandle*), + ffi::ShapeView shape, + DLDataType dtype, + DLDevice device) + +// Example: Allocate tensor using environment allocator +ffi::Tensor output = ffi::Tensor::FromEnvAlloc( + TVMFFIEnvTensorAlloc, shape, dtype, device); +``` + +**Conversion Methods:** +```cpp +DLManagedTensor* ToDLPack() const // Export to DLPack +DLManagedTensorVersioned* ToDLPackVersioned() const +const DLTensor* GetDLTensorPtr() const // Get underlying DLTensor +``` + +### Utility Functions +```cpp +// Check if tensor is contiguous +bool IsContiguous(const DLTensor& arr) + +// Check alignment +bool IsAligned(const DLTensor& arr, size_t alignment) + +// Check if device supports direct addressing (CPU, CUDA, etc.) +bool IsDirectAddressDevice(const DLDevice& device) + +// Calculate data size in bytes +size_t GetDataSize(size_t numel, DLDataType dtype) +size_t GetDataSize(const DLTensor& arr) +size_t GetDataSize(const Tensor& tensor) +size_t GetDataSize(const TensorView& tensor) +``` + +### Device Type Constants +```cpp +kDLCPU = 1 // CPU +kDLCUDA = 2 // CUDA GPU +kDLCUDAHost = 3 // CUDA pinned memory +kDLROCM = 10 // ROCm/HIP +kDLMetal = 8 // Metal (Apple) +kDLVulkan = 7 // Vulkan +``` + +### Data Type Constants (DLDataType) +```cpp +// DLDataType structure: { uint8_t code, uint8_t bits, uint16_t lanes } + +// Type codes (code field) +kDLInt = 0 // Signed integer +kDLUInt = 1 // Unsigned integer +kDLFloat = 2 // IEEE floating point +kDLBfloat = 4 // Brain floating point +kDLBool = 5 // Boolean +kDLFloat8_e4m3fn = 6 // FP8 E4M3 +kDLFloat8_e5m2 = 7 // FP8 E5M2 +// ... and more FP8 variants + +// Common data types (examples): +// float32: code=kDLFloat, bits=32, lanes=1 +// float16: code=kDLFloat, bits=16, lanes=1 +// int32: code=kDLInt, bits=32, lanes=1 +// uint8: code=kDLUInt, bits=8, lanes=1 +// bfloat16: code=kDLBfloat, bits=16, lanes=1 +``` + +--- + +## 1.2 Array (tvm/ffi/container/array.h) + +**Purpose:** Dynamic array container with copy-on-write semantics. Similar to `std::vector` but FFI-compatible. + +**Type Parameter:** `T` must be compatible with `tvm::ffi::Any` (ObjectRef types, primitives via TypeTraits) + +**Constructors:** +```cpp +Array() // Empty array +Array(size_t n, const T& val) // n copies of val +Array(std::initializer_list init) // From initializer list +Array(const std::vector& vec) // From std::vector +Array(IterType first, IterType last) // From iterator range +``` + +**Element Access:** +```cpp +T operator[](int64_t i) const // Read i-th element +T front() const // First element +T back() const // Last element +T at(int64_t i) const // Bounds-checked access +``` + +**Capacity:** +```cpp +size_t size() const // Number of elements +size_t capacity() const // Allocated capacity +bool empty() const // Check if empty +void reserve(int64_t n) // Reserve capacity +void resize(int64_t n) // Resize array +``` + +**Modifiers (Copy-on-Write):** +```cpp +void push_back(const T& item) // Add element to end +void emplace_back(Args&&... args) // Construct element in-place +void pop_back() // Remove last element +void insert(iterator pos, const T& val) // Insert at position +void erase(iterator pos) // Remove at position +void erase(iterator first, iterator last) // Remove range +void clear() // Remove all elements +void Set(int64_t i, T value) // Set i-th element +``` + +**Iterators:** +```cpp +iterator begin() const // Begin iterator +iterator end() const // End iterator +reverse_iterator rbegin() const // Reverse begin +reverse_iterator rend() const // Reverse end +``` + +**Functional Operations:** +```cpp +// Map function over array +template +Array Map(F fmap) const // Returns Array + +// Mutate array in-place (if unique owner) +template +void MutateByApply(F fmutate) // F: T -> T +``` + +**Static Methods:** +```cpp +// Aggregate multiple arrays/elements +template +static Array Agregate(Args... args) // Combine T and Array args +``` + +**Usage Example:** +```cpp +// Create array +Array dims = {1, 3, 224, 224}; + +// Access elements +int64_t batch = dims[0]; + +// Modify (copy-on-write) +dims.push_back(512); +dims.Set(0, 2); + +// Map operation +Array doubled = dims.Map([](int64_t x) { return x * 2; }); + +// Iterate +for (int64_t d : dims) { + // process d +} +``` + +--- + +## 1.3 Map (tvm/ffi/container/map.h) + +**Purpose:** Hash map container with copy-on-write semantics. Similar to `std::unordered_map`. + +**Type Parameters:** +- `K` (key type) must be hashable and compatible with Any +- `V` (value type) must be compatible with Any + +**Constructors:** +```cpp +Map() // Empty map +Map(std::initializer_list> init) +Map(IterType first, IterType last) // From iterator range +Map(const std::unordered_map& map) // From std::unordered_map +``` + +**Element Access:** +```cpp +V at(const K& key) const // Throws if key not found +V operator[](const K& key) const // Throws if key not found +Optional Get(const K& key) const // Returns Optional (nullopt if not found) +V Get(const K& key, const V& default_value) const // Returns default if not found +``` + +**Capacity:** +```cpp +size_t size() const // Number of elements +size_t count(const K& key) const // 0 or 1 (check if key exists) +bool empty() const // Check if empty +``` + +**Modifiers (Copy-on-Write):** +```cpp +void Set(const K& key, const V& value) // Insert or update +void erase(const K& key) // Remove key +void clear() // Remove all elements +``` + +**Iterators:** +```cpp +iterator begin() const // Begin iterator +iterator end() const // End iterator + +// Iterator dereference returns std::pair +for (auto kv : map) { + K key = kv.first; + V value = kv.second; +} +``` + +**Static Methods:** +```cpp +// Construct map from keys and values +static Map FromItems(const Array>& items) +``` + +**Usage Example:** +```cpp +// Create map +Map config = { + {"batch_size", 32}, + {"num_layers", 12} +}; + +// Access +int64_t batch = config["batch_size"]; +Optional opt = config.Get("hidden_dim"); +if (opt.defined()) { + int64_t hidden = opt.value(); +} + +// Modify +config.Set("hidden_dim", 768); +config.erase("num_layers"); + +// Iterate +for (auto kv : config) { + String key = kv.first; + int64_t value = kv.second; +} +``` + +--- + +## 1.4 Tuple (tvm/ffi/container/tuple.h) + +**Purpose:** Fixed-size heterogeneous container. Similar to `std::tuple` but FFI-compatible. + +**Constructors:** +```cpp +Tuple() // Default construct +Tuple(T1 v1, T2 v2, ...) // From values +``` + +**Element Access:** +```cpp +// Compile-time index access +std::get<0>(tuple) // Get first element +std::get<1>(tuple) // Get second element + +// Runtime index access (returns Any) +Any operator[](int64_t idx) const // Get element at runtime index +``` + +**Capacity:** +```cpp +static constexpr size_t size() // Number of elements (compile-time) +``` + +**Structured Binding:** +```cpp +auto [a, b, c] = tuple; // C++17 structured binding +``` + +**Usage Example:** +```cpp +// Create tuple +Tuple result = {42, 3.14f, "success"}; + +// Access elements +int64_t code = std::get<0>(result); +float value = std::get<1>(result); +String msg = std::get<2>(result); + +// Or with structured binding +auto [code, value, msg] = result; +``` + +--- + +## 1.5 Shape & ShapeView (tvm/ffi/container/shape.h) + +### ShapeView - Lightweight Non-Owning View + +**Purpose:** Non-owning view of shape dimensions. Use for passing shapes without allocation. + +**Constructors:** +```cpp +ShapeView() // Empty +ShapeView(const int64_t* data, size_t size) // From pointer and size +ShapeView(std::initializer_list init) // From initializer list +``` + +**Element Access:** +```cpp +int64_t operator[](size_t idx) const // Access dimension +int64_t at(size_t idx) const // Bounds-checked access +int64_t front() const // First dimension +int64_t back() const // Last dimension +``` + +**Properties:** +```cpp +const int64_t* data() const // Data pointer +size_t size() const // Number of dimensions +bool empty() const // Check if empty +int64_t Product() const // Product of all dimensions +``` + +**Iterators:** +```cpp +const int64_t* begin() const +const int64_t* end() const +``` + +### Shape - Managed Owning Container + +**Purpose:** Reference-counted shape with memory ownership. + +**Constructors:** +```cpp +Shape() // Empty +Shape(std::initializer_list init) +Shape(std::vector vec) +Shape(Array arr) +Shape(ShapeView view) +Shape(IterType first, IterType last) +``` + +**All ShapeView methods** (same interface) + +**Static Factory:** +```cpp +// Create strides from shape (row-major) +static Shape StridesFromShape(ShapeView shape) +``` + +**Conversion:** +```cpp +operator ShapeView() const // Implicit conversion to ShapeView +``` + +**Usage Example:** +```cpp +// Create shape +Shape shape = {2, 3, 224, 224}; // NCHW + +// Access +int64_t batch = shape[0]; +int64_t channels = shape[1]; + +// Calculate total elements +int64_t numel = shape.Product(); // 2 * 3 * 224 * 224 + +// Create strides +Shape strides = Shape::StridesFromShape(shape); +// strides = {3*224*224, 224*224, 224, 1} +``` + +--- + +## 1.6 Variant (tvm/ffi/container/variant.h) + +**Purpose:** Type-safe union that can hold one of several types. Similar to `std::variant`. + +**Constructors:** +```cpp +Variant(T1 val) // Construct with T1 +Variant(T2 val) // Construct with T2 +``` + +**Type Checking and Casting:** +```cpp +// Try cast (returns std::optional) +template +std::optional as() const // Returns nullopt if wrong type + +// Cast (throws on failure) +template +T get() const& // Copy value +T get() && // Move value + +// Get type information +std::string GetTypeKey() const // Get type key string +``` + +**Usage Example:** +```cpp +using Value = Variant; + +Value v1 = 42; +Value v2 = 3.14f; +Value v3 = String("hello"); + +// Check and extract +if (auto opt_int = v1.as()) { + int64_t val = *opt_int; +} + +// Or direct cast (throws if wrong type) +int64_t val = v1.get(); +``` + +--- + +## 1.7 String (tvm/ffi/string.h) + +**Purpose:** FFI-compatible string with small-string optimization. + +**Constructors:** +```cpp +String() // Empty string +String(const char* str) +String(const std::string& str) +String(std::string&& str) // Move from std::string +String(const char* data, size_t size) +``` + +**Properties:** +```cpp +const char* data() const // Data pointer +const char* c_str() const // Null-terminated C string +size_t size() const // Length +size_t length() const // Length (same as size) +bool empty() const // Check if empty +``` + +**Element Access:** +```cpp +char at(size_t pos) const // Bounds-checked access +``` + +**Comparison:** +```cpp +int compare(const String& other) const +int compare(const std::string& other) const +int compare(const char* other) const + +// Operators: ==, !=, <, >, <=, >= +``` + +**Concatenation:** +```cpp +String operator+(const String& lhs, const String& rhs) +String operator+(const String& lhs, const char* rhs) +// ... and other combinations +``` + +**Conversion:** +```cpp +operator std::string() const // Convert to std::string +``` + +**Utility:** +```cpp +String EscapeString(const String& value) // Escape for JSON/C++ +``` + +--- + +# Part 2: Data Type Utilities (tvm/ffi/dtype.h) + +## DLDataType Structure + +```cpp +struct DLDataType { + uint8_t code; // Type code (kDLInt, kDLFloat, etc.) + uint8_t bits; // Number of bits + uint16_t lanes; // Number of lanes (for vector types) +}; +``` + +## Type Code Constants + +```cpp +kDLInt = 0 // Signed integer +kDLUInt = 1 // Unsigned integer +kDLFloat = 2 // IEEE floating point +kDLOpaqueHandle = 3 // Opaque handle +kDLBfloat = 4 // Brain floating point +kDLBool = 5 // Boolean +kDLFloat8_e4m3fn = 6 +kDLFloat8_e5m2 = 7 +// ... many more FP8 variants +``` + +## Conversion Functions + +```cpp +// String to DLDataType +DLDataType StringToDLDataType(const String& str) +// Examples: "float32", "int64", "float16", "uint8" + +// DLDataType to String +String DLDataTypeToString(DLDataType dtype) + +// Get type code as string (internal) +const char* DLDataTypeCodeAsCStr(DLDataTypeCode type_code) +``` + +## Operators + +```cpp +// Comparison +bool operator==(const DLDataType& lhs, const DLDataType& rhs) +bool operator!=(const DLDataType& lhs, const DLDataType& rhs) + +// Output stream +std::ostream& operator<<(std::ostream& os, DLDataType dtype) +``` + +**Usage Example:** +```cpp +// Check dtype +DLDataType dt = tensor.dtype(); +if (dt.code == kDLFloat && dt.bits == 32) { + // float32 tensor +} + +// Convert from string +DLDataType float16_dtype = StringToDLDataType("float16"); + +// Convert to string +String dtype_str = DLDataTypeToString(dt); // "float32" +``` + +--- + +# Part 3: Error Handling (tvm/ffi/error.h) + +## Error Class + +```cpp +class Error : public std::exception { +public: + Error(std::string kind, std::string message, std::string backtrace); + + std::string kind() const; // Error category + std::string message() const; // Error message + std::string backtrace() const; // Stack trace + std::string TracebackMostRecentCallLast() const; // Python-style traceback + const char* what() const noexcept override; // Standard exception interface +}; +``` + +## Throwing Errors + +```cpp +// Throw with automatic backtrace +TVM_FFI_THROW(ErrorKind) << "message" << variable << "more text" + +// Common error kinds: +TVM_FFI_THROW(ValueError) << "Invalid value: " << x; +TVM_FFI_THROW(TypeError) << "Expected float32, got " << dtype; +TVM_FFI_THROW(RuntimeError) << "CUDA error: " << error_string; +TVM_FFI_THROW(InternalError) << "Unexpected condition"; +TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds"; + +// Log to stderr and throw (for startup functions) +TVM_FFI_LOG_AND_THROW(ErrorKind) << "message"; +``` + +## Assertions (Internal Logic Checks) + +```cpp +// General check +TVM_FFI_ICHECK(condition) << "message" + +// Comparison checks (more informative error messages) +TVM_FFI_ICHECK_EQ(x, y) << "x must equal y" // x == y +TVM_FFI_ICHECK_NE(x, y) << "x must not equal y" // x != y +TVM_FFI_ICHECK_LT(x, y) << "x must be less than y" // x < y +TVM_FFI_ICHECK_LE(x, y) << "x must be at most y" // x <= y +TVM_FFI_ICHECK_GT(x, y) << "x must be greater than y" // x > y +TVM_FFI_ICHECK_GE(x, y) << "x must be at least y" // x >= y +TVM_FFI_ICHECK_NOTNULL(ptr) << "ptr must not be null" // ptr != nullptr + +// Custom error type check +TVM_FFI_CHECK(condition, ErrorKind) << "message" +``` + +## Environment Error Handling + +```cpp +// Check for interrupts (e.g., Python Ctrl+C) +int TVMFFIEnvCheckSignals() + +// Exception for pre-existing errors in environment +throw tvm::ffi::EnvErrorAlreadySet(); + +// Usage in long-running functions +void LongRunningFunction() { + if (TVMFFIEnvCheckSignals() != 0) { + throw ::tvm::ffi::EnvErrorAlreadySet(); + } + // ... do work +} +``` + +## Safe Call Wrappers (for C API) + +```cpp +// Wrap C++ code for C API +TVM_FFI_SAFE_CALL_BEGIN(); +// C++ code that may throw +TVM_FFI_SAFE_CALL_END(); + +// Check C function return codes +TVM_FFI_CHECK_SAFE_CALL(function_call); +``` + +**Usage Examples:** +```cpp +// User input validation +void MyKernel(tvm::ffi::TensorView x, tvm::ffi::TensorView y) { + if (x.ndim() != 2) { + TVM_FFI_THROW(ValueError) << "Expected 2D tensor, got " << x.ndim() << "D"; + } + + DLDataType dt = x.dtype(); + if (dt.code != kDLFloat || dt.bits != 32) { + TVM_FFI_THROW(TypeError) << "Expected float32, got " << DLDataTypeToString(dt); + } + + // Internal consistency checks + TVM_FFI_ICHECK_EQ(x.size(1), y.size(0)) << "Dimension mismatch for matmul"; +} +``` + +--- + +# Part 4: Function Export (tvm/ffi/function.h) + +## Export Macro + +```cpp +// Export typed C++ function for FFI +TVM_FFI_DLL_EXPORT_TYPED_FUNC(export_name, cpp_function) + +// Example: +void MyKernel(tvm::ffi::TensorView a, tvm::ffi::TensorView b) { + // kernel implementation +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(my_kernel, MyKernel); + +// Or with lambda +TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, [](int x) { return x + 1; }); +``` + +## Supported Function Signatures + +The macro supports automatic type marshalling for: + +**Primitive Types:** +- `int32_t`, `int64_t` +- `uint32_t`, `uint64_t` +- `float`, `double` +- `bool` + +**FFI Types:** +- `tvm::ffi::TensorView`, `tvm::ffi::Tensor` +- `tvm::ffi::String`, `std::string`, `const char*` +- `tvm::ffi::Array`, `tvm::ffi::Map`, `tvm::ffi::Tuple<...>` +- `tvm::ffi::Shape`, `tvm::ffi::ShapeView` +- `tvm::ffi::Any`, `tvm::ffi::Optional` +- `DLDataType`, `DLDevice` + +**Return Types:** +- `void` (no return) +- Any of the above types + +**Example Signatures:** +```cpp +void func1(TensorView t1, TensorView t2) +int64_t func2(TensorView t) +Tensor func3(TensorView t, int64_t size, float alpha) +Array func4(Shape shape, String name) +Tuple func5(TensorView t) +``` + +## Function Class (for dynamic usage) + +```cpp +class Function { +public: + // Get global function by name + static std::optional GetGlobal(std::string_view name); + static Function GetGlobalRequired(std::string_view name); // Throws if not found + + // Set global function + static void SetGlobal(std::string_view name, Function func, bool override = false); + + // List all global function names + static std::vector ListGlobalNames(); + + // Remove global function + static void RemoveGlobal(const String& name); + + // Call function with arguments + template + Any operator()(Args&&... args) const; + + // Create from typed function + template + static Function FromTyped(TCallable callable); + + template + static Function FromTyped(TCallable callable, std::string name); +}; +``` + +**Usage Example:** +```cpp +// Get and call existing function +auto opt_func = Function::GetGlobal("my.function.name"); +if (opt_func.has_value()) { + Any result = (*opt_func)(arg1, arg2); +} + +// Create and register function +Function func = Function::FromTyped([](int64_t x) { return x * 2; }); +Function::SetGlobal("double", func); +``` + +## TypedFunction + +```cpp +// Type-safe function wrapper +TypedFunction add_func = [](int a, int b) { return a + b; }; + +// Call with type checking +int result = add_func(1, 2); + +// Convert to/from Function +Function erased = add_func; +TypedFunction typed = TypedFunction(erased); +``` + +--- + +# Part 5: Module Loading (tvm/ffi/extra/module.h) + +**Purpose:** Load and use functions from other compiled modules (for importing helper kernels). + +## Module Class + +```cpp +class Module { +public: + // Load module from file + static Module LoadFromFile(const String& file_name); + + // Get function from module + Optional GetFunction(const String& name, bool query_imports = true); + + // Check if function exists + bool ImplementsFunction(const String& name, bool query_imports = true); + + // Get function metadata + Optional GetFunctionDoc(const String& name, bool query_imports = true); + Optional GetFunctionMetadata(const String& name, bool query_imports = true); + + // Import another module + void ImportModule(const Module& other); + + // Export module + void WriteToFile(const String& file_name, const String& format) const; + Bytes SaveToBytes() const; + String InspectSource(const String& format) const; + Array GetWriteFormats() const; +}; +``` + +## Module Properties + +```cpp +enum ModulePropertyMask { + kBinarySerializable = 0b001, // Can be serialized to bytes + kRunnable = 0b010, // Has runnable functions + kCompilationExportable = 0b100 // Can export to .o/.cc/.cu +}; +``` + +## Symbol Names + +```cpp +namespace symbol { + constexpr const char* tvm_ffi_symbol_prefix = "__tvm_ffi_"; + constexpr const char* tvm_ffi_main = "__tvm_ffi_main"; + constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi__library_ctx"; + constexpr const char* tvm_ffi_library_bin = "__tvm_ffi__library_bin"; + constexpr const char* tvm_ffi_metadata_prefix = "__tvm_ffi__metadata_"; +} +``` + +**Usage Example:** +```cpp +// Load helper module +Module helpers = Module::LoadFromFile("helpers.so"); + +// Get function from module +Optional opt_func = helpers->GetFunction("helper_kernel"); +if (opt_func.defined()) { + Function helper = opt_func.value(); + // Call helper function + helper(tensor1, tensor2); +} +``` + +--- + +# Part 6: Environment APIs (tvm/ffi/extra/c_env_api.h) + +## CUDA Stream Management + +```cpp +// Get current CUDA stream for device +TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t device_type, int32_t device_id) + +// Set current CUDA stream for device +int TVMFFIEnvSetStream(int32_t device_type, + int32_t device_id, + TVMFFIStreamHandle stream, + TVMFFIStreamHandle* opt_out_original_stream) + +// Usage in kernel +DLDevice dev = tensor.device(); +cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(dev.device_type, dev.device_id)); + +// Launch kernel on stream +my_kernel<<>>(...); +``` + +## Tensor Allocation + +```cpp +// Allocate tensor using environment allocator (respects TLS/global allocator) +int TVMFFIEnvTensorAlloc(DLTensor* prototype, TVMFFIObjectHandle* out) + +// Get/Set allocator +DLPackManagedTensorAllocator TVMFFIEnvGetDLPackManagedTensorAllocator() + +int TVMFFIEnvSetDLPackManagedTensorAllocator( + DLPackManagedTensorAllocator allocator, + int write_to_global_context, + DLPackManagedTensorAllocator* opt_out_original_allocator) + +// Usage: Allocate output tensor +ffi::Tensor output = ffi::Tensor::FromEnvAlloc( + TVMFFIEnvTensorAlloc, + shape, // ffi::ShapeView + dtype, // DLDataType + device // DLDevice +); +``` + +## Module Symbol Management + +```cpp +// Lookup function from module imports +int TVMFFIEnvModLookupFromImports(TVMFFIObjectHandle library_ctx, + const char* func_name, + TVMFFIObjectHandle* out) + +// Register context symbols (available when library is loaded) +int TVMFFIEnvModRegisterContextSymbol(const char* name, void* symbol) + +// Register system library symbols +int TVMFFIEnvModRegisterSystemLibSymbol(const char* name, void* symbol) + +// Register C API symbols +int TVMFFIEnvRegisterCAPI(const char* name, void* symbol) +``` + +## Signal Checking + +```cpp +// Check for environment signals (e.g., Python Ctrl+C interrupts) +int TVMFFIEnvCheckSignals() + +// Usage in long-running loops +for (int i = 0; i < iterations; ++i) { + if (TVMFFIEnvCheckSignals() != 0) { + throw tvm::ffi::EnvErrorAlreadySet(); + } + // ... do work +} +``` + +--- + +# Common Patterns and Examples + +## Pattern 1: Basic CUDA Kernel + +```cpp +#include +#include +#include +#include + +__global__ void AddKernel(float* a, float* b, float* c, int64_t n) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] + b[idx]; + } +} + +void Add(tvm::ffi::TensorView a, tvm::ffi::TensorView b, tvm::ffi::TensorView c) { + // Validate inputs + TVM_FFI_ICHECK_EQ(a.ndim(), 1) << "a must be 1D"; + TVM_FFI_ICHECK_EQ(b.ndim(), 1) << "b must be 1D"; + TVM_FFI_ICHECK_EQ(c.ndim(), 1) << "c must be 1D"; + TVM_FFI_ICHECK_EQ(a.size(0), b.size(0)) << "Shape mismatch"; + TVM_FFI_ICHECK_EQ(a.size(0), c.size(0)) << "Shape mismatch"; + + // Check dtype + if (a.dtype().code != kDLFloat || a.dtype().bits != 32) { + TVM_FFI_THROW(TypeError) << "Expected float32 tensor"; + } + + // Get data + float* a_data = static_cast(a.data_ptr()); + float* b_data = static_cast(b.data_ptr()); + float* c_data = static_cast(c.data_ptr()); + int64_t n = a.size(0); + + // Get CUDA stream + DLDevice dev = a.device(); + cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(dev.device_type, dev.device_id)); + + // Launch kernel + int64_t threads = 256; + int64_t blocks = (n + threads - 1) / threads; + AddKernel<<>>(a_data, b_data, c_data, n); +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(add, Add); +``` + +## Pattern 2: Multi-Dimensional Tensor Processing + +```cpp +void Process2D(tvm::ffi::TensorView input, tvm::ffi::TensorView output) { + // Validate + TVM_FFI_ICHECK_EQ(input.ndim(), 2) << "Expected 2D tensor"; + TVM_FFI_ICHECK_EQ(output.ndim(), 2) << "Expected 2D tensor"; + + // Get dimensions + int64_t height = input.size(0); + int64_t width = input.size(1); + TVM_FFI_ICHECK_EQ(output.size(0), height); + TVM_FFI_ICHECK_EQ(output.size(1), width); + + // Check contiguous + TVM_FFI_ICHECK(input.IsContiguous()) << "Input must be contiguous"; + + // Access with strides (more general) + int64_t stride_h = input.stride(0); + int64_t stride_w = input.stride(1); + + // For contiguous tensors: element[i][j] at data[i * stride_h + j * stride_w] +} +``` + +## Pattern 3: Allocating Output Tensors + +```cpp +tvm::ffi::Tensor AllocateOutput(tvm::ffi::TensorView input, + tvm::ffi::ShapeView output_shape) { + // Create output with same device and dtype as input + DLDevice device = input.device(); + DLDataType dtype = input.dtype(); + + // Allocate using environment allocator + tvm::ffi::Tensor output = tvm::ffi::Tensor::FromEnvAlloc( + TVMFFIEnvTensorAlloc, + output_shape, + dtype, + device + ); + + return output; +} + +// Usage +void MyKernel(tvm::ffi::TensorView input) { + // Create output shape + Shape out_shape = {input.size(0), input.size(1) * 2}; + + // Allocate + Tensor output = AllocateOutput(input, out_shape); + + // Now use output.data_ptr() in kernel +} +``` + +## Pattern 4: Device-Specific Dispatch + +```cpp +void UniversalKernel(tvm::ffi::TensorView input, tvm::ffi::TensorView output) { + DLDevice dev = input.device(); + + switch (dev.device_type) { + case kDLCUDA: { + // CUDA implementation + cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(dev.device_type, dev.device_id)); + // Launch CUDA kernel + break; + } + case kDLCPU: { + // CPU implementation + float* in_data = static_cast(input.data_ptr()); + float* out_data = static_cast(output.data_ptr()); + // ... CPU code + break; + } + case kDLROCM: { + // ROCm/HIP implementation + break; + } + default: + TVM_FFI_THROW(RuntimeError) << "Unsupported device type: " << dev.device_type; + } +} +``` + +## Pattern 5: Using Configuration Parameters + +```cpp +void ConfigurableKernel(tvm::ffi::TensorView input, + tvm::ffi::TensorView output, + tvm::ffi::Map config) { + // Extract config with defaults + int64_t block_size = config.Get("block_size", 256); + int64_t num_threads = config.Get("num_threads", 1024); + + // Or check if key exists + if (config.count("custom_param") > 0) { + int64_t custom = config.at("custom_param"); + // use custom + } + + // Launch with config + // ... +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(configurable_kernel, ConfigurableKernel); +``` + +## Pattern 6: Returning Multiple Values + +```cpp +// Return tuple +tvm::ffi::Tuple +ComputeWithStats(tvm::ffi::TensorView input) { + // Allocate output + tvm::ffi::Tensor output = AllocateOutput(input, input.shape()); + + // Compute statistics + int64_t count = input.numel(); + float mean = 0.0f; // ... compute mean + + // Return multiple values + return {output, count, mean}; +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(compute_with_stats, ComputeWithStats); +``` + +## Pattern 7: Using Arrays for Variable Arguments + +```cpp +void BatchProcess(tvm::ffi::Array inputs, + tvm::ffi::Array outputs) { + TVM_FFI_ICHECK_EQ(inputs.size(), outputs.size()) << "Size mismatch"; + + for (size_t i = 0; i < inputs.size(); ++i) { + TensorView in = inputs[i]; + TensorView out = outputs[i]; + // Process each pair + } +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(batch_process, BatchProcess); +``` + +--- + +# Quick Reference Summary + +## Most Common Imports + +```cpp +#include // TensorView, Tensor +#include // TVM_FFI_DLL_EXPORT_TYPED_FUNC +#include // TVM_FFI_THROW, TVM_FFI_ICHECK +#include // TVMFFIEnvGetStream +``` + +## Most Common Function Signature + +```cpp +void MyKernel(tvm::ffi::TensorView input, tvm::ffi::TensorView output) { + // 1. Validate inputs + // 2. Get CUDA stream + // 3. Launch kernel +} +TVM_FFI_DLL_EXPORT_TYPED_FUNC(my_kernel, MyKernel); +``` + +## Most Common Validation Pattern + +```cpp +// Shape +TVM_FFI_ICHECK_EQ(tensor.ndim(), 2); +TVM_FFI_ICHECK_EQ(tensor.size(0), expected_dim0); + +// Dtype +DLDataType dt = tensor.dtype(); +TVM_FFI_ICHECK_EQ(dt.code, kDLFloat); +TVM_FFI_ICHECK_EQ(dt.bits, 32); + +// Device +TVM_FFI_ICHECK_EQ(tensor.device().device_type, kDLCUDA); + +// Memory layout +TVM_FFI_ICHECK(tensor.IsContiguous()); +``` + +## Most Common CUDA Launch Pattern + +```cpp +DLDevice dev = tensor.device(); +cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(dev.device_type, dev.device_id)); + +float* data = static_cast(tensor.data_ptr()); +int64_t n = tensor.numel(); + +int threads = 256; +int blocks = (n + threads - 1) / threads; +MyKernel<<>>(data, n); +``` + +--- + +This completes the comprehensive TVM FFI API reference. Use this as your guide for writing FFI-compatible CUDA kernels and host functions. +""" + +FFI_PROMPT = """ +You should use TVM FFI bindings for your code. + +# Required Headers + +```cpp +#include // Tensor, TensorView +#include // Function export macros +#include // Error handling +#include // Environment APIs (streams, allocators) +``` + +# Complete Example: CUDA Kernel Binding + +```cpp +// File: add_one_cuda.cu +#include +#include +#include +#include + +namespace my_kernels { + +__global__ void AddOneKernel(float* x, float* y, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + y[idx] = x[idx] + 1.0f; + } +} + +void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) { + // Input validation + TVM_FFI_ICHECK_EQ(x.ndim(), 1) << "x must be 1D"; + TVM_FFI_ICHECK_EQ(y.ndim(), 1) << "y must be 1D"; + TVM_FFI_ICHECK_EQ(x.size(0), y.size(0)) << "Shape mismatch"; + + // Get data pointers + float* x_data = static_cast(x.data_ptr()); + float* y_data = static_cast(y.data_ptr()); + int64_t n = x.size(0); + + // Get CUDA stream from environment + DLDevice dev = x.device(); + cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(dev.device_type, dev.device_id)); + + // Launch kernel + int64_t threads = 256; + int64_t blocks = (n + threads - 1) / threads; + AddOneKernel<<>>(x_data, y_data, n); +} + +// Export the function with name "add_one_cuda" +TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, AddOne); + +} // namespace my_kernels +``` + +# TVM FFI API Documentation + +## 1. Tensor Container API (tvm/ffi/container/tensor.h) + +### Tensor Class +A managed n-dimensional array with reference counting. + +**Methods:** +- `void* data_ptr() const` - Returns raw data pointer +- `DLDevice device() const` - Returns device info (`.device_type`, `.device_id`) +- `int32_t ndim() const` - Returns number of dimensions +- `DLDataType dtype() const` - Returns data type (`.code`, `.bits`, `.lanes`) +- `ShapeView shape() const` - Returns shape array (indexable: `shape()[0]`, `shape()[1]`, ...) +- `ShapeView strides() const` - Returns strides array +- `int64_t size(int64_t idx) const` - Returns size of dimension at idx (negative idx counts from end) +- `int64_t stride(int64_t idx) const` - Returns stride of dimension at idx +- `int64_t numel() const` - Returns total number of elements +- `uint64_t byte_offset() const` - Returns byte offset +- `bool IsContiguous() const` - Checks if tensor memory is contiguous +- `bool IsAligned(size_t alignment) const` - Checks if data is aligned to given bytes + +**Static Factory Methods:** +- `static Tensor FromDLPack(DLManagedTensor* tensor, size_t require_alignment = 0, bool require_contiguous = false)` +- `static Tensor FromDLPackVersioned(DLManagedTensorVersioned* tensor, size_t require_alignment = 0, bool require_contiguous = false)` +- `template static Tensor FromNDAlloc(TNDAlloc alloc, ffi::ShapeView shape, DLDataType dtype, DLDevice device, ExtraArgs&&... extra_args)` +- `static Tensor FromEnvAlloc(int (*env_alloc)(DLTensor*, TVMFFIObjectHandle*), ffi::ShapeView shape, DLDataType dtype, DLDevice device)` + +**Conversion Methods:** +- `DLManagedTensor* ToDLPack() const` - Convert to DLPack managed tensor +- `DLManagedTensorVersioned* ToDLPackVersioned() const` - Convert to versioned DLPack +- `const DLTensor* GetDLTensorPtr() const` - Get underlying DLTensor pointer + +### TensorView Class +Non-owning lightweight view of a Tensor. Kernel entrypoints should use `tvm::ffi::TensorView` (or `const TensorView&`) for tensor inputs/outputs. + +**Constructors:** +- `TensorView(const Tensor& tensor)` - From Tensor +- `TensorView(const DLTensor* tensor)` - From DLTensor pointer + +**Methods (same interface as Tensor):** +- `void* data_ptr() const` +- `DLDevice device() const` +- `int32_t ndim() const` +- `DLDataType dtype() const` +- `ShapeView shape() const` +- `ShapeView strides() const` +- `int64_t size(int64_t idx) const` +- `int64_t stride(int64_t idx) const` +- `int64_t numel() const` +- `uint64_t byte_offset() const` +- `bool IsContiguous() const` + +### Utility Functions +- `bool IsContiguous(const DLTensor& arr)` - Check if DLTensor is contiguous +- `bool IsAligned(const DLTensor& arr, size_t alignment)` - Check alignment +- `bool IsDirectAddressDevice(const DLDevice& device)` - Check if device uses direct addressing +- `size_t GetDataSize(size_t numel, DLDataType dtype)` - Calculate bytes for packed data +- `size_t GetDataSize(const DLTensor& arr)` - Calculate bytes in DLTensor +- `size_t GetDataSize(const Tensor& tensor)` - Calculate bytes in Tensor +- `size_t GetDataSize(const TensorView& tensor)` - Calculate bytes in TensorView + +### Device Type Constants (DLDevice.device_type) +```cpp +kDLCPU = 1 // CPU +kDLCUDA = 2 // CUDA GPU +kDLCUDAHost = 3 // CUDA pinned memory +kDLROCM = 10 // ROCm/HIP +``` + +### Data Type Constants (DLDataType) +```cpp +// DLDataType has: uint8_t code, uint8_t bits, uint16_t lanes +kDLInt = 0 // Signed integer +kDLUInt = 1 // Unsigned integer +kDLFloat = 2 // IEEE floating point +kDLBfloat = 4 // Brain floating point + +// Example: float32 has code=2 (kDLFloat), bits=32, lanes=1 +// Example: half/fp16 has code=2, bits=16, lanes=1 +``` + +## 2. Function API (tvm/ffi/function.h) + +### Function Class +Type-erased callable object. + +**Constructors:** +- `Function(std::nullptr_t)` - Null constructor +- `template Function(TCallable packed_call)` - From callable (legacy) + +**Static Factory Methods:** +- `template static Function FromPacked(TCallable packed_call)` - From packed signature: `void(const AnyView*, int32_t, Any*)` or `void(PackedArgs, Any*)` +- `template static Function FromTyped(TCallable callable)` - From typed C++ function +- `template static Function FromTyped(TCallable callable, std::string name)` - With name for error messages +- `static Function FromExternC(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self))` - From C callback + +**Global Registry:** +- `static std::optional GetGlobal(std::string_view name)` +- `static std::optional GetGlobal(const std::string& name)` +- `static std::optional GetGlobal(const String& name)` +- `static std::optional GetGlobal(const char* name)` +- `static Function GetGlobalRequired(std::string_view name)` - Throws if not found +- `static Function GetGlobalRequired(const std::string& name)` +- `static Function GetGlobalRequired(const String& name)` +- `static Function GetGlobalRequired(const char* name)` +- `static void SetGlobal(std::string_view name, Function func, bool override = false)` +- `static std::vector ListGlobalNames()` +- `static void RemoveGlobal(const String& name)` + +**Invocation:** +- `template Any operator()(Args&&... args) const` - Call with unpacked args +- `void CallPacked(const AnyView* args, int32_t num_args, Any* result) const` +- `void CallPacked(PackedArgs args, Any* result) const` +- `template static Any InvokeExternC(void* handle, TVMFFISafeCallType safe_call, Args&&... args)` + +### TypedFunction Class +Type-safe wrapper around Function. + +**Constructors:** +- `TypedFunction()` - Default +- `TypedFunction(std::nullptr_t)` +- `TypedFunction(Function packed)` - From Function +- `template TypedFunction(FLambda typed_lambda)` - From lambda +- `template TypedFunction(FLambda typed_lambda, std::string name)` - With name + +**Methods:** +- `R operator()(Args... args) const` - Type-safe invocation +- `operator Function() const` - Convert to Function +- `const Function& packed() const&` - Get internal Function +- `Function&& packed() &&` - Move internal Function +- `static std::string TypeSchema()` - Get JSON type schema + +### PackedArgs Class +Represents packed arguments. + +**Constructor:** +- `PackedArgs(const AnyView* data, int32_t size)` + +**Methods:** +- `int size() const` - Number of arguments +- `const AnyView* data() const` - Raw argument array +- `PackedArgs Slice(int begin, int end = -1) const` - Get subset +- `AnyView operator[](int i) const` - Access argument +- `template static void Fill(AnyView* data, Args&&... args)` - Pack arguments + +### Export Macro +```cpp +// Export typed C++ function for FFI +TVM_FFI_DLL_EXPORT_TYPED_FUNC(export_name, cpp_function) + +// Example: +void MyKernel(tvm::ffi::TensorView a, tvm::ffi::TensorView b) { ... } +TVM_FFI_DLL_EXPORT_TYPED_FUNC(my_kernel, MyKernel); +``` + +**Supported function signatures:** +- `void func(TensorView t1, TensorView t2, ...)` +- `void func(TensorView t, int64_t size, float alpha, ...)` +- `int64_t func(TensorView t)` +- Any combination of: `TensorView`, `int32_t`, `int64_t`, `float`, `double`, `bool`, `std::string` + +## 3. Error Handling API (tvm/ffi/error.h) + +### Error Class +Exception object with stack trace. + +**Constructor:** +- `Error(std::string kind, std::string message, std::string backtrace)` +- `Error(std::string kind, std::string message, const TVMFFIByteArray* backtrace)` + +**Methods:** +- `std::string kind() const` - Error category +- `std::string message() const` - Error description +- `std::string backtrace() const` - Stack trace +- `std::string TracebackMostRecentCallLast() const` - Python-style traceback +- `const char* what() const noexcept override` - Standard exception interface +- `void UpdateBacktrace(const TVMFFIByteArray* backtrace_str, int32_t update_mode)` - Modify traceback + +### EnvErrorAlreadySet Exception +Thrown when error exists in frontend environment (e.g., Python interrupt). + +**Usage:** +```cpp +void LongRunningFunction() { + if (TVMFFIEnvCheckSignals() != 0) { + throw ::tvm::ffi::EnvErrorAlreadySet(); + } + // do work here +} +``` + +### Error Macros +```cpp +// Throw with backtrace +TVM_FFI_THROW(ErrorKind) << message + +// Log to stderr and throw (for startup functions) +TVM_FFI_LOG_AND_THROW(ErrorKind) << message + +// Check C function return code +TVM_FFI_CHECK_SAFE_CALL(func) + +// Wrap C++ code for C API +TVM_FFI_SAFE_CALL_BEGIN(); +// c++ code region here +TVM_FFI_SAFE_CALL_END(); +``` + +### Assertion Macros +```cpp +TVM_FFI_ICHECK(condition) << "message" // General check +TVM_FFI_CHECK(condition, ErrorKind) << "message" // Custom error type +TVM_FFI_ICHECK_EQ(x, y) << "x must equal y" // x == y +TVM_FFI_ICHECK_NE(x, y) << "x must not equal y" // x != y +TVM_FFI_ICHECK_LT(x, y) << "x must be less than y" // x < y +TVM_FFI_ICHECK_LE(x, y) << "x must be at most y" // x <= y +TVM_FFI_ICHECK_GT(x, y) << "x must be greater than y" // x > y +TVM_FFI_ICHECK_GE(x, y) << "x must be at least y" // x >= y +TVM_FFI_ICHECK_NOTNULL(ptr) << "ptr must not be null" // ptr != nullptr +``` + +**Common Error Kinds:** +- `ValueError` - Invalid argument values +- `TypeError` - Type mismatch +- `RuntimeError` - Runtime failures (CUDA errors, etc.) +- `InternalError` - Internal logic errors + +### Utility Functions +- `int32_t TypeKeyToIndex(std::string_view type_key)` - Get type index from key + +## 4. Environment API (tvm/ffi/extra/c_env_api.h) + +### Stream Management +```cpp +// Set current stream for a device +int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, + TVMFFIStreamHandle stream, + TVMFFIStreamHandle* opt_out_original_stream) + +// Get current stream for a device +TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t device_type, int32_t device_id) + +// Usage example: +DLDevice dev = tensor.device(); +cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(dev.device_type, dev.device_id)); +``` + +### Tensor Allocation +```cpp +// Set DLPack allocator (TLS or global) +int TVMFFIEnvSetDLPackManagedTensorAllocator( + DLPackManagedTensorAllocator allocator, + int write_to_global_context, + DLPackManagedTensorAllocator* opt_out_original_allocator) + +// Get current allocator +DLPackManagedTensorAllocator TVMFFIEnvGetDLPackManagedTensorAllocator() + +// Allocate tensor using environment allocator +int TVMFFIEnvTensorAlloc(DLTensor* prototype, TVMFFIObjectHandle* out) + +// Usage with FromEnvAlloc: +ffi::Tensor tensor = ffi::Tensor::FromEnvAlloc( + TVMFFIEnvTensorAlloc, shape, dtype, device); +``` + +### Module & Symbol Management +```cpp +// Lookup function from module imports +int TVMFFIEnvModLookupFromImports(TVMFFIObjectHandle library_ctx, + const char* func_name, + TVMFFIObjectHandle* out) + +// Register context symbol (available when library is loaded) +int TVMFFIEnvModRegisterContextSymbol(const char* name, void* symbol) + +// Register system library symbol +int TVMFFIEnvModRegisterSystemLibSymbol(const char* name, void* symbol) + +// Register C API symbol +int TVMFFIEnvRegisterCAPI(const char* name, void* symbol) +``` + +### Utilities +```cpp +// Check for environment signals (e.g., Python Ctrl+C) +int TVMFFIEnvCheckSignals() +``` + +### Type Definitions +```cpp +typedef void* TVMFFIStreamHandle // Stream handle type +``` + +# Common Patterns + +## Dtype Validation +```cpp +void MyKernel(tvm::ffi::TensorView x) { + DLDataType dt = x.dtype(); + TVM_FFI_ICHECK_EQ(dt.code, kDLFloat) << "Expected float dtype"; + TVM_FFI_ICHECK_EQ(dt.bits, 32) << "Expected 32-bit dtype"; + + // Or for user-facing errors: + if (dt.code != kDLFloat || dt.bits != 32) { + TVM_FFI_THROW(TypeError) << "Expected float32, got " + << "code=" << dt.code << " bits=" << dt.bits; + } +} +``` + +## Shape Validation +```cpp +void MatMul(tvm::ffi::TensorView a, tvm::ffi::TensorView b, tvm::ffi::TensorView c) { + TVM_FFI_ICHECK_EQ(a.ndim(), 2) << "a must be 2D"; + TVM_FFI_ICHECK_EQ(b.ndim(), 2) << "b must be 2D"; + TVM_FFI_ICHECK_EQ(a.size(1), b.size(0)) << "Shape mismatch for matmul"; + TVM_FFI_ICHECK_EQ(c.size(0), a.size(0)) << "Output shape mismatch"; + TVM_FFI_ICHECK_EQ(c.size(1), b.size(1)) << "Output shape mismatch"; +} +``` + +## Multi-dimensional Indexing +```cpp +void Process2D(tvm::ffi::TensorView x) { + int64_t height = x.size(0); + int64_t width = x.size(1); + float* data = static_cast(x.data_ptr()); + + // If contiguous (row-major), access as: data[i * width + j] + TVM_FFI_ICHECK(x.IsContiguous()) << "Expected contiguous tensor"; + + // With strides: + int64_t stride0 = x.stride(0); + int64_t stride1 = x.stride(1); + // Access: data[i * stride0 + j * stride1] +} +``` + +## Device-specific Kernels +```cpp +void MyKernel(tvm::ffi::TensorView x) { + DLDevice dev = x.device(); + if (dev.device_type == kDLCUDA) { + // Launch CUDA kernel + cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(dev.device_type, dev.device_id)); + // kernel<<>>(...); + } else if (dev.device_type == kDLCPU) { + // CPU implementation + } else { + TVM_FFI_THROW(RuntimeError) << "Unsupported device type: " << dev.device_type; + } +} +``` + +## Allocating New Tensors +```cpp +void CreateOutput(tvm::ffi::TensorView input) { + // Create output tensor with same device and dtype as input + ffi::ShapeView shape = input.shape(); + DLDataType dtype = input.dtype(); + DLDevice device = input.device(); + + // Use environment allocator + ffi::Tensor output = ffi::Tensor::FromEnvAlloc( + TVMFFIEnvTensorAlloc, shape, dtype, device); +} +``` +""" diff --git a/tests/agent/test_prompt.py b/tests/agent/test_prompt.py new file mode 100644 index 00000000..66cf3561 --- /dev/null +++ b/tests/agent/test_prompt.py @@ -0,0 +1,261 @@ +""" +Test script to evaluate LLMs' ability to generate CUDA kernels with TVM FFI bindings. +Tests elementwise add kernel generation across multiple models. +""" + +import os +import re +from pathlib import Path + +import torch +import tvm_ffi.cpp +from dotenv import load_dotenv +from ffi_prompt import FFI_PROMPT_SIMPLE +from tvm_ffi import Module + +load_dotenv() + + +# System prompt for elementwise add +ELEMENTWISE_ADD_PROMPT = """Write a CUDA kernel function that performs elementwise addition of two tensors. + +The function should: +- Take three TensorView arguments: input tensor a, input tensor b, and output tensor c +- Compute c[i] = a[i] + b[i] for all elements +- Support 1D float32 tensors +- Use proper input validation +- Export the function with name "elementwise_add" +""" + + +def get_model_config(model_name: str): + """Get API configuration for a given model.""" + # OpenAI models + openai_key = os.getenv("OPENAI_API_KEY") + # Anthropic models + anthropic_key = os.getenv("ANTHROPIC_API_KEY") + + if model_name in ["gpt-5-2025-08-07", "o3", "gpt-5-mini-2025-08-07", "o4-mini-2025-04-16"]: + return {"provider": "openai", "api_key": openai_key, "model": model_name} + elif model_name in ["claude-opus-4-1-20250805", "claude-sonnet-4-5-20250805"]: + return {"provider": "anthropic", "api_key": anthropic_key, "model": model_name} + else: + raise ValueError(f"Unknown model: {model_name}") + + +def call_openai_model(model_name: str, api_key: str, prompt: str) -> str: + """Call OpenAI API to generate code.""" + import openai + + client = openai.OpenAI(api_key=api_key) + + if model_name in ["o3", "o4-mini-2025-04-16"]: + response = client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": prompt}], + reasoning_effort="high", + ) + else: + response = client.chat.completions.create( + model=model_name, + messages=[ + {"role": "system", "content": ELEMENTWISE_ADD_PROMPT}, + {"role": "user", "content": FFI_PROMPT_SIMPLE}, + ], + ) + + return response.choices[0].message.content + + +def call_anthropic_model(model_name: str, api_key: str, prompt: str) -> str: + import anthropic + + client = anthropic.Anthropic(api_key=api_key) + + response = client.messages.create( + model=model_name, + max_tokens=4096, + system=ELEMENTWISE_ADD_PROMPT, + messages=[{"role": "user", "content": FFI_PROMPT_SIMPLE}], + ) + + return response.content[0].text + + +def extract_cuda_code(response: str) -> str: + patterns = [r"```(?:cpp|cuda|c\+\+)\n(.*?)```", r"```\n(.*?)```"] + + for pattern in patterns: + matches = re.findall(pattern, response, re.DOTALL) + if matches: + for match in matches: + if "TVM_FFI" in match or "TensorView" in match: + return match.strip() + + return response.strip() + + +def test_kernel(mod: Module, test_name: str) -> bool: + """Test the generated kernel with simple test cases.""" + try: + print(f"\n Running {test_name}...") + + if test_name == "test_small": + # Test 1: Small tensor + a = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32, device="cuda") + b = torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0], dtype=torch.float32, device="cuda") + c = torch.empty_like(a) + + mod.elementwise_add(a, b, c) + expected = a + b + + torch.testing.assert_close(c, expected) + print(f" ✓ {test_name} passed") + return True + + elif test_name == "test_large": + # Test 2: Larger tensor + n = 10000 + a = torch.randn(n, dtype=torch.float32, device="cuda") + b = torch.randn(n, dtype=torch.float32, device="cuda") + c = torch.empty_like(a) + + mod.elementwise_add(a, b, c) + expected = a + b + + torch.testing.assert_close(c, expected) + print(f" ✓ {test_name} passed") + return True + + except Exception as e: + print(f" ✗ {test_name} failed: {e}") + return False + + +def test_model(model_name: str, output_dir: Path): + print(f"\n{'='*80}") + print(f"Testing model: {model_name}") + print(f"{'='*80}") + + try: + config = get_model_config(model_name) + + full_prompt = ELEMENTWISE_ADD_PROMPT + "\n\n" + FFI_PROMPT_SIMPLE + + print(f"Calling {config['provider']} API...") + if config["provider"] == "openai": + response = call_openai_model(config["model"], config["api_key"], full_prompt) + elif config["provider"] == "anthropic": + response = call_anthropic_model(config["model"], config["api_key"], full_prompt) + else: + raise ValueError(f"Unknown provider: {config['provider']}") + + print(f"Received response from {model_name}") + + cuda_code = extract_cuda_code(response) + print(f"Extracted {len(cuda_code)} characters of code") + + output_file = output_dir / f"{model_name}.txt" + with open(output_file, "w") as f: + f.write(f"Model: {model_name}\n") + f.write(f"{'='*80}\n\n") + f.write("Generated Code:\n") + f.write("=" * 80 + "\n") + f.write(cuda_code) + f.write("\n\n") + + print(f"Saved response to {output_file}") + + try: + mod: Module = tvm_ffi.cpp.load_inline( + name=f'elementwise_add_{model_name.replace("-", "_")}', cuda_sources=cuda_code + ) + print("Compilation successful!") + + # Run tests + print("\nRunning tests...") + test1_passed = test_kernel(mod, "test_small") + test2_passed = test_kernel(mod, "test_large") + + with open(output_file, "a") as f: + f.write("Test Results:\n") + f.write("=" * 80 + "\n") + f.write("Compilation: SUCCESS\n") + f.write(f"Test 1 (small tensor): {'PASS' if test1_passed else 'FAIL'}\n") + f.write(f"Test 2 (large tensor): {'PASS' if test2_passed else 'FAIL'}\n") + + all_passed = test1_passed and test2_passed + status = "ALL TESTS PASSED" if all_passed else "TESTS FAILED" + print(f"\n{status}") + + return { + "model": model_name, + "compilation": "success", + "test_small": test1_passed, + "test_large": test2_passed, + "all_passed": all_passed, + } + + except Exception as e: + print(f"Compilation failed: {e}") + with open(output_file, "a") as f: + f.write("Test Results:\n") + f.write("=" * 80 + "\n") + f.write("Compilation: FAILED\n") + f.write(f"Error: {str(e)}\n") + + return { + "model": model_name, + "compilation": "failed", + "error": str(e), + "all_passed": False, + } + + except Exception as e: + print(f"Error testing {model_name}: {e}") + return {"model": model_name, "compilation": "error", "error": str(e), "all_passed": False} + + +def main(): + models = [ + "gpt-5-2025-08-07", + "o3", + "claude-opus-4-1-20250805", + "claude-sonnet-4-5-20250805", + "gpt-5-mini-2025-08-07", + "o4-mini-2025-04-16", + ] + + output_dir = Path(__file__).parent / "test_results" + output_dir.mkdir(exist_ok=True) + + print("=" * 80) + print("Testing LLM FFI Code Generation") + print("=" * 80) + print(f"Models to test: {len(models)}") + print(f"Output directory: {output_dir}") + + results = [] + + for idx, model_name in enumerate(models, 1): + print(f"\n[{idx}/{len(models)}] Testing {model_name}...") + result = test_model(model_name, output_dir) + results.append(result) + + print(f"\n{'='*80}") + print("FINAL SUMMARY") + print(f"{'='*80}") + + for result in results: + model = result["model"] + status = "✓ PASS" if result.get("all_passed", False) else "✗ FAIL" + print(f"{model:30} {status}") + if "error" in result: + print(f" Error: {result['error'][:100]}") + + total_passed = sum(1 for r in results if r.get("all_passed", False)) + print(f"\nTotal: {total_passed}/{len(models)} models passed all tests") + + +if __name__ == "__main__": + main()