Skip to content

fix(qwen3): avoid out-of-memory crash on long-context embeddings#68

Open
contrapuntal wants to merge 1 commit into
Blaizzy:mainfrom
contrapuntal:fix/qwen3-embedding-causal-mask-oom
Open

fix(qwen3): avoid out-of-memory crash on long-context embeddings#68
contrapuntal wants to merge 1 commit into
Blaizzy:mainfrom
contrapuntal:fix/qwen3-embedding-causal-mask-oom

Conversation

@contrapuntal

@contrapuntal contrapuntal commented Jun 20, 2026

Copy link
Copy Markdown

What's wrong

Embedding a long batch with a Qwen3 model can crash before any real work happens. Qwen3Model.__call__ builds a dense (batch, 1, seq, seq) attention mask on every forward pass, and that tensor grows with the square of the sequence length. At Qwen3's full 32k context, a batch of 32 needs a single ~128 GiB allocation — more than Metal will hand out (~80.6 GiB on this machine) — so POST /v1/embeddings fails before the model computes anything:

[metal::malloc] Attempting to allocate 137438953472 bytes which is greater than the maximum allowed buffer size of 86586540032 bytes.

The key point is that it's the mask that blows up, not the attention math. mx.fast.scaled_dot_product_attention already knows how to apply a causal mask cheaply: pass it mask="causal" and it never materializes a dense tensor at all. The trouble is that qwen3.py builds the full (B,1,S,S) array by hand and hands that to SDPA, so it never reaches the cheap path. (The top-level Model.__call__ fills in an all-ones mask whenever the caller doesn't supply one, so even a single unpadded sequence ends up on the expensive branch.)

This has been the behavior since the file was first added in #34.

The fix

When there's no padding to worry about, pass mask="causal" and let the fused kernel do the work. When the batch genuinely has padding, build the dense mask exactly as before — that path is left untouched.

Concretely, Qwen3Model.__call__ now looks at the 2D mask:

  • no mask, or an all-ones mask (nothing is padded) → "causal", the cheap fused path;
  • any zeros (real padding) → the original dense mask.

This is already how the rest of the package handles causal masking — lfm2.py, gemma3_text.py, siglip.py, and others all pass "causal" to the fused kernel. The only thing changing here is which mask value qwen3.py hands over; the attention call itself is the same, and there's no new dependency.

For the unpadded case the result is identical to the old code. The model only ever does square self-attention (query and key lengths are always equal), and for that a causal mask and a lower-triangular mask are the same thing — so the output matches the old dense path bit-for-bit (max|delta| = 0.0) on the versions I tested. It isn't an approximation.

Two supporting details worth calling out:

  • The all-ones check reads the mask's values, which isn't allowed under mx.compile / mx.vmap. The embedding forward isn't compiled today, but I wrapped the check in a try/except so a future compiled caller falls back to the (always correct) dense path instead of crashing. I kept the check inside the model rather than asking callers to pass None, because the real entry point — processor.encode(...)model(**inputs) — hands the model an all-ones mask, not None. Moving the decision to the caller would quietly miss the common case and bring the OOM straight back.
  • The manual-attention fallback (only used if the fused kernel ever raises) now resolves "causal" to an explicit lower-triangular mask, and zeroes any NaN rows after softmax. A fully-masked row — say a left-padding position whose only visible keys are themselves padding — would otherwise produce NaNs that leak into later layers. The fused path already handles this; now the fallback matches it.

Verifying it

On mlx 0.31.2 (Apple Silicon):

  • Unpadded output is identical to the old dense mask (max|delta| = 0.0).
  • The manual fallback matches the fused path (mx.allclose, atol 1e-4) when SDPA is forced to raise.
  • A long single input at S=16384 drops peak memory from ~4970 MiB to ~2538 MiB, and the S=32768 / batch-32 case never builds the dense tensor at all.
  • Padded batches are byte-for-byte unchanged.
  • Qwen3Model wrapped in mx.compile now runs (falling back to the dense path) instead of crashing.
  • python -m pytest mlx_embeddings/tests/test_models.py::TestQwen3CausalMask → 4 passed.
Self-contained reproducer (no weights, no network)
import mlx.core as mx

B, S = 32, 32768
print("dense (B,1,S,S) f32 mask bytes =", B * S * S * 4)   # 137438953472

# BUGGY PATH: exactly what Qwen3Model.__call__ builds for an all-ones 2D mask.
am = mx.ones((B, S), dtype=mx.int32)
padding = mx.where(am[:, None, None, :] == 0, -mx.inf, 0.0).astype(mx.float32)            # (B,1,1,S)
causal = mx.where(mx.tril(mx.ones((S, S), dtype=mx.bool_)), 0.0, -mx.inf).astype(mx.float32)[None, None]  # (1,1,S,S)
try:
    mx.eval(causal + padding)                        # (B,1,S,S) -> OOM
    print("allocated dense mask (unexpected)")
except Exception as e:
    print("FAILED as expected:", e)

# FIXED PATH: the fused string mask the fix passes instead; no dense tensor is built.
q = k = v = mx.zeros((1, 4, S, 128), dtype=mx.float16)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask="causal")
mx.eval(out)
print("fused mask='causal' OK:", out.shape)

