Skip to content

Commit 7ba9b69

Browse files
committed
[perf] Modify CUDA SIMD and add Triton hash encoder
1 parent 89c3736 commit 7ba9b69

File tree

5 files changed

+320
-128
lines changed

5 files changed

+320
-128
lines changed

ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ class RetrievalWorkerBackend {
5858
if (rc != 0) {
5959
std::cerr << "Error binding memory to NUMA node " << numaId << std::endl;
6060
}
61-
#else
62-
std::cerr << "NUMA support is disabled." << std::endl;
6361
#endif
6462
}
6563

ucm/sparse/kvcomp/hash_encoder.py

Lines changed: 168 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,124 @@
3131

3232
logger = init_logger(__name__)
3333

34+
if hasattr(torch, "cuda") and torch.cuda.is_available():
35+
from vllm.triton_utils import tl, triton
36+
37+
@triton.jit
38+
def triton_hash_code_kernel(
39+
x_ptr,
40+
code_ptr,
41+
pack_w_ptr,
42+
hash_out_ptr,
43+
M,
44+
K,
45+
N,
46+
stride_xm,
47+
stride_xk,
48+
stride_codek,
49+
stride_coden,
50+
stride_pack_w,
51+
stride_om,
52+
stride_on,
53+
BLOCK_M: tl.constexpr,
54+
BLOCK_K: tl.constexpr,
55+
BLOCK_N: tl.constexpr,
56+
):
57+
pid_m = tl.program_id(0)
58+
pid_n = tl.program_id(1)
59+
60+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # sample dimension
61+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # hash_rbits dimension
62+
offs_k = tl.arange(0, BLOCK_K) # input_dim dimension
63+
64+
# Matrix multiplication
65+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
66+
for k in range(0, tl.cdiv(K, BLOCK_K)):
67+
x = tl.load(
68+
x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
69+
mask=(offs_m[:, None] < M) & (offs_k[None, :] < K),
70+
other=0.0,
71+
)
72+
code = tl.load(
73+
code_ptr
74+
+ offs_k[:, None] * stride_codek
75+
+ offs_n[None, :] * stride_coden,
76+
mask=(offs_k[:, None] < K) & (offs_n[None, :] < N),
77+
other=0.0,
78+
)
79+
acc += tl.dot(x, code)
80+
offs_k += BLOCK_K
81+
82+
# Binarize and pack
83+
bits = (acc > 0).to(tl.uint8) # Binarize
84+
bits = tl.reshape(bits, (BLOCK_M, BLOCK_N // 8, 8)) # Reshape for packing
85+
86+
# Load the packing weights (ensure it has the correct shape)
87+
pack_w = tl.load(pack_w_ptr + tl.arange(0, 8) * stride_pack_w)
88+
packed = tl.sum(bits * pack_w[None, None, :], axis=-1).to(tl.uint8)
89+
90+
# Store results
91+
offs_n = pid_n * (BLOCK_N // 8) + tl.arange(0, BLOCK_N // 8)
92+
hash_out_ptrs = (
93+
hash_out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
94+
)
95+
tl.store(
96+
hash_out_ptrs,
97+
packed,
98+
mask=(offs_m[:, None] < M) & (offs_n[None, :] < (N // 8)),
99+
)
100+
101+
def triton_hash_code(x, code, pack_weight):
102+
input_dim = x.shape[-1]
103+
samples = x.shape[0]
104+
hash_bits = code.shape[-1]
105+
assert (pack_weight.shape[0] == 8) and (hash_bits % 8 == 0)
106+
hash_out = torch.empty(
107+
(samples, hash_bits // 8), dtype=pack_weight.dtype, device=x.device
108+
)
109+
110+
grid = lambda opts: (
111+
triton.cdiv(samples, opts["BLOCK_M"]),
112+
triton.cdiv(input_dim, opts["BLOCK_N"]),
113+
)
114+
115+
triton_hash_code_kernel[grid](
116+
x,
117+
code,
118+
pack_weight,
119+
hash_out,
120+
samples,
121+
input_dim,
122+
hash_bits,
123+
x.stride(0),
124+
x.stride(1),
125+
code.stride(0),
126+
code.stride(1),
127+
pack_weight.stride(0),
128+
hash_out.stride(0),
129+
hash_out.stride(1),
130+
BLOCK_M=32,
131+
BLOCK_K=64,
132+
BLOCK_N=16,
133+
)
134+
135+
return hash_out.view(-1) # [samples * hash_numbers]
136+
137+
138+
@torch.compile()
139+
def torch_hash_code(x, code, pack_weight):
140+
# [N, hash_bits]
141+
x = x @ code
142+
m = x.shape[:-1]
143+
# [N, hash_bits] -- > [N, hash_bits // 8, 8]
144+
x = (x > 0).to(torch.uint8).view(*m, -1, 8)
145+
# 8bit -> 1bit
146+
# binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8]
147+
# then sum along the last dimension to get [N, hash_numbers]
148+
x = torch.sum(x * pack_weight, dim=-1, dtype=torch.uint8)
149+
x = x.view(-1) # [N * hash_numbers]
150+
return x
151+
34152

35153
class HashEncoder:
36154
"""
@@ -105,8 +223,6 @@ def _init_bit_masks(self) -> None:
105223
self.bit_masks = torch.pow(
106224
2, torch.arange(8, dtype=torch.uint8, device=self.device)
107225
)
108-
# shape (1, 1, 8)
109-
self.bit_masks = self.bit_masks.unsqueeze(0).unsqueeze(0)
110226

111227
def compute_hash(self, x: torch.Tensor) -> torch.Tensor:
112228
"""
@@ -136,29 +252,24 @@ def compute_hash(self, x: torch.Tensor) -> torch.Tensor:
136252
if x_flat.dtype != self.dtype:
137253
x_flat = x_flat.to(self.dtype)
138254

139-
# [N, hash_bits]
140-
xW = torch.matmul(x_flat, self.hash_weights)
141-
142-
# [N * hash_bits]
143-
xW_flat = xW.view(-1)
144-
145255
if self.device.type == "npu":
256+
# [N, hash_bits]
257+
xW = torch.matmul(x_flat, self.hash_weights)
258+
# [N * hash_bits]
259+
xW_flat = xW.view(-1)
146260
# [N*hash_numbers], where hash_numbers = hash_bits // 8
147261
packed_codes_flat = torch_npu.npu_sign_bits_pack(xW_flat, size=1)
148-
elif self.device.type == "cuda" or self.device.type == "cpu":
149-
# (TODO) improve performance later on CUDA ops and CPU SIMD instructions
150-
# [N, hash_bits]
151-
projected = (xW > 0).to(torch.uint8)
152262

153-
# [N, hash_numbers, 8]
154-
binary_codes = projected.view(-1, self.hash_numbers, 8)
155-
156-
# binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8]
157-
# then sum along the last dimension to get [N, hash_numbers]
158-
packed_codes_flat = torch.sum(
159-
binary_codes * self.bit_masks, dim=-1, dtype=torch.uint8
160-
) # [N, hash_numbers]
161-
packed_codes_flat = packed_codes_flat.view(-1) # [N * hash_numbers]
263+
elif self.device.type == "cuda":
264+
packed_codes_flat = triton_hash_code(
265+
x_flat, self.hash_weights, self.bit_masks
266+
) # [N * hash_numbers]
267+
268+
elif self.device.type == "cpu":
269+
packed_codes_flat = torch_hash_code(
270+
x_flat, self.hash_weights, self.bit_masks
271+
) # [N * hash_numbers]
272+
162273
else:
163274
raise ValueError(f"Unsupported device type: {self.device.type}")
164275

@@ -213,7 +324,7 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor:
213324
) # expand last dim to 8
214325

215326
# (expanded & self.bit_masks) > 0 -> [N, hash_numbers, 8]
216-
unpacked_bits = (expanded & self.bit_masks) > 0
327+
unpacked_bits = (expanded & self.bit_masks.unsqueeze(0).unsqueeze(0)) > 0
217328

218329
# 0 -> -1, 1 -> 1
219330
unpacked_bits = unpacked_bits * 2 - 1
@@ -232,20 +343,22 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor:
232343

233344

234345
if __name__ == "__main__":
346+
torch.manual_seed(42)
347+
348+
print("test HashEncoder...")
349+
dtype = torch.float16
235350
if hasattr(torch, "npu") and torch.npu.is_available():
236351
device = torch.device("npu:0")
237352
elif hasattr(torch, "cuda") and torch.cuda.is_available():
238353
device = torch.device("cuda:0")
354+
dtype = torch.float32
239355
else:
240356
device = torch.device("cpu")
241357

242358
print("Using device:", device)
359+
encoder = HashEncoder(input_dim=8, hash_bits=8, dtype=dtype, device=device)
243360

244-
torch.manual_seed(42)
245-
246-
encoder = HashEncoder(input_dim=8, hash_bits=8, dtype=torch.float16, device=device)
247-
248-
x = torch.randn(2, 8, device=device, dtype=torch.float16)
361+
x = torch.randn(2, 8, device=device, dtype=dtype)
249362
print("x:", x)
250363

251364
hash_codes = encoder.compute_hash(x)
@@ -262,3 +375,31 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor:
262375
print(
263376
f"hash_codes[1].item()={hash_codes[1].item()}, 8-bit binary form:{hash_codes[1].item():08b}"
264377
)
378+
379+
if hasattr(torch, "cuda") and torch.cuda.is_available():
380+
print("test cuda triton and torch hash code functions...")
381+
x = torch.randn((1024, 512), device="cuda:0", dtype=torch.bfloat16)
382+
code = torch.randn((512, 512), device="cuda:0", dtype=torch.bfloat16)
383+
pack_weight = torch.tensor(
384+
[128, 64, 32, 16, 8, 4, 2, 1], device="cuda:0", dtype=torch.uint8
385+
)
386+
387+
torch_output = torch_hash_code(x, code, pack_weight)
388+
triton_output = triton_hash_code(x, code, pack_weight)
389+
assert torch_output.shape == triton_output.shape
390+
print(f"x_shape: {x.shape} code_shape: {code.shape}")
391+
print("torch_output", torch_output)
392+
print("triton_output", triton_output)
393+
print(
394+
f"The maximum difference between Torch and Triton is"
395+
f" {torch.max(torch.abs(torch_output.to(torch.int32) - triton_output.to(torch.int32)))}"
396+
)
397+
# benchmark
398+
print(
399+
"torch:",
400+
triton.testing.do_bench(lambda: torch_hash_code(x, code, pack_weight)),
401+
)
402+
print(
403+
"triton:",
404+
triton.testing.do_bench(lambda: triton_hash_code(x, code, pack_weight)),
405+
)

0 commit comments

Comments
 (0)