Skip to content

Commit 7bf21b7

Browse files
stollemcopybara-github
authored andcommitted
Added access to flash attention internals to regular attention
PiperOrigin-RevId: 833353546
1 parent 49d420a commit 7bf21b7

File tree

7 files changed

+100
-23
lines changed

7 files changed

+100
-23
lines changed

gemma/activations.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
4646
struct AttentionActivations {
4747
AttentionActivations(
4848
const ModelConfig& config, const LayerConfig& layer_config,
49-
size_t batch_size, size_t seq_len, const Allocator& allocator,
49+
size_t batch_size, size_t seq_len, AttentionImpl attention_impl,
50+
const Allocator& allocator,
5051
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
5152
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
5253
// MHA and does not use an external KV cache.
@@ -74,13 +75,16 @@ struct AttentionActivations {
7475
allocator)),
7576
att_sums(
7677
MatFactory("att_sums", batch_size, config.model_dim, allocator)),
78+
softmax_state(MatFactory("softmax_state", batch_size,
79+
layer_config.heads, allocator)),
7780

7881
inv_timescale(
7982
CreateInvTimescale(allocator, layer_config.qkv_dim,
8083
layer_config.post_qk == PostQKType::HalfRope)),
8184
inv_timescale_global(CreateInvTimescale(
8285
allocator, layer_config.qkv_dim,
83-
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) {
86+
layer_config.post_qk == PostQKType::HalfRope, 1000000.0))
87+
{
8488
// Batch size can be 0 in experimental code so do not assert.
8589
if (batch_size == 0) {
8690
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
@@ -108,6 +112,7 @@ struct AttentionActivations {
108112
att.OverrideRows(batch_size);
109113
att_out.OverrideRows(batch_size);
110114
att_sums.OverrideRows(batch_size);
115+
softmax_state.OverrideRows(batch_size);
111116

112117
// `inv_timescale*` are not batched.
113118
}
@@ -121,6 +126,7 @@ struct AttentionActivations {
121126
MatStorageT<float> att_out; // attention output
122127
// Accumulation of attention outputs over heads
123128
MatStorageT<BF16> att_sums;
129+
MatStorageT<OnlineSoftmaxState> softmax_state;
124130

125131
// Rope
126132
MatStorageT<float> inv_timescale;
@@ -145,6 +151,7 @@ struct AttentionActivationsPtrs {
145151
att = activations.att;
146152
att_out = activations.att_out;
147153
att_sums = activations.att_sums;
154+
softmax_state = activations.softmax_state;
148155
inv_timescale = activations.inv_timescale;
149156
inv_timescale_global = activations.inv_timescale_global;
150157
}
@@ -157,6 +164,7 @@ struct AttentionActivationsPtrs {
157164
att.OverrideRows(batch_size);
158165
att_out.OverrideRows(batch_size);
159166
att_sums.OverrideRows(batch_size);
167+
softmax_state.OverrideRows(batch_size);
160168
// `inv_timescale*` are not batched.
161169
}
162170

@@ -182,6 +190,8 @@ struct AttentionActivationsPtrs {
182190
// Accumulation of attention outputs over heads, size batch_size x
183191
// model_dim.
184192
MatPtrT<BF16> att_sums;
193+
// State for online softmax computation, size batch_size x q_heads.
194+
MatPtrT<OnlineSoftmaxState> softmax_state;
185195
// Inverse timescales for RoPE computation.
186196
MatPtrT<float> inv_timescale;
187197
// Inverse timescales for global RoPE computation.
@@ -217,7 +227,8 @@ struct Activations {
217227

218228
attention_impl(runtime_config.attention_impl),
219229
attention_storage(config, layer_config, batch_size, seq_len,
220-
ctx.allocator, row_ptrs),
230+
runtime_config.attention_impl, ctx.allocator,
231+
row_ptrs),
221232
attention(config, seq_len, attention_storage) {
222233
HWY_ASSERT(batch_size != 0);
223234

gemma/attention.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ void SingleDotSoftmaxWeightedSum(
123123
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
124124
const MatPtr& query_norm_scale, const size_t layer_idx,
125125
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
126-
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
126+
float* HWY_RESTRICT att_out, OnlineSoftmaxState* state_out,
127+
ThreadingContext& ctx, const size_t worker) {
127128
const float att_cap = activations.config.att_cap;
128129
const float query_scale = activations.query_scale;
129130
// --seq_len must be large enough to avoid wraparound.
@@ -146,7 +147,7 @@ void SingleDotSoftmaxWeightedSum(
146147
// SoftMax with optional SoftCap yields "probabilities" in att.
147148
const Logits logits(att, last_pos + 1);
148149
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
149-
Softmax(logits, ctx, worker, /*temperature=*/1.0f);
150+
Softmax(logits, ctx, worker, /*temperature=*/1.0f, state_out);
150151

151152
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
152153
ctx, worker);
@@ -203,6 +204,8 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
203204
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
204205
float* HWY_RESTRICT att_out =
205206
activations.att_out.Row(tq_idx) + head * qkv_dim;
207+
OnlineSoftmaxState* state_out =
208+
activations.softmax_state.Row(tq_idx) + head;
206209

207210
// Make strided read-only views into the kv cache for
208211
// this query and head.
@@ -215,7 +218,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
215218

216219
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
217220
query_norm_scale, layer_idx, activations, att,
218-
att_out, ctx, worker);
221+
att_out, state_out, ctx, worker);
219222
};
220223

221224
{

gemma/flash_attention_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ void TestFlashAttention(size_t target_parallelism) {
124124
const size_t batch_size = kOuter;
125125
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
126126
AttentionActivations attention_storage(config, layer_config, batch_size,
127-
kOuter, ctx.allocator, row_ptrs);
127+
kOuter, AttentionImpl::kFlash,
128+
ctx.allocator, row_ptrs);
128129
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
129130
const size_t qkv_dim = layer_config.qkv_dim;
130131
ASSERT_EQ(qkv_dim, kInner);

gemma/flash_structs.h

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,10 @@
33

44
#include <stddef.h>
55

6-
#include <limits>
6+
#include "util/basics.h"
77

88
namespace gcpp {
99

10-
// State for computing softmax in a streaming ("online") manner,
11-
// avoiding large intermediate values by subtracting the running maximum.
12-
// For a sequence x_1, ..., x_n:
13-
// m_i = max(m_{i-1}, x_i)
14-
// d_i = d_{i-1} * exp(m_{i-1} - m_i) + exp(x_i - m_i)
15-
// softmax_i = exp(x_i - m_i) / d_i
16-
struct OnlineSoftmaxState {
17-
// Maximum logit value encountered so far.
18-
float max = -std::numeric_limits<float>::max() / 2.0f;
19-
// Sum of exponentials scaled by exp(-max).
20-
float d = 0.0f;
21-
};
22-
2310
static constexpr size_t kVTileSize4 = 4;
2411

2512
struct Tile4FlashState {

ops/ops-inl.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,9 +1125,23 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
11251125

11261126
// See below for a specialized version for top-1 sampling.
11271127
// TODO: support bf16 logits using Decompress2.
1128+
// Computes softmax probabilities for the given logits, normalizing in-place.
1129+
// The calculation is numerically stable, using the max-subtraction trick to
1130+
// compute exp(logits[i] - max(logits)) before normalizing by the sum.
1131+
// If temperature is provided and not 1.0, each intermediate exp() result is
1132+
// divided by temperature before normalization; however, this division by
1133+
// temperature cancels out during the final normalization step, meaning
1134+
// temperature currently has no effect on the output probabilities.
1135+
// @param logits In-out: on input, contains logits; on output, overwritten with
1136+
// probabilities.
1137+
// @param state Optional output: if not null, stores the max logit and sum of
1138+
// exp(logit - max) for use in online softmax computation.
1139+
// @param ctx Input: threading context for parallelism and profiling.
1140+
// @param worker Input: worker thread index.
1141+
// @param temperature Input: softmax temperature.
11281142
static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
1129-
const size_t worker,
1130-
float temperature = 1.0f) {
1143+
const size_t worker, float temperature = 1.0f,
1144+
OnlineSoftmaxState* state = nullptr) {
11311145
GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax);
11321146
HWY_DASSERT(logits.size() != 0);
11331147

@@ -1171,6 +1185,10 @@ static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
11711185
// Double-precision reciprocal does not appear to affect the results.
11721186
const float mul = 1.0f / sum_exp;
11731187
MulByConst(mul, logits.data(), logits.size());
1188+
if (state) {
1189+
state->max = hn::GetLane(vmax);
1190+
state->d = sum_exp;
1191+
}
11741192
}
11751193

11761194
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /

ops/ops_test.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,49 @@ void TestAllSoftmax() {
346346
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
347347
}
348348

349+
class TestSoftmaxState {
350+
public:
351+
template <class D>
352+
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
353+
hwy::RandomState& rng) {
354+
if (count == 0) return; // *Softmax would assert
355+
if (misalign_b == 0) return;
356+
using T = hn::TFromD<D>;
357+
358+
hwy::AlignedFreeUniquePtr<T[]> px =
359+
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
360+
hwy::AlignedFreeUniquePtr<T[]> pe =
361+
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
362+
HWY_ASSERT(px && pe);
363+
364+
T* x = px.get() + misalign_a;
365+
T* initial_logits = pe.get() + misalign_a;
366+
367+
for (size_t i = 0; i < count; ++i) {
368+
x[i] = Random<T>(rng);
369+
initial_logits[i] = x[i];
370+
}
371+
372+
OnlineSoftmaxState state;
373+
Softmax(Logits(x, count), Ctx(), /*worker=*/0, /*temperature=*/1.0f, &state);
374+
375+
const float maxval =
376+
*std::max_element(initial_logits, initial_logits + count);
377+
378+
float sum_exp = 0.0f;
379+
for (size_t i = 0; i < count; ++i) {
380+
sum_exp += std::exp(initial_logits[i] - maxval);
381+
}
382+
383+
ASSERT_NEAR(state.max, maxval, 1e-6);
384+
ASSERT_NEAR(state.d, sum_exp, 1e-6);
385+
}
386+
};
387+
388+
void TestAllSoftmaxState() {
389+
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmaxState>>()(float());
390+
}
391+
349392
template <size_t k>
350393
struct TestCreateDistribution {
351394
void operator()(hwy::RandomState& rng) {
@@ -769,6 +812,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
769812
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo);
770813
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
771814
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
815+
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmaxState);
772816
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
773817
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid);
774818
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu);

util/basics.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,19 @@ struct TokenAndProb {
8989
};
9090
#pragma pack(pop)
9191

92+
// State for computing softmax in a streaming ("online") manner,
93+
// avoiding large intermediate values by subtracting the running maximum.
94+
// For a sequence x_1, ..., x_n:
95+
// m_i = max(m_{i-1}, x_i)
96+
// d_i = d_{i-1} * exp(m_{i-1} - m_i) + exp(x_i - m_i)
97+
// softmax_i = exp(x_i - m_i) / d_i
98+
struct OnlineSoftmaxState {
99+
// Maximum logit value encountered so far.
100+
float max = -std::numeric_limits<float>::max() / 2.0f;
101+
// Sum of exponentials scaled by exp(-max).
102+
float d = 0.0f;
103+
};
104+
92105
// Entire size of a 2D array.
93106
struct Extents2D {
94107
constexpr Extents2D() : rows(0), cols(0) {}

0 commit comments

Comments
 (0)