@@ -863,16 +863,17 @@ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_s
863863 values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_1[None, :], other=0)
864864 _mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _BLOCK_SIZE_1]), values, float('-inf'))
865865 local_amax = tl.max(_mask_to, 1)
866- mi = triton_helpers.maximum(mi_copy_0, local_amax)
867- v_1 = mi_copy_0 - mi
866+ v_0 = triton_helpers.maximum(mi_copy_0, local_amax)
867+ v_1 = mi_copy_0 - v_0
868868 v_2 = tl_math.exp(v_1)
869869 v_3 = di_copy_0 * v_2
870- subscript = mi [:, None]
870+ subscript = v_0 [:, None]
871871 v_4 = values - subscript
872872 v_5 = tl_math.exp(v_4)
873873 _mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _BLOCK_SIZE_1]), v_5, 0)
874874 sum_1 = tl.sum(_mask_to_1, 1)
875875 di = v_3 + sum_1
876+ mi = v_0
876877 for offset_2 in range(0, n.to(tl.int32), _BLOCK_SIZE_1):
877878 indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
878879 mask_2 = indices_2 < n
@@ -945,16 +946,17 @@ def _softmax_two_pass_kernel(x, out, out_size_0, out_size_1, x_size_0, x_size_1,
945946 values = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
946947 _mask_to = tl.where(mask_0[:, None] & mask_1[None, :], values, float('-inf'))
947948 local_amax = tl.max(_mask_to, 1)
948- mi = triton_helpers.maximum(mi_copy_0, local_amax)
949- v_1 = mi_copy_0 - mi
949+ v_0 = triton_helpers.maximum(mi_copy_0, local_amax)
950+ v_1 = mi_copy_0 - v_0
950951 v_2 = tl_math.exp(v_1)
951952 v_3 = di_copy_0 * v_2
952- subscript = mi [:, None]
953+ subscript = v_0 [:, None]
953954 v_4 = values - subscript
954955 v_5 = tl_math.exp(v_4)
955956 _mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], v_5, 0)
956957 sum_1 = tl.sum(_mask_to_1, 1)
957958 di = v_3 + sum_1
959+ mi = v_0
958960 for offset_2 in range(0, n.to(tl.int32), _BLOCK_SIZE_1):
959961 indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
960962 mi_copy_1 = mi
@@ -1148,21 +1150,22 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
11481150 amax = tl.max(qk, 2)
11491151 v_0 = 0.18033688
11501152 v_1 = amax * v_0
1151- m_i = triton_helpers.maximum(m_i_copy_0, v_1)
1153+ v_2 = triton_helpers.maximum(m_i_copy_0, v_1)
11521154 v_3 = 0.18033688
11531155 v_4 = qk * v_3
1154- subscript = m_i [:, :, None]
1156+ subscript = v_2 [:, :, None]
11551157 v_5 = v_4 - subscript
11561158 v_6 = libdevice.exp2(v_5)
11571159 l_ij = tl.sum(v_6, 2)
1158- v_7 = m_i_copy_0 - m_i
1160+ v_7 = m_i_copy_0 - v_2
11591161 v_8 = libdevice.exp2(v_7)
11601162 v_9 = l_i_copy_0 * v_8
11611163 l_i = v_9 + l_ij
11621164 subscript_1 = v_8[:, :, None]
11631165 v_11 = acc_copy_0 * subscript_1
11641166 v = tl.load(v_view + (indices_0[:, None, None] * 32768 + indices_2[None, :, None] * 64 + indices_4[None, None, :] * 1), None)
11651167 acc = tl.reshape(tl.dot(tl.reshape(v_6, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
1168+ m_i = v_2
11661169 subscript_2 = l_i[:, :, None]
11671170 v_12 = acc / subscript_2
11681171 tl.store(out + (indices_0[:, None, None] * 32768 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), v_12, None)
@@ -1254,15 +1257,15 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
12541257 v_0 = tl.full([], 0.18033688, tl.float16)
12551258 v_1 = amax * v_0
12561259 v_2 = v_1.to(tl.float32)
1257- m_i = triton_helpers.maximum(m_i_copy_0, v_2)
1260+ v_3 = triton_helpers.maximum(m_i_copy_0, v_2)
12581261 v_4 = tl.full([], 0.18033688, tl.float16)
12591262 v_5 = qk * v_4
1260- subscript = m_i [:, :, None]
1263+ subscript = v_3 [:, :, None]
12611264 v_6 = v_5.to(tl.float32)
12621265 v_7 = v_6 - subscript
12631266 v_8 = libdevice.exp2(v_7)
12641267 l_ij = tl.sum(v_8, 2)
1265- v_9 = m_i_copy_0 - m_i
1268+ v_9 = m_i_copy_0 - v_3
12661269 v_10 = libdevice.exp2(v_9)
12671270 v_11 = l_i_copy_0 * v_10
12681271 l_i = v_11 + l_ij
@@ -1271,6 +1274,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
12711274 v = tl.load(tl.make_block_ptr(v_view, [64, 512, 64], [32768, 64, 1], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
12721275 v_14 = v_8.to(tl.float16)
12731276 acc = tl.reshape(tl.dot(tl.reshape(v_14, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_13, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
1277+ m_i = v_3
12741278 subscript_2 = l_i[:, :, None]
12751279 v_15 = acc / subscript_2
12761280 v_16 = v_15.to(tl.float16)
@@ -1366,22 +1370,23 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
13661370 amax = tl.max(_mask_to_2, 2)
13671371 v_0 = 0.18033688
13681372 v_1 = amax * v_0
1369- m_i = triton_helpers.maximum(m_i_copy_0, v_1)
1373+ v_2 = triton_helpers.maximum(m_i_copy_0, v_1)
13701374 v_3 = 0.18033688
13711375 v_4 = qk * v_3
1372- subscript = m_i [:, :, None]
1376+ subscript = v_2 [:, :, None]
13731377 v_5 = v_4 - subscript
13741378 v_6 = libdevice.exp2(v_5)
13751379 _mask_to_3 = tl.where(tl.broadcast_to(mask_1[None, :, None] & mask_3[None, None, :], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3]), v_6, 0)
13761380 l_ij = tl.sum(_mask_to_3, 2)
1377- v_7 = m_i_copy_0 - m_i
1381+ v_7 = m_i_copy_0 - v_2
13781382 v_8 = libdevice.exp2(v_7)
13791383 v_9 = l_i_copy_0 * v_8
13801384 l_i = v_9 + l_ij
13811385 subscript_1 = v_8[:, :, None]
13821386 v_11 = acc_copy_0 * subscript_1
13831387 v = tl.load(tl.make_block_ptr(v_view, [v_view_size_0, v_view_size_1, 64], [v_view_stride_0, v_view_stride_1, v_view_stride_2], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
13841388 acc = tl.reshape(tl.dot(tl.reshape(_mask_to_3, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
1389+ m_i = v_2
13851390 subscript_2 = l_i[:, :, None]
13861391 v_12 = acc / subscript_2
13871392 tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1, 64], [out_stride_0, out_stride_1, out_stride_2], [offset_0, offset_1, 0], [1, _BLOCK_SIZE_1, 64], [2, 1, 0]), v_12, boundary_check=[0, 1, 2])
0 commit comments