@@ -38,7 +38,7 @@ class ConvDescriptor(Structure):
3838infiniopConvDescriptor_t = POINTER (ConvDescriptor )
3939
4040
41- def conv (x , w , stride , padding , dilation ):
41+ def conv (x , w , b , stride , padding , dilation ):
4242 ndim = len (x .shape ) - 2
4343 conv_func_map = {
4444 1 : F .conv1d ,
@@ -54,10 +54,10 @@ def conv(x, w, stride, padding, dilation):
5454 conv_func = conv_func_map [ndim ]
5555
5656 if PROFILE :
57- ans = conv_func (x , w , stride = stride , padding = padding , dilation = dilation )
57+ ans = conv_func (x , w , b , stride = stride , padding = padding , dilation = dilation )
5858 torch .cuda .synchronize ()
5959 return ans
60- return conv_func (x , w , stride = stride , padding = padding , dilation = dilation )
60+ return conv_func (x , w , b , stride = stride , padding = padding , dilation = dilation )
6161
6262
6363# infer the shape of the output given the inputs for a N-ary convolution
@@ -98,30 +98,33 @@ def test(
9898 pads ,
9999 strides ,
100100 dilations ,
101- tensor_stride = None ,
101+ add_bias ,
102102 tensor_dtype = torch .float16 ,
103103):
104104 assert len (pads ) == len (strides ) == len (dilations )
105105 print (
106- f"Testing Conv on { torch_device } with x_shape: { x_shape } , w_shape: { w_shape } , b_shape: { w_shape [0 ]} , pads: { pads } , strides: { strides } , dilations: { dilations } , x_stride: { tensor_stride } dtype:{ tensor_dtype } "
106+ f"Testing Conv on { torch_device } with x_shape: { x_shape } , w_shape: { w_shape } , add_bias: { add_bias } , "
107+ f"b_shape: { w_shape [0 ]} , pads: { pads } , strides: { strides } , dilations: { dilations } , dtype:{ tensor_dtype } "
107108 )
108109 x = torch .rand (x_shape , dtype = tensor_dtype ).to (torch_device )
109110 w = torch .rand (w_shape , dtype = tensor_dtype ).to (torch_device )
111+ b = torch .round ((torch .rand (w_shape [0 ], dtype = tensor_dtype ).to (torch_device ) * 2 - 1 ) * 1000 ) / 1000 if add_bias else None
110112 y = torch .zeros (
111113 inferShape (x .shape , w .shape , pads , strides , dilations ), dtype = tensor_dtype
112114 ).to (torch_device )
113115
114116 for i in range (NUM_PRERUN if PROFILE else 1 ):
115- ans = conv (x , w , strides , pads , dilations )
117+ ans = conv (x , w , b , strides , pads , dilations )
116118 if PROFILE :
117119 start_time = time .time ()
118120 for i in range (NUM_ITERATIONS ):
119- _ = conv (x , w , strides , pads , dilations )
121+ _ = conv (x , w , b , strides , pads , dilations )
120122 elapsed = (time .time () - start_time ) / NUM_ITERATIONS
121123 print (f"pytorch time: { elapsed :6f} " )
122124
123125 x_tensor = to_tensor (x , lib )
124126 w_tensor = to_tensor (w , lib )
127+ b_tensor = to_tensor (b , lib ) if b is not None else None
125128 y_tensor = to_tensor (y , lib )
126129 descriptor = infiniopConvDescriptor_t ()
127130
@@ -132,6 +135,7 @@ def test(
132135 y_tensor .descriptor ,
133136 x_tensor .descriptor ,
134137 w_tensor .descriptor ,
138+ b_tensor .descriptor if b_tensor else None ,
135139 tuple_to_void_p (pads ),
136140 tuple_to_void_p (strides ),
137141 tuple_to_void_p (dilations ),
@@ -154,6 +158,7 @@ def test(
154158 y_tensor .data ,
155159 x_tensor .data ,
156160 w_tensor .data ,
161+ b_tensor .data if b_tensor else None ,
157162 None ,
158163 )
159164 )
@@ -168,6 +173,10 @@ def test(
168173 y_tensor .data ,
169174 x_tensor .data ,
170175 w_tensor .data ,
176+ << << << < HEAD
177+ == == == =
178+ b_tensor .data if b_tensor else None ,
179+ >> >> >> > 5 b25aa1 (Rename ConvBiasAct to ConvAct , make bias optional for both conv and conAct , add WARN , etc . )
171180 None ,
172181 )
173182 )
@@ -184,18 +193,18 @@ def test(
184193def test_cpu (lib , test_cases ):
185194 device = DeviceEnum .DEVICE_CPU
186195 handle = create_handle (lib , device )
187- for x_shape , w_shape , pads , strides , dilations , x_strides in test_cases :
188- test (lib , handle , "cpu" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float16 )
189- test (lib , handle , "cpu" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float32 )
196+ for x_shape , w_shape , pads , strides , dilations , add_bias in test_cases :
197+ test (lib , handle , "cpu" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float16 )
198+ test (lib , handle , "cpu" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float32 )
190199 destroy_handle (lib , handle )
191200
192201
193202def test_cuda (lib , test_cases ):
194203 device = DeviceEnum .DEVICE_CUDA
195204 handle = create_handle (lib , device )
196- for x_shape , w_shape , pads , strides , dilations , x_strides in test_cases :
197- test (lib , handle , "cuda" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float16 )
198- test (lib , handle , "cuda" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float32 )
205+ for x_shape , w_shape , pads , strides , dilations , add_bias in test_cases :
206+ test (lib , handle , "cuda" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float16 )
207+ test (lib , handle , "cuda" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float32 )
199208 destroy_handle (lib , handle )
200209
201210
@@ -204,54 +213,62 @@ def test_bang(lib, test_cases):
204213
205214 device = DeviceEnum .DEVICE_BANG
206215 handle = create_handle (lib , device )
207- for x_shape , w_shape , pads , strides , dilations , x_strides in test_cases :
208- test (lib , handle , "mlu" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float16 )
209- test (lib , handle , "mlu" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float32 )
216+ for x_shape , w_shape , pads , strides , dilations , add_bias in test_cases :
217+ test (lib , handle , "mlu" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float16 )
218+ test (lib , handle , "mlu" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float32 )
210219 destroy_handle (lib , handle )
211220
212221
213222if __name__ == "__main__" :
214223 test_cases = [
215- # x_shape, w_shape, pads, strides, dilations, x_strides
224+ # x_shape, w_shape, pads, strides, dilations, add_bias
216225 (
217226 (32 , 3 , 4 ),
218227 (32 , 3 , 5 ),
219228 (1 ,),
220229 (1 ,),
221230 (1 ,),
222- None ,
231+ False ,
232+ ),
233+ (
234+ (3 , 7 , 4 ),
235+ (3 , 7 , 5 ),
236+ (1 ,),
237+ (1 ,),
238+ (1 ,),
239+ True ,
223240 ),
224241 (
225242 (1 , 3 , 4 , 4 ),
226243 (2 , 3 , 3 , 3 ),
227244 (1 , 1 ),
228245 (1 , 2 ),
229246 (2 , 1 ),
230- None ,
247+ True ,
231248 ),
232249 (
233250 (32 , 3 , 128 , 128 ),
234251 (64 , 3 , 5 , 5 ),
235252 (2 , 2 ),
236253 (2 , 2 ),
237254 (1 , 1 ),
238- None ,
255+ False ,
239256 ),
240257 (
241258 (1 , 1 , 4 , 4 , 4 ),
242259 (1 , 1 , 5 , 5 , 5 ),
243260 (1 , 1 , 1 ),
244261 (1 , 1 , 1 ),
245262 (1 , 1 , 1 ),
246- None ,
263+ True ,
247264 ),
248265 (
249266 (32 , 3 , 32 , 32 , 32 ),
250267 (64 , 3 , 5 , 5 , 5 ),
251268 (3 , 2 , 2 ),
252269 (4 , 3 , 3 ),
253270 (2 , 2 , 1 ),
254- None ,
271+ False ,
255272 ),
256273 ]
257274 args = get_args ()
@@ -263,6 +280,7 @@ def test_bang(lib, test_cases):
263280 infiniopTensorDescriptor_t ,
264281 infiniopTensorDescriptor_t ,
265282 infiniopTensorDescriptor_t ,
283+ infiniopTensorDescriptor_t ,
266284 c_void_p ,
267285 c_void_p ,
268286 c_void_p ,
@@ -277,6 +295,7 @@ def test_bang(lib, test_cases):
277295 c_void_p ,
278296 c_void_p ,
279297 c_void_p ,
298+ c_void_p ,
280299 ]
281300 lib .infiniopDestroyConvDescriptor .restype = c_int32
282301 lib .infiniopDestroyConvDescriptor .argtypes = [
0 commit comments