Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
struct AttentionActivations {
AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, const Allocator& allocator,
size_t batch_size, size_t seq_len, AttentionImpl attention_impl,
const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
// MHA and does not use an external KV cache.
Expand All @@ -72,6 +73,10 @@ struct AttentionActivations {
att_out(MatFactory("att_out", batch_size,
layer_config.heads * layer_config.qkv_dim,
allocator)),
softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads,
allocator)),
softmax_d(
MatFactory("softmax_d", batch_size, layer_config.heads, allocator)),
att_sums(
MatFactory("att_sums", batch_size, config.model_dim, allocator)),

Expand All @@ -80,7 +85,8 @@ struct AttentionActivations {
layer_config.post_qk == PostQKType::HalfRope)),
inv_timescale_global(CreateInvTimescale(
allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) {
layer_config.post_qk == PostQKType::HalfRope, 1000000.0))
{
// Batch size can be 0 in experimental code so do not assert.
if (batch_size == 0) {
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
Expand All @@ -107,6 +113,8 @@ struct AttentionActivations {
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
softmax_max.OverrideRows(batch_size);
softmax_d.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);

// `inv_timescale*` are not batched.
Expand All @@ -119,6 +127,8 @@ struct AttentionActivations {
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
MatStorageT<float> softmax_d; // see OnlineSoftmaxState
// Accumulation of attention outputs over heads
MatStorageT<BF16> att_sums;

Expand All @@ -144,6 +154,8 @@ struct AttentionActivationsPtrs {
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
att_out = activations.att_out;
softmax_max = activations.softmax_max;
softmax_d = activations.softmax_d;
att_sums = activations.att_sums;
inv_timescale = activations.inv_timescale;
inv_timescale_global = activations.inv_timescale_global;
Expand All @@ -156,6 +168,8 @@ struct AttentionActivationsPtrs {
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
softmax_max.OverrideRows(batch_size);
softmax_d.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
// `inv_timescale*` are not batched.
}
Expand All @@ -179,6 +193,14 @@ struct AttentionActivationsPtrs {
// Attention output computed from att * V, size batch_size x (q_heads *
// qkv_dim).
MatPtrT<float> att_out;
// The maximum logit value encountered when computing att_out from att,
// size batch_size x q_heads . See OnlineSoftmaxState for details.
// WARNING: Only filled in for AttentionImpl::kOld.
MatPtrT<float> softmax_max;
// The sum of scaled exponentials when computing att_out from att,
// size batch_size x q_heads . See OnlineSoftmaxState for details.
// WARNING: Only filled in for AttentionImpl::kOld.
MatPtrT<float> softmax_d;
// Accumulation of attention outputs over heads, size batch_size x
// model_dim.
MatPtrT<BF16> att_sums;
Expand Down Expand Up @@ -217,7 +239,8 @@ struct Activations {

attention_impl(runtime_config.attention_impl),
attention_storage(config, layer_config, batch_size, seq_len,
ctx.allocator, row_ptrs),
runtime_config.attention_impl, ctx.allocator,
row_ptrs),
attention(config, seq_len, attention_storage) {
HWY_ASSERT(batch_size != 0);

Expand Down
10 changes: 7 additions & 3 deletions gemma/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ void SingleDotSoftmaxWeightedSum(
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const MatPtr& query_norm_scale, const size_t layer_idx,
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
float* HWY_RESTRICT att_out, float* HWY_RESTRICT softmax_max,
float* HWY_RESTRICT softmax_d, ThreadingContext& ctx, const size_t worker) {
const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale;
// --seq_len must be large enough to avoid wraparound.
Expand All @@ -146,7 +147,7 @@ void SingleDotSoftmaxWeightedSum(
// SoftMax with optional SoftCap yields "probabilities" in att.
const Logits logits(att, last_pos + 1);
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
Softmax(logits, ctx, worker, /*temperature=*/1.0f);
Softmax(logits, ctx, worker, /*temperature=*/1.0f, softmax_max, softmax_d);

WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
ctx, worker);
Expand Down Expand Up @@ -203,6 +204,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
float* HWY_RESTRICT att_out =
activations.att_out.Row(tq_idx) + head * qkv_dim;
float* HWY_RESTRICT softmax_max =
activations.softmax_max.Row(tq_idx) + head;
float* HWY_RESTRICT softmax_d = activations.softmax_d.Row(tq_idx) + head;

// Make strided read-only views into the kv cache for
// this query and head.
Expand All @@ -215,7 +219,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,

SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
query_norm_scale, layer_idx, activations, att,
att_out, ctx, worker);
att_out, softmax_max, softmax_d, ctx, worker);
};

{
Expand Down
3 changes: 2 additions & 1 deletion gemma/flash_attention_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ void TestFlashAttention(size_t target_parallelism) {
const size_t batch_size = kOuter;
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
AttentionActivations attention_storage(config, layer_config, batch_size,
kOuter, ctx.allocator, row_ptrs);
kOuter, AttentionImpl::kFlash,
ctx.allocator, row_ptrs);
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
const size_t qkv_dim = layer_config.qkv_dim;
ASSERT_EQ(qkv_dim, kInner);
Expand Down
25 changes: 23 additions & 2 deletions ops/ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1125,9 +1125,26 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(

// See below for a specialized version for top-1 sampling.
// TODO: support bf16 logits using Decompress2.
// Computes softmax probabilities for the given logits, normalizing in-place.
// The calculation is numerically stable, using the max-subtraction trick to
// compute exp(logits[i] - max(logits)) before normalizing by the sum.
// If temperature is provided and not 1.0, each intermediate exp() result is
// divided by temperature before normalization; however, this division by
// temperature cancels out during the final normalization step, meaning
// temperature currently has no effect on the output probabilities.
// @param logits In-out: on input, contains logits; on output, overwritten with
// probabilities.
// @param ctx Input: threading context for parallelism and profiling.
// @param worker Input: worker thread index.
// @param temperature Input: softmax temperature.
// @param softmax_max_out Optional output: if not null, stores the max logit
// value.
// @param softmax_d_out Optional output: if softmax_max is not null, this must
// not be null and stores the sum of exp(logit - max).
static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
const size_t worker,
float temperature = 1.0f) {
const size_t worker, float temperature = 1.0f,
float* HWY_RESTRICT softmax_max_out = nullptr,
float* HWY_RESTRICT softmax_d_out = nullptr) {
GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax);
HWY_DASSERT(logits.size() != 0);

Expand Down Expand Up @@ -1171,6 +1188,10 @@ static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
// Double-precision reciprocal does not appear to affect the results.
const float mul = 1.0f / sum_exp;
MulByConst(mul, logits.data(), logits.size());
if (softmax_max_out) {
(*softmax_max_out) = hn::GetLane(vmax);
(*softmax_d_out) = sum_exp;
}
}

// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
Expand Down
46 changes: 46 additions & 0 deletions ops/ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,51 @@ void TestAllSoftmax() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
}

class TestSoftmaxState {
public:
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
if (count == 0) return; // *Softmax would assert
if (misalign_b == 0) return;
using T = hn::TFromD<D>;

hwy::AlignedFreeUniquePtr<T[]> px =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> pe =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
HWY_ASSERT(px && pe);

T* x = px.get() + misalign_a;
T* initial_logits = pe.get() + misalign_a;

for (size_t i = 0; i < count; ++i) {
x[i] = Random<T>(rng);
initial_logits[i] = x[i];
}

float softmax_max;
float softmax_d;
Softmax(Logits(x, count), Ctx(), /*worker=*/0, /*temperature=*/1.0f,
&softmax_max, &softmax_d);

const float maxval =
*std::max_element(initial_logits, initial_logits + count);

float sum_exp = 0.0f;
for (size_t i = 0; i < count; ++i) {
sum_exp += std::exp(initial_logits[i] - maxval);
}

ASSERT_NEAR(softmax_max, maxval, 1e-6);
ASSERT_NEAR(softmax_d, sum_exp, 1e-6);
}
};

void TestAllSoftmaxState() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmaxState>>()(float());
}

template <size_t k>
struct TestCreateDistribution {
void operator()(hwy::RandomState& rng) {
Expand Down Expand Up @@ -769,6 +814,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmaxState);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu);
Expand Down
Loading