fix(qwen3): avoid out-of-memory crash on long-context embeddings#68
Open
contrapuntal wants to merge 1 commit into
Open
fix(qwen3): avoid out-of-memory crash on long-context embeddings#68contrapuntal wants to merge 1 commit into
contrapuntal wants to merge 1 commit into
Conversation
b32ba59 to
6b3eff6
Compare
6b3eff6 to
3816f9f
Compare
Qwen3Model.__call__ unconditionally builds a dense (batch, 1, seq, seq) additive attention mask on every forward. At Qwen3's native 32768-token context a 32-item batch needs one 32*32768*32768*4 = 137438953472-byte (128 GiB) allocation, which overflows Metal's max buffer (~80.6 GiB) and 500s POST /v1/embeddings with: [metal::malloc] Attempting to allocate 137438953472 bytes which is greater than the maximum allowed buffer size of 86586540032 bytes. It is the mask *tensor* that overflows, not the attention scores: the fused mx.fast.scaled_dot_product_attention causal path holds the mask in O(1) memory. This routes the common unpadded case through mask="causal" instead of materializing the dense tensor. Genuinely padded batches are untouched. For square self-attention (T_q == T_kv, always true here) lower-right causal equals lower-triangular, so the unpadded fast-path output is numerically identical to the old dense mask -- bit-identical on the MLX/Metal versions tested, not an approximation. Changes: - models/qwen3.py: Qwen3Model.__call__ branches on padding. attention_mask is None (direct call) or an all-ones 2D mask (no padded positions) -> mask="causal"; a 2D mask with real zeros still builds the dense additive mask exactly as before. The all-ones read is value-dependent, so it is wrapped in try/except: under mx.compile/mx.vmap tracing it degrades to the (always correct) dense path instead of crashing. The public caller (processor.encode -> model(**inputs)) supplies an all-ones 2D mask rather than None, so this in-model check -- not caller-side routing -- is what covers the common path. - models/qwen3.py: Qwen3Attention.__call__'s manual-attention fallback resolves "causal" into an explicit tril additive mask (so it cannot crash on the new string contract) and raises ValueError on any other string. It also zeroes NaN rows after softmax: a fully-masked left-padding query position would otherwise leak NaN into later layers through the causal residual stream. The fused SDPA path already returns finite rows here. - tests/test_models.py: TestQwen3CausalMask (4 tests) -- fallback matches the fast path for mask="causal" on a GQA config; fallback rejects an unknown string mask; _create_causal_mask is skipped for all-ones / None and still called for a genuinely padded mask; the forced-fallback left-padding path leaks no NaN. Verified (mlx 0.31.2): dense (B,1,S,S) f32 mask bytes = 137438953472 -> OOM as in the traceback fused mask="causal" -> OK, no dense tensor built long single input S=16384 peak mem ~4970 MiB -> ~2538 MiB Qwen3Model under mx.compile -> runs (dense fallback), no crash TestQwen3CausalMask -> 4 passed Out of scope: the bidirectional encoders (gemma3_text, llama_bidirec, llama_nemotron_vl, modernbert) and qwen3_vl share the dense-mask pattern but need a different change (a causal string would alter their semantics).
3816f9f to
95daf5e
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What's wrong
Embedding a long batch with a Qwen3 model can crash before any real work happens.
Qwen3Model.__call__builds a dense(batch, 1, seq, seq)attention mask on every forward pass, and that tensor grows with the square of the sequence length. At Qwen3's full 32k context, a batch of 32 needs a single ~128 GiB allocation — more than Metal will hand out (~80.6 GiB on this machine) — soPOST /v1/embeddingsfails before the model computes anything:The key point is that it's the mask that blows up, not the attention math.
mx.fast.scaled_dot_product_attentionalready knows how to apply a causal mask cheaply: pass itmask="causal"and it never materializes a dense tensor at all. The trouble is thatqwen3.pybuilds the full(B,1,S,S)array by hand and hands that to SDPA, so it never reaches the cheap path. (The top-levelModel.__call__fills in an all-ones mask whenever the caller doesn't supply one, so even a single unpadded sequence ends up on the expensive branch.)This has been the behavior since the file was first added in #34.
The fix
When there's no padding to worry about, pass
mask="causal"and let the fused kernel do the work. When the batch genuinely has padding, build the dense mask exactly as before — that path is left untouched.Concretely,
Qwen3Model.__call__now looks at the 2D mask:"causal", the cheap fused path;This is already how the rest of the package handles causal masking —
lfm2.py,gemma3_text.py,siglip.py, and others all pass"causal"to the fused kernel. The only thing changing here is which mask valueqwen3.pyhands over; the attention call itself is the same, and there's no new dependency.For the unpadded case the result is identical to the old code. The model only ever does square self-attention (query and key lengths are always equal), and for that a causal mask and a lower-triangular mask are the same thing — so the output matches the old dense path bit-for-bit (
max|delta| = 0.0) on the versions I tested. It isn't an approximation.Two supporting details worth calling out:
mx.compile/mx.vmap. The embedding forward isn't compiled today, but I wrapped the check in atry/exceptso a future compiled caller falls back to the (always correct) dense path instead of crashing. I kept the check inside the model rather than asking callers to passNone, because the real entry point —processor.encode(...)→model(**inputs)— hands the model an all-ones mask, notNone. Moving the decision to the caller would quietly miss the common case and bring the OOM straight back."causal"to an explicit lower-triangular mask, and zeroes any NaN rows after softmax. A fully-masked row — say a left-padding position whose only visible keys are themselves padding — would otherwise produce NaNs that leak into later layers. The fused path already handles this; now the fallback matches it.Verifying it
On mlx 0.31.2 (Apple Silicon):
max|delta| = 0.0).mx.allclose, atol 1e-4) when SDPA is forced to raise.Qwen3Modelwrapped inmx.compilenow runs (falling back to the dense path) instead of crashing.python -m pytest mlx_embeddings/tests/test_models.py::TestQwen3CausalMask→ 4 passed.Self-contained reproducer (no weights, no network)
Verified output:
What this doesn't cover
The fast path kicks in only when there's no padding. A 2D mask with any zero — a genuinely padded, unequal-length batch — still builds the dense mask, so padded callers see no change. Note that this 2D mask is a 0/1 padding mask (1 = attend, 0 = padded), not an additive mask.
The bidirectional encoders (
gemma3_text,llama_bidirec,llama_nemotron_vl,modernbert) andqwen3_vlbuild the same dense mask and likely hit the same wall at long context, but they need a different fix — a causal string would change their semantics — so I've left them out of this PR. Happy to follow up on them separately.Background
I couldn't find an existing issue or PR for this OOM (searched all issues and PRs). #34, which introduced the dense mask, advertised that Qwen3 is "not limited to 512 tokens" — this is the long-context use that change was meant to enable. #66 (open) adds a separate bidirectional encoder and leaves
qwen3.pyalone.