Verified output:

dense (B,1,S,S) f32 mask bytes = 137438953472
FAILED as expected: [metal::malloc] Attempting to allocate 137438953472 bytes which is greater than the maximum allowed buffer size of 86586540032 bytes.
fused mask='causal' OK: (1, 4, 32768, 128)

What this doesn't cover

The fast path kicks in only when there's no padding. A 2D mask with any zero — a genuinely padded, unequal-length batch — still builds the dense mask, so padded callers see no change. Note that this 2D mask is a 0/1 padding mask (1 = attend, 0 = padded), not an additive mask.

The bidirectional encoders (gemma3_text, llama_bidirec, llama_nemotron_vl, modernbert) and qwen3_vl build the same dense mask and likely hit the same wall at long context, but they need a different fix — a causal string would change their semantics — so I've left them out of this PR. Happy to follow up on them separately.

Background

I couldn't find an existing issue or PR for this OOM (searched all issues and PRs). #34, which introduced the dense mask, advertised that Qwen3 is "not limited to 512 tokens" — this is the long-context use that change was meant to enable. #66 (open) adds a separate bidirectional encoder and leaves qwen3.py alone.

@contrapuntal contrapuntal force-pushed the fix/qwen3-embedding-causal-mask-oom branch from b32ba59 to 6b3eff6 Compare June 20, 2026 23:36
@contrapuntal contrapuntal changed the title fix(qwen3): avoid dense (B,1,S,S) mask OOM via fused mask="causal" fix(qwen3): avoid dense attention-mask OOM at long context Jun 20, 2026
@contrapuntal contrapuntal changed the title fix(qwen3): avoid dense attention-mask OOM at long context fix(qwen3): avoid out-of-memory crash on long-context embeddings Jun 20, 2026
@contrapuntal contrapuntal force-pushed the fix/qwen3-embedding-causal-mask-oom branch from 6b3eff6 to 3816f9f Compare June 20, 2026 23:37
@contrapuntal contrapuntal marked this pull request as ready for review June 20, 2026 23:38
Qwen3Model.__call__ unconditionally builds a dense (batch, 1, seq, seq)
additive attention mask on every forward. At Qwen3's native 32768-token
context a 32-item batch needs one 32*32768*32768*4 = 137438953472-byte
(128 GiB) allocation, which overflows Metal's max buffer (~80.6 GiB) and
500s POST /v1/embeddings with:

  [metal::malloc] Attempting to allocate 137438953472 bytes which is greater
  than the maximum allowed buffer size of 86586540032 bytes.

It is the mask *tensor* that overflows, not the attention scores: the fused
mx.fast.scaled_dot_product_attention causal path holds the mask in O(1)
memory. This routes the common unpadded case through mask="causal" instead
of materializing the dense tensor. Genuinely padded batches are untouched.

For square self-attention (T_q == T_kv, always true here) lower-right causal
equals lower-triangular, so the unpadded fast-path output is numerically
identical to the old dense mask -- bit-identical on the MLX/Metal versions
tested, not an approximation.

Changes:

- models/qwen3.py: Qwen3Model.__call__ branches on padding. attention_mask
  is None (direct call) or an all-ones 2D mask (no padded positions) ->
  mask="causal"; a 2D mask with real zeros still builds the dense additive
  mask exactly as before. The all-ones read is value-dependent, so it is
  wrapped in try/except: under mx.compile/mx.vmap tracing it degrades to the
  (always correct) dense path instead of crashing. The public caller
  (processor.encode -> model(**inputs)) supplies an all-ones 2D mask rather
  than None, so this in-model check -- not caller-side routing -- is what
  covers the common path.

- models/qwen3.py: Qwen3Attention.__call__'s manual-attention fallback
  resolves "causal" into an explicit tril additive mask (so it cannot crash
  on the new string contract) and raises ValueError on any other string. It
  also zeroes NaN rows after softmax: a fully-masked left-padding query
  position would otherwise leak NaN into later layers through the causal
  residual stream. The fused SDPA path already returns finite rows here.

- tests/test_models.py: TestQwen3CausalMask (4 tests) -- fallback matches the
  fast path for mask="causal" on a GQA config; fallback rejects an unknown
  string mask; _create_causal_mask is skipped for all-ones / None and still
  called for a genuinely padded mask; the forced-fallback left-padding path
  leaks no NaN.

Verified (mlx 0.31.2):

  dense (B,1,S,S) f32 mask bytes = 137438953472  -> OOM as in the traceback
  fused mask="causal"                            -> OK, no dense tensor built
  long single input S=16384 peak mem ~4970 MiB   -> ~2538 MiB
  Qwen3Model under mx.compile                     -> runs (dense fallback), no crash
  TestQwen3CausalMask                            -> 4 passed

Out of scope: the bidirectional encoders (gemma3_text, llama_bidirec,
llama_nemotron_vl, modernbert) and qwen3_vl share the dense-mask pattern but
need a different change (a causal string would alter their semantics).
@contrapuntal contrapuntal force-pushed the fix/qwen3-embedding-causal-mask-oom branch from 3816f9f to 95daf5e Compare June 21, 2026 00:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant