Skip to content

Commit a501a8b

Browse files
stollemcopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 834255021
1 parent 49d420a commit a501a8b

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

gemma/activations.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
4646
struct AttentionActivations {
4747
AttentionActivations(
4848
const ModelConfig& config, const LayerConfig& layer_config,
49-
size_t batch_size, size_t seq_len, const Allocator& allocator,
49+
size_t batch_size, size_t seq_len, AttentionImpl attention_impl,
50+
const Allocator& allocator,
5051
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
5152
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
5253
// MHA and does not use an external KV cache.
@@ -80,7 +81,8 @@ struct AttentionActivations {
8081
layer_config.post_qk == PostQKType::HalfRope)),
8182
inv_timescale_global(CreateInvTimescale(
8283
allocator, layer_config.qkv_dim,
83-
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) {
84+
layer_config.post_qk == PostQKType::HalfRope, 1000000.0))
85+
{
8486
// Batch size can be 0 in experimental code so do not assert.
8587
if (batch_size == 0) {
8688
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
@@ -217,7 +219,8 @@ struct Activations {
217219

218220
attention_impl(runtime_config.attention_impl),
219221
attention_storage(config, layer_config, batch_size, seq_len,
220-
ctx.allocator, row_ptrs),
222+
runtime_config.attention_impl, ctx.allocator,
223+
row_ptrs),
221224
attention(config, seq_len, attention_storage) {
222225
HWY_ASSERT(batch_size != 0);
223226

0 commit comments

Comments
 (0)