@@ -38,7 +38,7 @@ class ConvDescriptor(Structure):
3838infiniopConvDescriptor_t = POINTER (ConvDescriptor )
3939
4040
41- def conv (x , w , stride , padding , dilation ):
41+ def conv (x , w , b , stride , padding , dilation ):
4242 ndim = len (x .shape ) - 2
4343 conv_func_map = {
4444 1 : F .conv1d ,
@@ -54,10 +54,10 @@ def conv(x, w, stride, padding, dilation):
5454 conv_func = conv_func_map [ndim ]
5555
5656 if PROFILE :
57- ans = conv_func (x , w , stride = stride , padding = padding , dilation = dilation )
57+ ans = conv_func (x , w , b , stride = stride , padding = padding , dilation = dilation )
5858 torch .cuda .synchronize ()
5959 return ans
60- return conv_func (x , w , stride = stride , padding = padding , dilation = dilation )
60+ return conv_func (x , w , b , stride = stride , padding = padding , dilation = dilation )
6161
6262
6363# infer the shape of the output given the inputs for a N-ary convolution
@@ -98,31 +98,34 @@ def test(
9898 pads ,
9999 strides ,
100100 dilations ,
101- tensor_stride = None ,
101+ add_bias ,
102102 tensor_dtype = torch .float16 ,
103103):
104104 assert len (pads ) == len (strides ) == len (dilations )
105105 print (
106- f"Testing Conv on { torch_device } with x_shape: { x_shape } , w_shape: { w_shape } , b_shape: { w_shape [0 ]} , pads: { pads } , strides: { strides } , dilations: { dilations } , x_stride: { tensor_stride } dtype:{ tensor_dtype } "
106+ f"Testing Conv on { torch_device } with x_shape: { x_shape } , w_shape: { w_shape } , add_bias: { add_bias } , "
107+ f"b_shape: { w_shape [0 ]} , pads: { pads } , strides: { strides } , dilations: { dilations } , dtype:{ tensor_dtype } "
107108 )
108109 x = torch .rand (x_shape , dtype = tensor_dtype ).to (torch_device )
109110 w = torch .rand (w_shape , dtype = tensor_dtype ).to (torch_device )
111+ b = torch .round ((torch .rand (w_shape [0 ], dtype = tensor_dtype ).to (torch_device ) * 2 - 1 ) * 1000 ) / 1000 if add_bias else None
110112 y = torch .zeros (
111113 inferShape (x .shape , w .shape , pads , strides , dilations ), dtype = tensor_dtype
112114 ).to (torch_device )
113115
114116 for i in range (NUM_PRERUN if PROFILE else 1 ):
115- ans = conv (x , w , strides , pads , dilations )
117+ ans = conv (x , w , b , strides , pads , dilations )
116118 if PROFILE :
117119 start_time = time .time ()
118120 for i in range (NUM_ITERATIONS ):
119- _ = conv (x , w , strides , pads , dilations )
121+ _ = conv (x , w , b , strides , pads , dilations )
120122 elapsed = (time .time () - start_time ) / NUM_ITERATIONS
121123 print (f"pytorch time: { elapsed :6f} " )
122124
123125
124126 x_tensor = to_tensor (x , lib )
125127 w_tensor = to_tensor (w , lib )
128+ b_tensor = to_tensor (b , lib ) if b is not None else None
126129 y_tensor = to_tensor (y , lib )
127130 descriptor = infiniopConvDescriptor_t ()
128131
@@ -133,6 +136,7 @@ def test(
133136 y_tensor .descriptor ,
134137 x_tensor .descriptor ,
135138 w_tensor .descriptor ,
139+ b_tensor .descriptor if b_tensor else None ,
136140 tuple_to_void_p (pads ),
137141 tuple_to_void_p (strides ),
138142 tuple_to_void_p (dilations ),
@@ -147,27 +151,33 @@ def test(
147151 workspace_ptr = ctypes .cast (workspace .data_ptr (), ctypes .POINTER (ctypes .c_uint8 ))
148152
149153 for i in range (NUM_PRERUN if PROFILE else 1 ):
150- lib .infiniopConv (
151- descriptor ,
152- workspace_ptr ,
153- workspaceSize ,
154- y_tensor .data ,
155- x_tensor .data ,
156- w_tensor .data ,
157- None ,
158- )
159- if PROFILE :
160- start_time = time .time ()
161- for i in range (NUM_ITERATIONS ):
154+ check_error (
162155 lib .infiniopConv (
163156 descriptor ,
164157 workspace_ptr ,
165158 workspaceSize ,
166159 y_tensor .data ,
167160 x_tensor .data ,
168161 w_tensor .data ,
162+ b_tensor .data if b_tensor else None ,
169163 None ,
170164 )
165+ )
166+ if PROFILE :
167+ start_time = time .time ()
168+ for i in range (NUM_ITERATIONS ):
169+ check_error (
170+ lib .infiniopConv (
171+ descriptor ,
172+ workspace_ptr ,
173+ workspaceSize ,
174+ y_tensor .data ,
175+ x_tensor .data ,
176+ w_tensor .data ,
177+ b_tensor .data if b_tensor else None ,
178+ None ,
179+ )
180+ )
171181 elapsed = (time .time () - start_time ) / NUM_ITERATIONS
172182 print (f" lib time: { elapsed :6f} " )
173183
@@ -181,18 +191,18 @@ def test(
181191def test_cpu (lib , test_cases ):
182192 device = DeviceEnum .DEVICE_CPU
183193 handle = create_handle (lib , device )
184- for x_shape , w_shape , pads , strides , dilations , x_strides in test_cases :
185- test (lib , handle , "cpu" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float16 )
186- test (lib , handle , "cpu" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float32 )
194+ for x_shape , w_shape , pads , strides , dilations , add_bias in test_cases :
195+ test (lib , handle , "cpu" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float16 )
196+ test (lib , handle , "cpu" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float32 )
187197 destroy_handle (lib , handle )
188198
189199
190200def test_cuda (lib , test_cases ):
191201 device = DeviceEnum .DEVICE_CUDA
192202 handle = create_handle (lib , device )
193- for x_shape , w_shape , pads , strides , dilations , x_strides in test_cases :
194- test (lib , handle , "cuda" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float16 )
195- test (lib , handle , "cuda" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float32 )
203+ for x_shape , w_shape , pads , strides , dilations , add_bias in test_cases :
204+ test (lib , handle , "cuda" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float16 )
205+ test (lib , handle , "cuda" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float32 )
196206 destroy_handle (lib , handle )
197207
198208
@@ -201,54 +211,62 @@ def test_bang(lib, test_cases):
201211
202212 device = DeviceEnum .DEVICE_BANG
203213 handle = create_handle (lib , device )
204- for x_shape , w_shape , pads , strides , dilations , x_strides in test_cases :
205- test (lib , handle , "mlu" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float16 )
206- test (lib , handle , "mlu" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float32 )
214+ for x_shape , w_shape , pads , strides , dilations , add_bias in test_cases :
215+ test (lib , handle , "mlu" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float16 )
216+ test (lib , handle , "mlu" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float32 )
207217 destroy_handle (lib , handle )
208218
209219
210220if __name__ == "__main__" :
211221 test_cases = [
212- # x_shape, w_shape, pads, strides, dilations, x_strides
222+ # x_shape, w_shape, pads, strides, dilations, add_bias
213223 (
214224 (32 , 3 , 4 ),
215225 (32 , 3 , 5 ),
216226 (1 ,),
217227 (1 ,),
218228 (1 ,),
219- None ,
229+ False ,
230+ ),
231+ (
232+ (3 , 7 , 4 ),
233+ (3 , 7 , 5 ),
234+ (1 ,),
235+ (1 ,),
236+ (1 ,),
237+ True ,
220238 ),
221239 (
222240 (1 , 3 , 4 , 4 ),
223241 (2 , 3 , 3 , 3 ),
224242 (1 , 1 ),
225243 (1 , 2 ),
226244 (2 , 1 ),
227- None ,
245+ True ,
228246 ),
229247 (
230248 (32 , 3 , 128 , 128 ),
231249 (64 , 3 , 5 , 5 ),
232250 (2 , 2 ),
233251 (2 , 2 ),
234252 (1 , 1 ),
235- None ,
253+ False ,
236254 ),
237255 (
238256 (1 , 1 , 4 , 4 , 4 ),
239257 (1 , 1 , 5 , 5 , 5 ),
240258 (1 , 1 , 1 ),
241259 (1 , 1 , 1 ),
242260 (1 , 1 , 1 ),
243- None ,
261+ True ,
244262 ),
245263 (
246264 (32 , 3 , 32 , 32 , 32 ),
247265 (64 , 3 , 5 , 5 , 5 ),
248266 (3 , 2 , 2 ),
249267 (4 , 3 , 3 ),
250268 (2 , 2 , 1 ),
251- None ,
269+ False ,
252270 ),
253271 ]
254272 args = get_args ()
@@ -260,6 +278,7 @@ def test_bang(lib, test_cases):
260278 infiniopTensorDescriptor_t ,
261279 infiniopTensorDescriptor_t ,
262280 infiniopTensorDescriptor_t ,
281+ infiniopTensorDescriptor_t ,
263282 c_void_p ,
264283 c_void_p ,
265284 c_void_p ,
@@ -274,6 +293,7 @@ def test_bang(lib, test_cases):
274293 c_void_p ,
275294 c_void_p ,
276295 c_void_p ,
296+ c_void_p ,
277297 ]
278298 lib .infiniopDestroyConvDescriptor .restype = c_int32
279299 lib .infiniopDestroyConvDescriptor .argtypes = [
0 commit comments