Skip to content
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ elseif(BUILD_HIP)

list(APPEND SRC_FILES ${GPU_FILES})

# 4-bit GEMM: build the SIMT kernel + C dispatch on ROCm.
set_source_files_properties(csrc/gemm_4bit_simt.cu csrc/gemm_4bit.cu PROPERTIES LANGUAGE HIP)
list(APPEND SRC_FILES csrc/gemm_4bit_simt.cu csrc/gemm_4bit.cu)

string(APPEND BNB_OUTPUT_NAME "_rocm")

# get hip version
Expand Down
284 changes: 153 additions & 131 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

from ..._ops import register_kernel
from ...cextension import lib
from ..default.ops import _gemm_4bit_default_impl
from ..utils import _get_4bit_code


def _setup_ctypes(names, argtypes, restype=None):
Expand Down Expand Up @@ -583,7 +581,7 @@ def _gemv_4bit_impl(


@functools.cache
def _gemm_4bit_use_custom(device_index, dtype, M, N, K):
def _gemm_4bit_use_custom_cuda(device_index, dtype, M, N, K):
"""Custom kernel vs dequant+F.linear heuristic for M in [5, 1536].

Per-arch notes (bf16/fp16, M >= 8, large weight):
Expand All @@ -595,6 +593,9 @@ def _gemm_4bit_use_custom(device_index, dtype, M, N, K):
sm100 (B200/B300, HBM3e): exits early at top of function.
sm120 (RTX 5000, GDDR7): dedicated block; medium-N tiers differ from sm89.
"""
if M <= _GEMM_4BIT_CUSTOM_FLOOR_M:
return True

num_sms, major, minor = _gpu_dispatch_props(device_index)
n_blocks = (N + 63) // 64

Expand Down Expand Up @@ -800,140 +801,157 @@ def _gemm_4bit_use_custom(device_index, dtype, M, N, K):
return M <= (16 if (tall_k_2xn or n_blocks < 48) else 8)


if torch.version.hip is None:
@functools.cache
def _gemm_4bit_use_custom_rocm(device_index, dtype, M, N, K):
Comment thread
sstamenk marked this conversation as resolved.
"""
Fused SIMT kernel vs dequant+F.linear heuristic for ROCm.

@register_kernel("bitsandbytes::gemm_4bit", "cuda")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
bias: Optional[torch.Tensor] = None,
absmax_8bit: Optional[torch.Tensor] = None,
absmax_code: Optional[torch.Tensor] = None,
absmax_offset: Optional[torch.Tensor] = None,
) -> torch.Tensor:
K = A.shape[-1]
M = A.numel() // K
N = shapeB[0]

# M>1536: dequant+F.linear wins (dequant savings negligible at very large batch).
# M<=4: always custom (custom kernel wins universally at small batch).
# M in [5, 1536]: shape/arch-dependent; cached per (device, dtype, M, N, K).
if M > 1536:
use_custom = False
elif K % blocksize != 0:
warn(
f"inner dimension ({K}) is not aligned for fast kernel "
f"with blocksize={blocksize}, falling back to slower implementation.",
UserWarning,
)
use_custom = False
else:
use_custom = M <= 4 or _gemm_4bit_use_custom(A.device.index, A.dtype, M, N, K)

if not use_custom:
if absmax_8bit is not None:
absmax_dq = torch.empty_like(absmax_8bit, dtype=torch.float32)
_dequantize_blockwise_impl(absmax_8bit, absmax, absmax_code, 256, torch.float32, out=absmax_dq)
absmax = absmax_dq + absmax_offset
B_dq = torch.empty(shapeB, dtype=A.dtype, device=A.device)
_dequantize_4bit_impl(B, absmax, blocksize, quant_type, A.dtype, out=B_dq)
return torch.nn.functional.linear(A, B_dq, bias)

if K != shapeB[1]:
raise RuntimeError(f"A inner dim ({K}) does not match weight ({shapeB[1]})")
if absmax.dtype != torch.float32:
raise RuntimeError(f"absmax must be float32, got {absmax.dtype}")
if bias is not None:
if bias.ndim != 1:
raise RuntimeError(f"bias must be 1D, got {bias.ndim}D")
if bias.dtype != A.dtype:
raise RuntimeError(f"bias dtype ({bias.dtype}) must match A dtype ({A.dtype})")

quant_type_int = 1 if quant_type == "fp4" else 2

out = torch.empty((*A.shape[:-1], N), dtype=A.dtype, device=A.device)
stream = _get_raw_stream(A.device.index)

if A.dtype == torch.bfloat16:
fn = lib.cgemm_4bit_bf16
elif A.dtype == torch.float16:
fn = lib.cgemm_4bit_fp16
elif A.dtype == torch.float32:
fn = lib.cgemm_4bit_fp32
else:
raise RuntimeError(f"unsupported dtype {A.dtype}")

# Offset is expected to be a float32 tensor.
absmax_offset_f32 = absmax_offset.to(dtype=torch.float32) if absmax_offset is not None else None

with _cuda_device_of(A):
fn(
A.data_ptr(),
B.data_ptr(),
absmax.data_ptr(),
absmax_8bit.data_ptr() if absmax_8bit is not None else None,
absmax_code.data_ptr() if absmax_code is not None else None,
absmax_offset_f32.data_ptr() if absmax_offset_f32 is not None else None,
out.data_ptr(),
bias.data_ptr() if bias is not None else None,
M,
N,
K,
blocksize,
quant_type_int,
stream,
)
RDNA3/RDNA4 calibration keeps the SIMT kernel through ~M=8.
CDNA/gfx9 is calibrated on MI308X (gfx942): bf16/fp16 win through M<=4
after the SIMT math-path tuning, while fp32 only has a broad win through M<=2.

TODO: revisit once WMMA/MFMA kernels land.
"""
if M <= _GEMM_4BIT_CUSTOM_FLOOR_M and dtype != torch.float32:
return True

arch = _rocm_gfx_arch(device_index)
if arch.startswith("gfx11") or arch.startswith("gfx12"): # RDNA3 / RDNA4
return M <= 8
if arch.startswith("gfx9"): # CDNA / MI-series
return M <= (2 if dtype == torch.float32 else 4)
return M <= 4 # unknown ROCm arch: conservative tiny-batch floor


@functools.cache
def _rocm_gfx_arch(device_index):
"""gfx arch string (e.g. 'gfx1100') for a ROCm device, feature flags stripped."""
name = getattr(torch.cuda.get_device_properties(device_index), "gcnArchName", "") or ""
return name.split(":")[0]


def _gemm_4bit_kernel_impl(
A, B, shapeB, absmax, blocksize, quant_type, bias=None, absmax_8bit=None, absmax_code=None, absmax_offset=None
):
"""Invoke the fused cgemm_4bit_* kernel (shared by the CUDA and ROCm dispatch; the
C dispatch in gemm_4bit.cu picks SIMT vs MMA per arch/shape). A is made contiguous
because the kernel reads it as row-major (stride K)."""
K = A.shape[-1]
M = A.numel() // K
N = shapeB[0]

if K != shapeB[1]:
raise RuntimeError(f"A inner dim ({K}) does not match weight ({shapeB[1]})")
if absmax.dtype != torch.float32:
raise RuntimeError(f"absmax must be float32, got {absmax.dtype}")
if bias is not None:
if bias.ndim != 1:
raise RuntimeError(f"bias must be 1D, got {bias.ndim}D")
if bias.dtype != A.dtype:
raise RuntimeError(f"bias dtype ({bias.dtype}) must match A dtype ({A.dtype})")

A = A.contiguous()
quant_type_int = 1 if quant_type == "fp4" else 2
out = torch.empty((*A.shape[:-1], N), dtype=A.dtype, device=A.device)
stream = _get_raw_stream(A.device.index)

if A.dtype == torch.bfloat16:
fn = lib.cgemm_4bit_bf16
elif A.dtype == torch.float16:
fn = lib.cgemm_4bit_fp16
elif A.dtype == torch.float32:
fn = lib.cgemm_4bit_fp32
else:
raise RuntimeError(f"unsupported dtype {A.dtype}")

# Offset is expected to be a float32 tensor.
absmax_offset_f32 = absmax_offset.to(dtype=torch.float32) if absmax_offset is not None else None

with _cuda_device_of(A):
fn(
A.data_ptr(),
B.data_ptr(),
absmax.data_ptr(),
absmax_8bit.data_ptr() if absmax_8bit is not None else None,
absmax_code.data_ptr() if absmax_code is not None else None,
absmax_offset_f32.data_ptr() if absmax_offset_f32 is not None else None,
out.data_ptr(),
bias.data_ptr() if bias is not None else None,
M,
N,
K,
blocksize,
quant_type_int,
stream,
)

return out

return out

def _dequant_linear_fallback(
A, B, shapeB, absmax, blocksize, quant_type, bias=None, absmax_8bit=None, absmax_code=None, absmax_offset=None
):
"""Unfused fallback shared by CUDA and ROCm: reconstruct the (optionally nested)
absmax, dequantize the 4-bit weight via the backend dequant impls (reusing
preallocated buffers), then F.linear."""
if absmax_8bit is not None:
absmax_dq = torch.empty_like(absmax_8bit, dtype=torch.float32)
_dequantize_blockwise_impl(absmax_8bit, absmax, absmax_code, 256, torch.float32, out=absmax_dq)
absmax = absmax_dq + absmax_offset
B_dq = torch.empty(shapeB, dtype=A.dtype, device=A.device)
_dequantize_4bit_impl(B, absmax, blocksize, quant_type, A.dtype, out=B_dq)
return torch.nn.functional.linear(A, B_dq, bias)


# Unified CUDA/ROCm dispatch for bitsandbytes::gemm_4bit. The choice *among* custom
# kernels (CUDA SIMT vs MMA; ROCm SIMT) is made in the C dispatch (csrc/gemm_4bit.cu).
_GEMM_4BIT_CUSTOM_FLOOR_M = 4
if torch.version.hip is None:
_gemm_4bit_use_custom_fn = _gemm_4bit_use_custom_cuda
# CUDA: dequant+F.linear wins past M=1536 (dequant savings negligible at very
# large batch).
_gemm_4bit_custom_max_m = 1536
else:
_gemm_4bit_use_custom_fn = _gemm_4bit_use_custom_rocm
# ROCm: the custom path is SIMT-only today; the per-arch heuristic above owns
# RDNA/CDNA thresholds. Keep a hard upper cap while WMMA/MFMA paths are absent.
_gemm_4bit_custom_max_m = 256

@register_kernel("bitsandbytes::gemm_4bit", "cuda")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
bias: Optional[torch.Tensor] = None,
absmax_8bit: Optional[torch.Tensor] = None,
absmax_code: Optional[torch.Tensor] = None,
absmax_offset: Optional[torch.Tensor] = None,
) -> torch.Tensor:
K = A.shape[-1]
M = A.numel() // K
N = shapeB[0]

