Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ class RetrievalWorkerBackend {
if (rc != 0) {
std::cerr << "Error binding memory to NUMA node " << numaId << std::endl;
}
#else
std::cerr << "NUMA support is disabled." << std::endl;
#endif
}

Expand Down
17 changes: 15 additions & 2 deletions ucm/sparse/esa/retrieval/retrieval_worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import time
from collections import defaultdict

import numpy as np
import torch

# import retrieval_backend
from ucm.sparse.esa.retrieval import retrieval_backend
from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank


class RetrievalWorker:
Expand Down Expand Up @@ -42,7 +43,19 @@ def wait(self, req_id):
data = torch.rand(kv_cache_blocks, dim).to(torch.float32)
print("data created", data.shape)

backend = retrieval_backend.RetrievalWorkerBackend(data)
ratio = 0.75
total_tp_size = 4
local_tp_rank = 0
bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank(
total_tp_size, local_tp_rank, ratio=ratio
)

bind_info_dict = defaultdict(list)
for item in bind_info_list:
bind_info_dict[item[1]].append(item[0])
bind_info_dict = dict(bind_info_dict)

backend = retrieval_backend.RetrievalWorkerBackend(data, bind_info_dict)
worker = RetrievalWorker(backend)
topk = 3000
search_blocks_range = 8000
Expand Down
195 changes: 168 additions & 27 deletions ucm/sparse/kvcomp/hash_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,124 @@

logger = init_logger(__name__)

if hasattr(torch, "cuda") and torch.cuda.is_available():
from vllm.triton_utils import tl, triton

