3131
3232logger = init_logger (__name__ )
3333
34+ if hasattr (torch , "cuda" ) and torch .cuda .is_available ():
35+ from vllm .triton_utils import tl , triton
36+
37+ @triton .jit
38+ def triton_hash_code_kernel (
39+ x_ptr ,
40+ code_ptr ,
41+ pack_w_ptr ,
42+ hash_out_ptr ,
43+ M ,
44+ K ,
45+ N ,
46+ stride_xm ,
47+ stride_xk ,
48+ stride_codek ,
49+ stride_coden ,
50+ stride_pack_w ,
51+ stride_om ,
52+ stride_on ,
53+ BLOCK_M : tl .constexpr ,
54+ BLOCK_K : tl .constexpr ,
55+ BLOCK_N : tl .constexpr ,
56+ ):
57+ pid_m = tl .program_id (0 )
58+ pid_n = tl .program_id (1 )
59+
60+ offs_m = pid_m * BLOCK_M + tl .arange (0 , BLOCK_M ) # sample dimension
61+ offs_n = pid_n * BLOCK_N + tl .arange (0 , BLOCK_N ) # hash_rbits dimension
62+ offs_k = tl .arange (0 , BLOCK_K ) # input_dim dimension
63+
64+ # Matrix multiplication
65+ acc = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
66+ for k in range (0 , tl .cdiv (K , BLOCK_K )):
67+ x = tl .load (
68+ x_ptr + offs_m [:, None ] * stride_xm + offs_k [None , :] * stride_xk ,
69+ mask = (offs_m [:, None ] < M ) & (offs_k [None , :] < K ),
70+ other = 0.0 ,
71+ )
72+ code = tl .load (
73+ code_ptr
74+ + offs_k [:, None ] * stride_codek
75+ + offs_n [None , :] * stride_coden ,
76+ mask = (offs_k [:, None ] < K ) & (offs_n [None , :] < N ),
77+ other = 0.0 ,
78+ )
79+ acc += tl .dot (x , code )
80+ offs_k += BLOCK_K
81+
82+ # Binarize and pack
83+ bits = (acc > 0 ).to (tl .uint8 ) # Binarize
84+ bits = tl .reshape (bits , (BLOCK_M , BLOCK_N // 8 , 8 )) # Reshape for packing
85+
86+ # Load the packing weights (ensure it has the correct shape)
87+ pack_w = tl .load (pack_w_ptr + tl .arange (0 , 8 ) * stride_pack_w )
88+ packed = tl .sum (bits * pack_w [None , None , :], axis = - 1 ).to (tl .uint8 )
89+
90+ # Store results
91+ offs_n = pid_n * (BLOCK_N // 8 ) + tl .arange (0 , BLOCK_N // 8 )
92+ hash_out_ptrs = (
93+ hash_out_ptr + offs_m [:, None ] * stride_om + offs_n [None , :] * stride_on
94+ )
95+ tl .store (
96+ hash_out_ptrs ,
97+ packed ,
98+ mask = (offs_m [:, None ] < M ) & (offs_n [None , :] < (N // 8 )),
99+ )
100+
101+ def triton_hash_code (x , code , pack_weight ):
102+ input_dim = x .shape [- 1 ]
103+ samples = x .shape [0 ]
104+ hash_bits = code .shape [- 1 ]
105+ assert (pack_weight .shape [0 ] == 8 ) and (hash_bits % 8 == 0 )
106+ hash_out = torch .empty (
107+ (samples , hash_bits // 8 ), dtype = pack_weight .dtype , device = x .device
108+ )
109+
110+ grid = lambda opts : (
111+ triton .cdiv (samples , opts ["BLOCK_M" ]),
112+ triton .cdiv (input_dim , opts ["BLOCK_N" ]),
113+ )
114+
115+ triton_hash_code_kernel [grid ](
116+ x ,
117+ code ,
118+ pack_weight ,
119+ hash_out ,
120+ samples ,
121+ input_dim ,
122+ hash_bits ,
123+ x .stride (0 ),
124+ x .stride (1 ),
125+ code .stride (0 ),
126+ code .stride (1 ),
127+ pack_weight .stride (0 ),
128+ hash_out .stride (0 ),
129+ hash_out .stride (1 ),
130+ BLOCK_M = 32 ,
131+ BLOCK_K = 64 ,
132+ BLOCK_N = 16 ,
133+ )
134+
135+ return hash_out .view (- 1 ) # [samples * hash_numbers]
136+
137+
138+ @torch .compile ()
139+ def torch_hash_code (x , code , pack_weight ):
140+ # [N, hash_bits]
141+ x = x @ code
142+ m = x .shape [:- 1 ]
143+ # [N, hash_bits] -- > [N, hash_bits // 8, 8]
144+ x = (x > 0 ).to (torch .uint8 ).view (* m , - 1 , 8 )
145+ # 8bit -> 1bit
146+ # binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8]
147+ # then sum along the last dimension to get [N, hash_numbers]
148+ x = torch .sum (x * pack_weight , dim = - 1 , dtype = torch .uint8 )
149+ x = x .view (- 1 ) # [N * hash_numbers]
150+ return x
151+
34152
35153class HashEncoder :
36154 """
@@ -105,8 +223,6 @@ def _init_bit_masks(self) -> None:
105223 self .bit_masks = torch .pow (
106224 2 , torch .arange (8 , dtype = torch .uint8 , device = self .device )
107225 )
108- # shape (1, 1, 8)
109- self .bit_masks = self .bit_masks .unsqueeze (0 ).unsqueeze (0 )
110226
111227 def compute_hash (self , x : torch .Tensor ) -> torch .Tensor :
112228 """
@@ -136,29 +252,24 @@ def compute_hash(self, x: torch.Tensor) -> torch.Tensor:
136252 if x_flat .dtype != self .dtype :
137253 x_flat = x_flat .to (self .dtype )
138254
139- # [N, hash_bits]
140- xW = torch .matmul (x_flat , self .hash_weights )
141-
142- # [N * hash_bits]
143- xW_flat = xW .view (- 1 )
144-
145255 if self .device .type == "npu" :
256+ # [N, hash_bits]
257+ xW = torch .matmul (x_flat , self .hash_weights )
258+ # [N * hash_bits]
259+ xW_flat = xW .view (- 1 )
146260 # [N*hash_numbers], where hash_numbers = hash_bits // 8
147261 packed_codes_flat = torch_npu .npu_sign_bits_pack (xW_flat , size = 1 )
148- elif self .device .type == "cuda" or self .device .type == "cpu" :
149- # (TODO) improve performance later on CUDA ops and CPU SIMD instructions
150- # [N, hash_bits]
151- projected = (xW > 0 ).to (torch .uint8 )
152262
153- # [N, hash_numbers, 8]
154- binary_codes = projected .view (- 1 , self .hash_numbers , 8 )
155-
156- # binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8]
157- # then sum along the last dimension to get [N, hash_numbers]
158- packed_codes_flat = torch .sum (
159- binary_codes * self .bit_masks , dim = - 1 , dtype = torch .uint8
160- ) # [N, hash_numbers]
161- packed_codes_flat = packed_codes_flat .view (- 1 ) # [N * hash_numbers]
263+ elif self .device .type == "cuda" :
264+ packed_codes_flat = triton_hash_code (
265+ x_flat , self .hash_weights , self .bit_masks
266+ ) # [N * hash_numbers]
267+
268+ elif self .device .type == "cpu" :
269+ packed_codes_flat = torch_hash_code (
270+ x_flat , self .hash_weights , self .bit_masks
271+ ) # [N * hash_numbers]
272+
162273 else :
163274 raise ValueError (f"Unsupported device type: { self .device .type } " )
164275
@@ -213,7 +324,7 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor:
213324 ) # expand last dim to 8
214325
215326 # (expanded & self.bit_masks) > 0 -> [N, hash_numbers, 8]
216- unpacked_bits = (expanded & self .bit_masks ) > 0
327+ unpacked_bits = (expanded & self .bit_masks . unsqueeze ( 0 ). unsqueeze ( 0 ) ) > 0
217328
218329 # 0 -> -1, 1 -> 1
219330 unpacked_bits = unpacked_bits * 2 - 1
@@ -232,20 +343,22 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor:
232343
233344
234345if __name__ == "__main__" :
346+ torch .manual_seed (42 )
347+
348+ print ("test HashEncoder..." )
349+ dtype = torch .float16
235350 if hasattr (torch , "npu" ) and torch .npu .is_available ():
236351 device = torch .device ("npu:0" )
237352 elif hasattr (torch , "cuda" ) and torch .cuda .is_available ():
238353 device = torch .device ("cuda:0" )
354+ dtype = torch .float32
239355 else :
240356 device = torch .device ("cpu" )
241357
242358 print ("Using device:" , device )
359+ encoder = HashEncoder (input_dim = 8 , hash_bits = 8 , dtype = dtype , device = device )
243360
244- torch .manual_seed (42 )
245-
246- encoder = HashEncoder (input_dim = 8 , hash_bits = 8 , dtype = torch .float16 , device = device )
247-
248- x = torch .randn (2 , 8 , device = device , dtype = torch .float16 )
361+ x = torch .randn (2 , 8 , device = device , dtype = dtype )
249362 print ("x:" , x )
250363
251364 hash_codes = encoder .compute_hash (x )
@@ -262,3 +375,31 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor:
262375 print (
263376 f"hash_codes[1].item()={ hash_codes [1 ].item ()} , 8-bit binary form:{ hash_codes [1 ].item ():08b} "
264377 )
378+
379+ if hasattr (torch , "cuda" ) and torch .cuda .is_available ():
380+ print ("test cuda triton and torch hash code functions..." )
381+ x = torch .randn ((1024 , 512 ), device = "cuda:0" , dtype = torch .bfloat16 )
382+ code = torch .randn ((512 , 512 ), device = "cuda:0" , dtype = torch .bfloat16 )
383+ pack_weight = torch .tensor (
384+ [128 , 64 , 32 , 16 , 8 , 4 , 2 , 1 ], device = "cuda:0" , dtype = torch .uint8
385+ )
386+
387+ torch_output = torch_hash_code (x , code , pack_weight )
388+ triton_output = triton_hash_code (x , code , pack_weight )
389+ assert torch_output .shape == triton_output .shape
390+ print (f"x_shape: { x .shape } code_shape: { code .shape } " )
391+ print ("torch_output" , torch_output )
392+ print ("triton_output" , triton_output )
393+ print (
394+ f"The maximum difference between Torch and Triton is"
395+ f" { torch .max (torch .abs (torch_output .to (torch .int32 ) - triton_output .to (torch .int32 )))} "
396+ )
397+ # benchmark
398+ print (
399+ "torch:" ,
400+ triton .testing .do_bench (lambda : torch_hash_code (x , code , pack_weight )),
401+ )
402+ print (
403+ "triton:" ,
404+ triton .testing .do_bench (lambda : triton_hash_code (x , code , pack_weight )),
405+ )
0 commit comments