if M == 1:
if K % blocksize == 0:
if absmax_8bit is not None:
absmax = (
torch.ops.bitsandbytes.dequantize_blockwise.default(
absmax_8bit, absmax, absmax_code, 256, torch.float32
)
+ absmax_offset
)

code = _get_4bit_code(quant_type, A.device)
out = torch.empty((*A.shape[:-1], N), dtype=A.dtype, device=A.device)
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)

if bias is not None:
out = out + bias
return out

warn(
f"inner dimension ({K}) is not aligned for fast kernel "
f"with blocksize={blocksize}, falling back to slower implementation.",
UserWarning,
)

return _gemm_4bit_default_impl(
@register_kernel("bitsandbytes::gemm_4bit", "cuda")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
bias: Optional[torch.Tensor] = None,
absmax_8bit: Optional[torch.Tensor] = None,
absmax_code: Optional[torch.Tensor] = None,
absmax_offset: Optional[torch.Tensor] = None,
) -> torch.Tensor:
K = A.shape[-1]
M = A.numel() // K
N = shapeB[0]

# The backend-specific heuristic owns tiny-M floors and per-arch thresholds.
# Past custom_max_m (or for blocksize-misaligned K), use the dequant+F.linear
# fallback.
if M > _gemm_4bit_custom_max_m:
use_custom = False
elif K % blocksize != 0:
warn(
f"inner dimension ({K}) is not aligned for fast kernel "
f"with blocksize={blocksize}, falling back to slower implementation.",
UserWarning,
)
use_custom = False
else:
use_custom = _gemm_4bit_use_custom_fn(A.device.index, A.dtype, M, N, K)

