@@ -286,6 +286,20 @@ def _float8_linear_impl(
286286 f"Don't expect to reach here with an override other than weight currently, { type (input_tensor )} { type (weight_tensor )} "
287287 )
288288
289+ # TODO: make this better
290+ # During the backward pass, we transpose the weight tensor,
291+ # so if the weight tensor was originally rowwise quantized,
292+ # now it becomes colwise. In this case, simply dequantize
293+ # the tensor and do a bf16 matmul
294+ is_backward = (
295+ weight_tensor .block_size [0 ] == weight_tensor .shape [0 ] and
296+ weight_tensor .block_size [1 ] == 1
297+ )
298+ if is_backward :
299+ return torch .nn .functional .linear (
300+ input_tensor , weight_tensor .dequantize (), bias ,
301+ )
302+
289303 act_quant_kwargs = weight_tensor .act_quant_kwargs
290304 # quantizing activation, if `act_quant_kwargs` is specified
291305 if act_quant_kwargs is not None :
@@ -321,8 +335,7 @@ def _float8_linear_impl(
321335 wq = weight_tensor .qdata
322336 x_scale = input_tensor .scale
323337 w_scale = weight_tensor .scale
324- # TODO: fix this?
325- if True : # _is_rowwise_scaled(weight_tensor):
338+ if _is_rowwise_scaled (weight_tensor ):
326339 assert _is_rowwise_scaled (input_tensor ), (
327340 "Input tensor must be rowwise block size"
328341 )
0 commit comments