-
Notifications
You must be signed in to change notification settings - Fork 20
Add GEMV INT 4 #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
albiol2004
wants to merge
5
commits into
amd:devel
Choose a base branch
from
albiol2004:gemv-int4
base: devel
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add GEMV INT 4 #101
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
4a3314d
first working version of INT 4 GEMV
albiol2004 9cd050a
1.6x speedup, GROUP_SIZE at compile time
albiol2004 6c48321
double-pump kernel + K at compile time
albiol2004 1db9c31
Add fused INT4 dequant-GEMV operator
albiol2004 32571c6
ran clang formatter
albiol2004 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,158 @@ | ||
| // SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| // Fused INT4 dequantization + GEMV kernel for AIE2+. | ||
| // | ||
| // Loads INT4-packed weights, dequantizes in-register, and performs | ||
| // matrix-vector multiplication in a single pass. | ||
| // | ||
| // Weight layout per tile (m rows x K cols, group_size G): | ||
| // [m * K / 2 bytes of packed uint4 weights] | ||
| // [m * (K / G) bf16 scale factors, stored as (m * K / G * 2) bytes] | ||
| // | ||
| // Dequantization: w_bf16 = scale * unpack_uint4_to_bf16(w_uint4) | ||
| // | ||
| // The unpack chain matches the existing dequant kernel (expand.cc): | ||
| // uint4 -> uint8 (aie::unpack) -> uint16 (aie::unpack) -> bf16 (aie::to_float) | ||
| // | ||
| // Optimization: double-pump — process 2 groups (64 elements) per iteration | ||
| // so the compiler can interleave the two independent unpack chains, hiding | ||
| // the dequant latency behind computation. | ||
|
|
||
| #define NOCPP | ||
|
|
||
| #include "../aie_kernel_utils.h" | ||
|
|
||
| #include <aie_api/aie.hpp> | ||
| #include <stdint.h> | ||
| #include <type_traits> | ||
|
|
||
| // block_size: dequant vector width (must be 32 for aie::unpack) | ||
| // G: group size (compile-time for pipelining, must be multiple of block_size) | ||
| // DK: K dimension (compile-time for loop count optimization) | ||
| template <uint32_t block_size, uint32_t G, uint32_t DK> | ||
| void fused_dequant_matvec(uint32_t m, | ||
| const uint8_t *__restrict a_in, | ||
| const bfloat16 *__restrict b_in, | ||
| bfloat16 *__restrict c_out) | ||
| { | ||
| static_assert(block_size == 32, "block_size must be 32 to match dequant vector width"); | ||
| static_assert(G % block_size == 0, "group_size must be a multiple of block_size"); | ||
| constexpr uint32_t blocks_per_group = G / block_size; | ||
| constexpr uint32_t groups_per_row = DK / G; | ||
| // For double-pump: process 2 groups per iteration when possible | ||
| constexpr bool can_double_pump = (groups_per_row >= 2) && (groups_per_row % 2 == 0); | ||
| constexpr uint32_t pump_groups = can_double_pump ? 2 : 1; | ||
| constexpr uint32_t loop_iters = groups_per_row / pump_groups; | ||
|
|
||
| ::aie::set_rounding(aie::rounding_mode::conv_even); | ||
|
|
||
| const uint4 *weights_packed = reinterpret_cast<const uint4 *>(a_in); | ||
| const uint8_t *scale_bytes = a_in + m * DK / 2; | ||
| const bfloat16 *scales = reinterpret_cast<const bfloat16 *>(scale_bytes); | ||
|
|
||
| event0(); | ||
| for (uint32_t row = 0; row < m; row++) { | ||
| const uint4 *row_weights = weights_packed + row * DK / 2; | ||
| const bfloat16 *row_scales = scales + row * groups_per_row; | ||
| const bfloat16 *b_ptr = b_in; | ||
|
|
||
| aie::accum<accfloat, block_size> acc = aie::zeros<accfloat, block_size>(); | ||
|
|
||
| if constexpr (can_double_pump && blocks_per_group == 1) { | ||
| // Optimized path: 2 groups per iteration, 1 block per group | ||
| // Two independent unpack chains for the compiler to interleave. | ||
| AIE_LOOP_MIN_ITERATION_COUNT(loop_iters) | ||
| for (uint32_t g = 0; g < groups_per_row; g += 2) | ||
| AIE_PREPARE_FOR_PIPELINING | ||
| { | ||
| // --- Chain A: group g --- | ||
| bfloat16 sf_a = row_scales[g]; | ||
| aie::vector<bfloat16, block_size> sf_a_bc = aie::broadcast<bfloat16, block_size>(sf_a); | ||
|
|
||
| aie::vector<uint4, block_size> I0_a = aie::load_v<block_size>(row_weights); | ||
| row_weights += block_size / 2; | ||
|
|
||
| // --- Chain B: group g+1 (interleaved) --- | ||
| bfloat16 sf_b = row_scales[g + 1]; | ||
| aie::vector<bfloat16, block_size> sf_b_bc = aie::broadcast<bfloat16, block_size>(sf_b); | ||
|
|
||
| aie::vector<uint4, block_size> I0_b = aie::load_v<block_size>(row_weights); | ||
| row_weights += block_size / 2; | ||
|
|
||
| // Unpack chain A | ||
| aie::vector<uint8, block_size> a8_a = aie::unpack(I0_a); | ||
| aie::vector<uint16, block_size> a16_a = aie::unpack(a8_a); | ||
| aie::vector<bfloat16, block_size> abf_a = aie::to_float<bfloat16>(a16_a, 0); | ||
| aie::vector<bfloat16, block_size> w_a = aie::mul(abf_a, sf_a_bc).template to_vector<bfloat16>(); | ||
|
|
||
| // Unpack chain B | ||
| aie::vector<uint8, block_size> a8_b = aie::unpack(I0_b); | ||
| aie::vector<uint16, block_size> a16_b = aie::unpack(a8_b); | ||
| aie::vector<bfloat16, block_size> abf_b = aie::to_float<bfloat16>(a16_b, 0); | ||
| aie::vector<bfloat16, block_size> w_b = aie::mul(abf_b, sf_b_bc).template to_vector<bfloat16>(); | ||
|
|
||
| // Load activation vectors and MAC | ||
| aie::vector<bfloat16, block_size> b_a = aie::load_v<block_size>(b_ptr); | ||
| b_ptr += block_size; | ||
| acc = aie::mac(acc, w_a, b_a); | ||
|
|
||
| aie::vector<bfloat16, block_size> b_b = aie::load_v<block_size>(b_ptr); | ||
| b_ptr += block_size; | ||
| acc = aie::mac(acc, w_b, b_b); | ||
| } | ||
| } else { | ||
| // Generic path: 1 group per iteration | ||
| AIE_LOOP_MIN_ITERATION_COUNT(loop_iters) | ||
| for (uint32_t g = 0; g < groups_per_row; g++) | ||
| AIE_PREPARE_FOR_PIPELINING | ||
| { | ||
| bfloat16 sf = row_scales[g]; | ||
| aie::vector<bfloat16, block_size> sf_broadcast = aie::broadcast<bfloat16, block_size>(sf); | ||
|
|
||
| AIE_LOOP_MIN_ITERATION_COUNT(blocks_per_group) | ||
| for (uint32_t blk = 0; blk < blocks_per_group; blk++) { | ||
| aie::vector<uint4, block_size> I0 = aie::load_v<block_size>(row_weights); | ||
| row_weights += block_size / 2; | ||
|
|
||
| aie::vector<uint8, block_size> as_int8 = aie::unpack(I0); | ||
| aie::vector<uint16, block_size> as_int16 = aie::unpack(as_int8); | ||
| aie::vector<bfloat16, block_size> as_bf16 = aie::to_float<bfloat16>(as_int16, 0); | ||
| aie::vector<bfloat16, block_size> w_dequant = | ||
| aie::mul(as_bf16, sf_broadcast).template to_vector<bfloat16>(); | ||
|
|
||
| aie::vector<bfloat16, block_size> b_vec = aie::load_v<block_size>(b_ptr); | ||
| b_ptr += block_size; | ||
|
|
||
| acc = aie::mac(acc, w_dequant, b_vec); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| *c_out = static_cast<bfloat16>(aie::reduce_add(acc.template to_vector<float>())); | ||
| c_out++; | ||
| } | ||
| event1(); | ||
| } | ||
|
|
||
| #ifndef GROUP_SIZE | ||
| #define GROUP_SIZE 32 | ||
| #endif | ||
|
|
||
| #ifndef DIM_K | ||
| #define DIM_K 2048 | ||
| #endif | ||
|
|
||
| extern "C" { | ||
|
|
||
| void fused_dequant_matvec_bf16(uint32_t m, | ||
| uint32_t row_offset, | ||
| const uint8_t *__restrict a_in, | ||
| const bfloat16 *__restrict b_in, | ||
| bfloat16 *__restrict c_out) | ||
| { | ||
| c_out += row_offset; | ||
| fused_dequant_matvec<32, GROUP_SIZE, DIM_K>(m, a_in, b_in, c_out); | ||
| } | ||
|
|
||
| } // extern "C" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| # SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """ | ||
| Fused INT4 dequantization matrix-vector design. | ||
|
|
||
| Performs a fused dequantize-GEMV where the weight matrix is stored in packed | ||
| INT4 format (two 4-bit values per uint8 byte) with per-group bfloat16 scale | ||
| factors. The activation vector and output are bfloat16. | ||
|
|
||
| Each AIE column processes a contiguous block of output rows. Within a column, | ||
| the worker iterates over tiles of packed weight rows, acquires the full | ||
| activation vector once per outer iteration, and calls the fused dequant-matvec | ||
| kernel which unpacks, dequantizes, and accumulates in a single pass. | ||
|
|
||
| Buffer layout for A (packed weights, uint8): | ||
| For each tile of m_input rows: [m_input * K / 2 bytes of packed weights] | ||
| [m_input * (K / group_size) * 2 bytes of scales] | ||
| """ | ||
|
|
||
| import numpy as np | ||
| from ml_dtypes import bfloat16 | ||
|
|
||
| import aie.dialects.index as index | ||
| from aie.dialects.aie import T | ||
| from aie.helpers.dialects.scf import _for as range_ | ||
| from aie.helpers.taplib import TensorAccessPattern | ||
| from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker | ||
| from aie.iron.placers import SequentialPlacer | ||
|
|
||
|
|
||
| def my_fused_dequant_matvec( | ||
| dev, cols, M, K, m_input, m_output=None, group_size=32 | ||
| ): | ||
| if m_output is None: | ||
| m_output = m_input | ||
|
|
||
| # --- Assertions --- | ||
| assert ( | ||
| m_output % m_input == 0 and m_output >= m_input | ||
| ), "m_output must be a multiple of m_input" | ||
| assert m_output <= M // cols, "m_output must be less than or equal to M/cols" | ||
| assert (M // cols) % m_output == 0, "m_output must evenly divide M/cols" | ||
| assert m_input <= M // cols, "m_input must be less than or equal to M/cols" | ||
| assert (M // cols) % m_input == 0, "m_input must evenly divide M/cols" | ||
| assert K % group_size == 0, "K must be divisible by group_size" | ||
| assert group_size % 32 == 0, "group_size must be a multiple of 32" | ||
| assert M % cols == 0, "M must be divisible by cols" | ||
|
|
||
| # --- Data types --- | ||
| dtype_in = np.dtype[np.uint8] | ||
| dtype_vec = np.dtype[bfloat16] | ||
| dtype_out = np.dtype[bfloat16] | ||
|
|
||
| # --- Per-tile sizes (in uint8 bytes) --- | ||
| num_groups_per_row = K // group_size | ||
| packed_tile_bytes = m_input * K // 2 + m_input * num_groups_per_row * 2 | ||
| rows_per_col = M // cols | ||
| tiles_per_col = rows_per_col // m_input | ||
| bytes_per_col = tiles_per_col * packed_tile_bytes | ||
| packed_total_bytes = cols * bytes_per_col | ||
|
|
||
| # --- L1 (on-chip) tensor types --- | ||
| L1_A_ty = np.ndarray[(packed_tile_bytes,), dtype_in] | ||
| L1_B_ty = np.ndarray[(K,), dtype_vec] | ||
| L1_C_ty = np.ndarray[(m_output,), dtype_out] | ||
|
|
||
| # --- L3 (DDR) tensor types --- | ||
| L3_A_ty = np.ndarray[(packed_total_bytes,), dtype_in] | ||
| L3_B_ty = np.ndarray[(K,), dtype_vec] | ||
| L3_C_ty = np.ndarray[(M,), dtype_out] | ||
|
|
||
| # --- Kernel declaration --- | ||
| # K and group_size are compile-time via -DDIM_K/-DGROUP_SIZE. | ||
| fused_matvec = Kernel( | ||
| "fused_dequant_matvec_bf16", | ||
| f"fused_dequant_gemv_{K}k_g{group_size}.o", | ||
| [np.int32, np.int32, L1_A_ty, L1_B_ty, L1_C_ty], | ||
| ) | ||
|
|
||
| # --- ObjectFIFOs --- | ||
| A_L3L1_fifos = [ | ||
| ObjectFifo(L1_A_ty, name=f"A_L3L1_{i}", depth=2) for i in range(cols) | ||
| ] | ||
| B_L3L1_fifos = [ | ||
| ObjectFifo(L1_B_ty, name=f"B_L3L1_{i}", depth=1) for i in range(cols) | ||
| ] | ||
| C_L1L3_fifos = [ | ||
| ObjectFifo(L1_C_ty, name=f"C_L1L3_{i}", depth=2) for i in range(cols) | ||
| ] | ||
|
|
||
| # --- Worker core body --- | ||
| N_div_n = tiles_per_col // (m_output // m_input) | ||
|
|
||
| def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, fused_matvec_fn): | ||
| for _ in range_(0xFFFFFFFF): | ||
| b = B_L3L1_fifo.acquire(1) | ||
| for i_idx in range_(N_div_n): | ||
| c = C_L1L3_fifo.acquire(1) | ||
| for j_idx in range_(m_output // m_input): | ||
| j_i32 = index.casts(T.i32(), j_idx) | ||
| output_row_offset = j_i32 * m_input | ||
| a = A_L3L1_fifo.acquire(1) | ||
| fused_matvec_fn( | ||
| m_input, output_row_offset, a, b, c | ||
| ) | ||
| A_L3L1_fifo.release(1) | ||
| C_L1L3_fifo.release(1) | ||
| B_L3L1_fifo.release(1) | ||
|
|
||
| workers = [ | ||
| Worker( | ||
| core_body, | ||
| [ | ||
| A_L3L1_fifos[i].cons(), | ||
| B_L3L1_fifos[i].cons(), | ||
| C_L1L3_fifos[i].prod(), | ||
| fused_matvec, | ||
| ], | ||
| ) | ||
| for i in range(cols) | ||
| ] | ||
|
|
||
| # --- TensorAccessPatterns --- | ||
| # A: each column gets a contiguous chunk of bytes_per_col packed bytes | ||
| A_taps = [ | ||
| TensorAccessPattern( | ||
| tensor_dims=L3_A_ty.__args__[0], | ||
| offset=col * bytes_per_col, | ||
| sizes=[1, 1, 1, bytes_per_col], | ||
| strides=[0, 0, 0, 1], | ||
| ) | ||
| for col in range(cols) | ||
| ] | ||
|
|
||
| # C: each column writes contiguous rows_per_col bfloat16 values | ||
| C_taps = [ | ||
| TensorAccessPattern( | ||
| tensor_dims=L3_C_ty.__args__[0], | ||
| offset=col * rows_per_col, | ||
| sizes=[1, 1, 1, rows_per_col], | ||
| strides=[0, 0, 0, 1], | ||
| ) | ||
| for col in range(cols) | ||
| ] | ||
|
|
||
| # --- Runtime sequence --- | ||
| rt = Runtime() | ||
| with rt.sequence(L3_A_ty, L3_B_ty, L3_C_ty) as (A, B, C): | ||
| rt.start(*workers) | ||
| tg = rt.task_group() | ||
| for i in range(cols): | ||
| rt.fill(A_L3L1_fifos[i].prod(), A, A_taps[i], task_group=tg) | ||
| rt.fill(B_L3L1_fifos[i].prod(), B, task_group=tg) | ||
| for i in range(cols): | ||
| rt.drain( | ||
| C_L1L3_fifos[i].cons(), C, C_taps[i], task_group=tg, wait=True | ||
| ) | ||
| rt.finish_task_group(tg) | ||
|
|
||
| return Program(dev, rt).resolve_program(SequentialPlacer()) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much of this is similar to other GEMM ops? Can we find a way to reuse code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like 70% similar to gemv. But different layout and kernel. I could import from gemv to reuse _build_gemv_program()
or have it as a common helper (only 2 ops using it I think) which I would leave for refactor in another PR
Basically it’s the fused kernel that is important so depends on what you prefer for readability/maintenance