Skip to content

Commit c3e9473

Browse files
stollemcopybara-github
authored andcommitted
Added access to softmax attention internals to regular attention
PiperOrigin-RevId: 833353546
1 parent 49d420a commit c3e9473

File tree

5 files changed

+104
-9
lines changed

5 files changed

+104
-9
lines changed

gemma/activations.h

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
4646
struct AttentionActivations {
4747
AttentionActivations(
4848
const ModelConfig& config, const LayerConfig& layer_config,
49-
size_t batch_size, size_t seq_len, const Allocator& allocator,
49+
size_t batch_size, size_t seq_len, AttentionImpl attention_impl,
50+
const Allocator& allocator,
5051
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
5152
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
5253
// MHA and does not use an external KV cache.
@@ -72,6 +73,10 @@ struct AttentionActivations {
7273
att_out(MatFactory("att_out", batch_size,
7374
layer_config.heads * layer_config.qkv_dim,
7475
allocator)),
76+
softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads,
77+
allocator)),
78+
softmax_d(
79+
MatFactory("softmax_d", batch_size, layer_config.heads, allocator)),
7580
att_sums(
7681
MatFactory("att_sums", batch_size, config.model_dim, allocator)),
7782

@@ -80,7 +85,8 @@ struct AttentionActivations {
8085
layer_config.post_qk == PostQKType::HalfRope)),
8186
inv_timescale_global(CreateInvTimescale(
8287
allocator, layer_config.qkv_dim,
83-
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) {
88+
layer_config.post_qk == PostQKType::HalfRope, 1000000.0))
89+
{
8490
// Batch size can be 0 in experimental code so do not assert.
8591
if (batch_size == 0) {
8692
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
@@ -107,6 +113,8 @@ struct AttentionActivations {
107113
pre_att_rms_out.OverrideRows(batch_size);
108114
att.OverrideRows(batch_size);
109115
att_out.OverrideRows(batch_size);
116+
softmax_max.OverrideRows(batch_size);
117+
softmax_d.OverrideRows(batch_size);
110118
att_sums.OverrideRows(batch_size);
111119

112120
// `inv_timescale*` are not batched.
@@ -119,6 +127,8 @@ struct AttentionActivations {
119127
MatStorageT<float> pre_att_rms_out;
120128
MatStorageT<float> att; // attention vector
121129
MatStorageT<float> att_out; // attention output
130+
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
131+
MatStorageT<float> softmax_d; // see OnlineSoftmaxState
122132
// Accumulation of attention outputs over heads
123133
MatStorageT<BF16> att_sums;
124134

@@ -144,6 +154,8 @@ struct AttentionActivationsPtrs {
144154
pre_att_rms_out = activations.pre_att_rms_out;
145155
att = activations.att;
146156
att_out = activations.att_out;
157+
softmax_max = activations.softmax_max;
158+
softmax_d = activations.softmax_d;
147159
att_sums = activations.att_sums;
148160
inv_timescale = activations.inv_timescale;
149161
inv_timescale_global = activations.inv_timescale_global;
@@ -156,6 +168,8 @@ struct AttentionActivationsPtrs {
156168
pre_att_rms_out.OverrideRows(batch_size);
157169
att.OverrideRows(batch_size);
158170
att_out.OverrideRows(batch_size);
171+
softmax_max.OverrideRows(batch_size);
172+
softmax_d.OverrideRows(batch_size);
159173
att_sums.OverrideRows(batch_size);
160174
// `inv_timescale*` are not batched.
161175
}
@@ -179,6 +193,14 @@ struct AttentionActivationsPtrs {
179193
// Attention output computed from att * V, size batch_size x (q_heads *
180194
// qkv_dim).
181195
MatPtrT<float> att_out;
196+
// The maximum logit value encountered when computing att_out from att,
197+
// size batch_size x q_heads . See OnlineSoftmaxState for details.
198+
// WARNING: Only filled in for AttentionImpl::kOld.
199+
MatPtrT<float> softmax_max;
200+
// The sum of scaled exponentials when computing att_out from att,
201+
// size batch_size x q_heads . See OnlineSoftmaxState for details.
202+
// WARNING: Only filled in for AttentionImpl::kOld.
203+
MatPtrT<float> softmax_d;
182204
// Accumulation of attention outputs over heads, size batch_size x
183205
// model_dim.
184206
MatPtrT<BF16> att_sums;
@@ -217,7 +239,8 @@ struct Activations {
217239

218240
attention_impl(runtime_config.attention_impl),
219241
attention_storage(config, layer_config, batch_size, seq_len,
220-
ctx.allocator, row_ptrs),
242+
runtime_config.attention_impl, ctx.allocator,
243+
row_ptrs),
221244
attention(config, seq_len, attention_storage) {
222245
HWY_ASSERT(batch_size != 0);
223246

gemma/attention.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ void SingleDotSoftmaxWeightedSum(
123123
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
124124
const MatPtr& query_norm_scale, const size_t layer_idx,
125125
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
126-
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
126+
float* HWY_RESTRICT att_out, float* HWY_RESTRICT softmax_max,
127+
float* HWY_RESTRICT softmax_d, ThreadingContext& ctx, const size_t worker) {
127128
const float att_cap = activations.config.att_cap;
128129
const float query_scale = activations.query_scale;
129130
// --seq_len must be large enough to avoid wraparound.
@@ -146,7 +147,7 @@ void SingleDotSoftmaxWeightedSum(
146147
// SoftMax with optional SoftCap yields "probabilities" in att.
147148
const Logits logits(att, last_pos + 1);
148149
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
149-
Softmax(logits, ctx, worker, /*temperature=*/1.0f);
150+
Softmax(logits, ctx, worker, /*temperature=*/1.0f, softmax_max, softmax_d);
150151

151152
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
152153
ctx, worker);
@@ -203,6 +204,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
203204
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
204205
float* HWY_RESTRICT att_out =
205206
activations.att_out.Row(tq_idx) + head * qkv_dim;
207+
float* HWY_RESTRICT softmax_max =
208+
activations.softmax_max.Row(tq_idx) + head;
209+
float* HWY_RESTRICT softmax_d = activations.softmax_d.Row(tq_idx) + head;
206210

207211
// Make strided read-only views into the kv cache for
208212
// this query and head.
@@ -215,7 +219,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
215219

216220
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
217221
query_norm_scale, layer_idx, activations, att,
218-
att_out, ctx, worker);
222+
att_out, softmax_max, softmax_d, ctx, worker);
219223
};
220224

221225
{

gemma/flash_attention_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ void TestFlashAttention(size_t target_parallelism) {
124124
const size_t batch_size = kOuter;
125125
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
126126
AttentionActivations attention_storage(config, layer_config, batch_size,
127-
kOuter, ctx.allocator, row_ptrs);
127+
kOuter, AttentionImpl::kFlash,
128+
ctx.allocator, row_ptrs);
128129
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
129130
const size_t qkv_dim = layer_config.qkv_dim;
130131
ASSERT_EQ(qkv_dim, kInner);

ops/ops-inl.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,9 +1125,26 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
11251125

11261126
// See below for a specialized version for top-1 sampling.
11271127
// TODO: support bf16 logits using Decompress2.
1128+
// Computes softmax probabilities for the given logits, normalizing in-place.
1129+
// The calculation is numerically stable, using the max-subtraction trick to
1130+
// compute exp(logits[i] - max(logits)) before normalizing by the sum.
1131+
// If temperature is provided and not 1.0, each intermediate exp() result is
1132+
// divided by temperature before normalization; however, this division by
1133+
// temperature cancels out during the final normalization step, meaning
1134+
// temperature currently has no effect on the output probabilities.
1135+
// @param logits In-out: on input, contains logits; on output, overwritten with
1136+
// probabilities.
1137+
// @param ctx Input: threading context for parallelism and profiling.
1138+
// @param worker Input: worker thread index.
1139+
// @param temperature Input: softmax temperature.
1140+
// @param softmax_max_out Optional output: if not null, stores the max logit
1141+
// value.
1142+
// @param softmax_d_out Optional output: if softmax_max is not null, this must
1143+
// not be null and stores the sum of exp(logit - max).
11281144
static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
1129-
const size_t worker,
1130-
float temperature = 1.0f) {
1145+
const size_t worker, float temperature = 1.0f,
1146+
float* HWY_RESTRICT softmax_max_out = nullptr,
1147+
float* HWY_RESTRICT softmax_d_out = nullptr) {
11311148
GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax);
11321149
HWY_DASSERT(logits.size() != 0);
11331150

@@ -1171,6 +1188,10 @@ static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
11711188
// Double-precision reciprocal does not appear to affect the results.
11721189
const float mul = 1.0f / sum_exp;
11731190
MulByConst(mul, logits.data(), logits.size());
1191+
if (softmax_max_out) {
1192+
(*softmax_max_out) = hn::GetLane(vmax);
1193+
(*softmax_d_out) = sum_exp;
1194+
}
11741195
}
11751196

11761197
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /

ops/ops_test.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,51 @@ void TestAllSoftmax() {
346346
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
347347
}
348348

349+
class TestSoftmaxState {
350+
public:
351+
template <class D>
352+
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
353+
hwy::RandomState& rng) {
354+
if (count == 0) return; // *Softmax would assert
355+
if (misalign_b == 0) return;
356+
using T = hn::TFromD<D>;
357+
358+
hwy::AlignedFreeUniquePtr<T[]> px =
359+
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
360+
hwy::AlignedFreeUniquePtr<T[]> pe =
361+
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
362+
HWY_ASSERT(px && pe);
363+
364+
T* x = px.get() + misalign_a;
365+
T* initial_logits = pe.get() + misalign_a;
366+
367+
for (size_t i = 0; i < count; ++i) {
368+
x[i] = Random<T>(rng);
369+
initial_logits[i] = x[i];
370+
}
371+
372+
float softmax_max;
373+
float softmax_d;
374+
Softmax(Logits(x, count), Ctx(), /*worker=*/0, /*temperature=*/1.0f,
375+
&softmax_max, &softmax_d);
376+
377+
const float maxval =
378+
*std::max_element(initial_logits, initial_logits + count);
379+
380+
float sum_exp = 0.0f;
381+
for (size_t i = 0; i < count; ++i) {
382+
sum_exp += std::exp(initial_logits[i] - maxval);
383+
}
384+
385+
ASSERT_NEAR(softmax_max, maxval, 1e-6);
386+
ASSERT_NEAR(softmax_d, sum_exp, 1e-6);
387+
}
388+
};
389+
390+
void TestAllSoftmaxState() {
391+
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmaxState>>()(float());
392+
}
393+
349394
template <size_t k>
350395
struct TestCreateDistribution {
351396
void operator()(hwy::RandomState& rng) {
@@ -769,6 +814,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
769814
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo);
770815
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
771816
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
817+
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmaxState);
772818
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
773819
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid);
774820
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu);

0 commit comments

Comments
 (0)