@@ -46,7 +46,8 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
4646struct AttentionActivations {
4747 AttentionActivations (
4848 const ModelConfig& config, const LayerConfig& layer_config,
49- size_t batch_size, size_t seq_len, const Allocator& allocator,
49+ size_t batch_size, size_t seq_len, AttentionImpl attention_impl,
50+ const Allocator& allocator,
5051 std::vector<hwy::AlignedFreeUniquePtr<uint8_t *[]>>& row_ptrs)
5152 : // `vocab_size == 0` means it is for Vit part, VitAttention is still
5253 // MHA and does not use an external KV cache.
@@ -80,7 +81,8 @@ struct AttentionActivations {
8081 layer_config.post_qk == PostQKType::HalfRope)),
8182 inv_timescale_global(CreateInvTimescale(
8283 allocator, layer_config.qkv_dim,
83- layer_config.post_qk == PostQKType::HalfRope, 1000000.0 )) {
84+ layer_config.post_qk == PostQKType::HalfRope, 1000000.0 ))
85+ {
8486 // Batch size can be 0 in experimental code so do not assert.
8587 if (batch_size == 0 ) {
8688 static std::atomic_flag warned = ATOMIC_FLAG_INIT;
@@ -217,7 +219,8 @@ struct Activations {
217219
218220 attention_impl(runtime_config.attention_impl),
219221 attention_storage(config, layer_config, batch_size, seq_len,
220- ctx.allocator, row_ptrs),
222+ runtime_config.attention_impl, ctx.allocator,
223+ row_ptrs),
221224 attention(config, seq_len, attention_storage) {
222225 HWY_ASSERT (batch_size != 0 );
223226
0 commit comments