Skip to content

Commit 1c50bb0

Browse files
author
lshAlgorithm
committed
vectorization on sum of sa
Signed-off-by: lshAlgorithm <[email protected]>
1 parent 493682d commit 1c50bb0

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

rwkv_operators_wkv_v7.inc

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,15 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
8989
memset(&result_data[t_h_offset], 0, h_stride * sizeof(float));
9090
}
9191

92-
// auto sa_vec = ZEROS();
93-
// for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
94-
// sa_vec = ADD(sa_vec, MULTIPLY(
95-
// LOAD(&a[t_h_offset + j]),
96-
// LOAD(&state_in[h_2d_i_offset + j])
97-
// )
98-
// );
99-
// }
100-
// float sa = horizontal_sum(sa_vec);
101-
float sa = .0;
102-
for (size_t j = 0; j < C / H; j++) {
103-
sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j];
92+
auto sa_vec = ZEROS();
93+
for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
94+
sa_vec = ADD(sa_vec, MULTIPLY(
95+
LOAD(&a[t_h_offset + j]),
96+
LOAD(&state_in[h_2d_i_offset + j])
97+
)
98+
);
10499
}
100+
float sa = horizontal_sum(sa_vec);
105101

106102
auto v_vec = SET1(v[t_h_i_offset]);
107103
sa_vec = SET1(sa);

0 commit comments

Comments
 (0)