@triton.jit
def triton_hash_code_kernel(
x_ptr,
code_ptr,
pack_w_ptr,
hash_out_ptr,
M,
K,
N,
stride_xm,
stride_xk,
stride_codek,
stride_coden,
stride_pack_w,
stride_om,
stride_on,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # sample dimension
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # hash_rbits dimension
offs_k = tl.arange(0, BLOCK_K) # input_dim dimension

# Matrix multiplication
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x = tl.load(
x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
mask=(offs_m[:, None] < M) & (offs_k[None, :] < K),
other=0.0,
)
code = tl.load(
code_ptr
+ offs_k[:, None] * stride_codek
+ offs_n[None, :] * stride_coden,
mask=(offs_k[:, None] < K) & (offs_n[None, :] < N),
other=0.0,
)
acc += tl.dot(x, code)
offs_k += BLOCK_K

# Binarize and pack
bits = (acc > 0).to(tl.uint8) # Binarize
bits = tl.reshape(bits, (BLOCK_M, BLOCK_N // 8, 8)) # Reshape for packing

# Load the packing weights (ensure it has the correct shape)
pack_w = tl.load(pack_w_ptr + tl.arange(0, 8) * stride_pack_w)
packed = tl.sum(bits * pack_w[None, None, :], axis=-1).to(tl.uint8)

# Store results
offs_n = pid_n * (BLOCK_N // 8) + tl.arange(0, BLOCK_N // 8)
hash_out_ptrs = (
hash_out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
)
tl.store(
hash_out_ptrs,
packed,
mask=(offs_m[:, None] < M) & (offs_n[None, :] < (N // 8)),
)

def triton_hash_code(x, code, pack_weight):
input_dim = x.shape[-1]
samples = x.shape[0]
hash_bits = code.shape[-1]
assert (pack_weight.shape[0] == 8) and (hash_bits % 8 == 0)
hash_out = torch.empty(
(samples, hash_bits // 8), dtype=pack_weight.dtype, device=x.device
)

grid = lambda opts: (
triton.cdiv(samples, opts["BLOCK_M"]),
triton.cdiv(input_dim, opts["BLOCK_N"]),
)

triton_hash_code_kernel[grid](
x,
code,
pack_weight,
hash_out,
samples,
input_dim,
hash_bits,
x.stride(0),
x.stride(1),
code.stride(0),
code.stride(1),
pack_weight.stride(0),
hash_out.stride(0),
hash_out.stride(1),
BLOCK_M=32,
BLOCK_K=64,
BLOCK_N=16,
)

return hash_out.view(-1) # [samples * hash_numbers]


@torch.compile()
def torch_hash_code(x, code, pack_weight):
# [N, hash_bits]
x = x @ code
m = x.shape[:-1]
# [N, hash_bits] -- > [N, hash_bits // 8, 8]
x = (x > 0).to(torch.uint8).view(*m, -1, 8)
# 8bit -> 1bit
# binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8]
# then sum along the last dimension to get [N, hash_numbers]
x = torch.sum(x * pack_weight, dim=-1, dtype=torch.uint8)
x = x.view(-1) # [N * hash_numbers]
return x


class HashEncoder:
"""
Expand Down Expand Up @@ -105,8 +223,6 @@ def _init_bit_masks(self) -> None:
self.bit_masks = torch.pow(
2, torch.arange(8, dtype=torch.uint8, device=self.device)
)
# shape (1, 1, 8)
self.bit_masks = self.bit_masks.unsqueeze(0).unsqueeze(0)

def compute_hash(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -136,29 +252,24 @@ def compute_hash(self, x: torch.Tensor) -> torch.Tensor:
if x_flat.dtype != self.dtype:
x_flat = x_flat.to(self.dtype)

# [N, hash_bits]
xW = torch.matmul(x_flat, self.hash_weights)

# [N * hash_bits]
xW_flat = xW.view(-1)

if self.device.type == "npu":
# [N, hash_bits]
xW = torch.matmul(x_flat, self.hash_weights)
# [N * hash_bits]
xW_flat = xW.view(-1)
# [N*hash_numbers], where hash_numbers = hash_bits // 8
packed_codes_flat = torch_npu.npu_sign_bits_pack(xW_flat, size=1)
elif self.device.type == "cuda" or self.device.type == "cpu":
# (TODO) improve performance later on CUDA ops and CPU SIMD instructions
# [N, hash_bits]
projected = (xW > 0).to(torch.uint8)

# [N, hash_numbers, 8]
binary_codes = projected.view(-1, self.hash_numbers, 8)

# binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8]
# then sum along the last dimension to get [N, hash_numbers]
packed_codes_flat = torch.sum(
binary_codes * self.bit_masks, dim=-1, dtype=torch.uint8
) # [N, hash_numbers]
packed_codes_flat = packed_codes_flat.view(-1) # [N * hash_numbers]
elif self.device.type == "cuda":
packed_codes_flat = triton_hash_code(
x_flat, self.hash_weights, self.bit_masks
) # [N * hash_numbers]

elif self.device.type == "cpu":
packed_codes_flat = torch_hash_code(
x_flat, self.hash_weights, self.bit_masks
) # [N * hash_numbers]

else:
raise ValueError(f"Unsupported device type: {self.device.type}")

Expand Down Expand Up @@ -213,7 +324,7 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor:
) # expand last dim to 8

# (expanded & self.bit_masks) > 0 -> [N, hash_numbers, 8]
unpacked_bits = (expanded & self.bit_masks) > 0
unpacked_bits = (expanded & self.bit_masks.unsqueeze(0).unsqueeze(0)) > 0

# 0 -> -1, 1 -> 1
unpacked_bits = unpacked_bits * 2 - 1
Expand All @@ -232,20 +343,22 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor:


if __name__ == "__main__":
torch.manual_seed(42)

print("test HashEncoder...")
dtype = torch.float16
if hasattr(torch, "npu") and torch.npu.is_available():
device = torch.device("npu:0")
elif hasattr(torch, "cuda") and torch.cuda.is_available():
device = torch.device("cuda:0")
dtype = torch.float32
else:
device = torch.device("cpu")

print("Using device:", device)
encoder = HashEncoder(input_dim=8, hash_bits=8, dtype=dtype, device=device)

torch.manual_seed(42)

encoder = HashEncoder(input_dim=8, hash_bits=8, dtype=torch.float16, device=device)

x = torch.randn(2, 8, device=device, dtype=torch.float16)
x = torch.randn(2, 8, device=device, dtype=dtype)
print("x:", x)

hash_codes = encoder.compute_hash(x)
Expand All @@ -262,3 +375,31 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor:
print(
f"hash_codes[1].item()={hash_codes[1].item()}, 8-bit binary form:{hash_codes[1].item():08b}"
)

if hasattr(torch, "cuda") and torch.cuda.is_available():
print("test cuda triton and torch hash code functions...")
x = torch.randn((1024, 512), device="cuda:0", dtype=torch.bfloat16)
code = torch.randn((512, 512), device="cuda:0", dtype=torch.bfloat16)
pack_weight = torch.tensor(
[128, 64, 32, 16, 8, 4, 2, 1], device="cuda:0", dtype=torch.uint8
)

torch_output = torch_hash_code(x, code, pack_weight)
triton_output = triton_hash_code(x, code, pack_weight)
assert torch_output.shape == triton_output.shape
print(f"x_shape: {x.shape} code_shape: {code.shape}")
print("torch_output", torch_output)
print("triton_output", triton_output)
print(
f"The maximum difference between Torch and Triton is"
f" {torch.max(torch.abs(torch_output.to(torch.int32) - triton_output.to(torch.int32)))}"
)
# benchmark
print(
"torch:",
triton.testing.do_bench(lambda: torch_hash_code(x, code, pack_weight)),
)
print(
"triton:",
triton.testing.do_bench(lambda: triton_hash_code(x, code, pack_weight)),
)
Loading