@@ -46,7 +46,8 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
4646struct AttentionActivations {
4747 AttentionActivations (
4848 const ModelConfig& config, const LayerConfig& layer_config,
49- size_t batch_size, size_t seq_len, const Allocator& allocator,
49+ size_t batch_size, size_t seq_len, AttentionImpl attention_impl,
50+ const Allocator& allocator,
5051 std::vector<hwy::AlignedFreeUniquePtr<uint8_t *[]>>& row_ptrs)
5152 : // `vocab_size == 0` means it is for Vit part, VitAttention is still
5253 // MHA and does not use an external KV cache.
@@ -74,13 +75,16 @@ struct AttentionActivations {
7475 allocator)),
7576 att_sums (
7677 MatFactory (" att_sums" , batch_size, config.model_dim, allocator)),
78+ softmax_state(MatFactory(" softmax_state" , batch_size,
79+ layer_config.heads, allocator)),
7780
7881 inv_timescale(
7982 CreateInvTimescale (allocator, layer_config.qkv_dim,
8083 layer_config.post_qk == PostQKType::HalfRope)),
8184 inv_timescale_global(CreateInvTimescale(
8285 allocator, layer_config.qkv_dim,
83- layer_config.post_qk == PostQKType::HalfRope, 1000000.0 )) {
86+ layer_config.post_qk == PostQKType::HalfRope, 1000000.0 ))
87+ {
8488 // Batch size can be 0 in experimental code so do not assert.
8589 if (batch_size == 0 ) {
8690 static std::atomic_flag warned = ATOMIC_FLAG_INIT;
@@ -108,6 +112,7 @@ struct AttentionActivations {
108112 att.OverrideRows (batch_size);
109113 att_out.OverrideRows (batch_size);
110114 att_sums.OverrideRows (batch_size);
115+ softmax_state.OverrideRows (batch_size);
111116
112117 // `inv_timescale*` are not batched.
113118 }
@@ -121,6 +126,7 @@ struct AttentionActivations {
121126 MatStorageT<float > att_out; // attention output
122127 // Accumulation of attention outputs over heads
123128 MatStorageT<BF16> att_sums;
129+ MatStorageT<OnlineSoftmaxState> softmax_state;
124130
125131 // Rope
126132 MatStorageT<float > inv_timescale;
@@ -145,6 +151,7 @@ struct AttentionActivationsPtrs {
145151 att = activations.att ;
146152 att_out = activations.att_out ;
147153 att_sums = activations.att_sums ;
154+ softmax_state = activations.softmax_state ;
148155 inv_timescale = activations.inv_timescale ;
149156 inv_timescale_global = activations.inv_timescale_global ;
150157 }
@@ -157,6 +164,7 @@ struct AttentionActivationsPtrs {
157164 att.OverrideRows (batch_size);
158165 att_out.OverrideRows (batch_size);
159166 att_sums.OverrideRows (batch_size);
167+ softmax_state.OverrideRows (batch_size);
160168 // `inv_timescale*` are not batched.
161169 }
162170
@@ -182,6 +190,8 @@ struct AttentionActivationsPtrs {
182190 // Accumulation of attention outputs over heads, size batch_size x
183191 // model_dim.
184192 MatPtrT<BF16> att_sums;
193+ // State for online softmax computation, size batch_size x q_heads.
194+ MatPtrT<OnlineSoftmaxState> softmax_state;
185195 // Inverse timescales for RoPE computation.
186196 MatPtrT<float > inv_timescale;
187197 // Inverse timescales for global RoPE computation.
@@ -217,7 +227,8 @@ struct Activations {
217227
218228 attention_impl(runtime_config.attention_impl),
219229 attention_storage(config, layer_config, batch_size, seq_len,
220- ctx.allocator, row_ptrs),
230+ runtime_config.attention_impl, ctx.allocator,
231+ row_ptrs),
221232 attention(config, seq_len, attention_storage) {
222233 HWY_ASSERT (batch_size != 0 );
223234
0 commit comments