if not use_custom:
return _dequant_linear_fallback(
A,
B,
shapeB,
Expand All @@ -946,6 +964,10 @@ def _(
absmax_offset=absmax_offset,
)

return _gemm_4bit_kernel_impl(
A, B, shapeB, absmax, blocksize, quant_type, bias, absmax_8bit, absmax_code, absmax_offset
)


"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {
Expand Down
13 changes: 3 additions & 10 deletions csrc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,13 @@
// Warp size

#if BNB_HIP
// CDNA (gfx9xx) = 64, RDNA (gfx10xx/gfx11xx/gfx12xx) = 32.
// __AMDGCN_WAVEFRONT_SIZE is not defined by all compiler versions (removed since ROCm 7.0),
// so fall back to architecture-family macros when it is absent.
// This is a macro that is defined by the compiler during each device-code pass and as such
// should only be used inside kernels.
#ifdef __AMDGCN_WAVEFRONT_SIZE
#define BNB_WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
#elif defined(__GFX9__)
#if defined(__GFX9__)
#define BNB_WARP_SIZE 64 // CDNA
#else
#define BNB_WARP_SIZE 32 // RDNA and other
#define BNB_WARP_SIZE 32 // RDNA
#endif
#else
#define BNB_WARP_SIZE 32
#define BNB_WARP_SIZE 32 // Other
#endif

// BF16 availability
Expand Down
Loading
Loading