From 95daf5e58c49ff0ab4ccda2de420b87d62e1df40 Mon Sep 17 00:00:00 2001 From: Aaron Yang <16892776+contrapuntal@users.noreply.github.com> Date: Sat, 20 Jun 2026 15:11:46 -0700 Subject: [PATCH] fix(qwen3): avoid out-of-memory crash on long-context embeddings Qwen3Model.__call__ unconditionally builds a dense (batch, 1, seq, seq) additive attention mask on every forward. At Qwen3's native 32768-token context a 32-item batch needs one 32*32768*32768*4 = 137438953472-byte (128 GiB) allocation, which overflows Metal's max buffer (~80.6 GiB) and 500s POST /v1/embeddings with: [metal::malloc] Attempting to allocate 137438953472 bytes which is greater than the maximum allowed buffer size of 86586540032 bytes. It is the mask *tensor* that overflows, not the attention scores: the fused mx.fast.scaled_dot_product_attention causal path holds the mask in O(1) memory. This routes the common unpadded case through mask="causal" instead of materializing the dense tensor. Genuinely padded batches are untouched. For square self-attention (T_q == T_kv, always true here) lower-right causal equals lower-triangular, so the unpadded fast-path output is numerically identical to the old dense mask -- bit-identical on the MLX/Metal versions tested, not an approximation. Changes: - models/qwen3.py: Qwen3Model.__call__ branches on padding. attention_mask is None (direct call) or an all-ones 2D mask (no padded positions) -> mask="causal"; a 2D mask with real zeros still builds the dense additive mask exactly as before. The all-ones read is value-dependent, so it is wrapped in try/except: under mx.compile/mx.vmap tracing it degrades to the (always correct) dense path instead of crashing. The public caller (processor.encode -> model(**inputs)) supplies an all-ones 2D mask rather than None, so this in-model check -- not caller-side routing -- is what covers the common path. - models/qwen3.py: Qwen3Attention.__call__'s manual-attention fallback resolves "causal" into an explicit tril additive mask (so it cannot crash on the new string contract) and raises ValueError on any other string. It also zeroes NaN rows after softmax: a fully-masked left-padding query position would otherwise leak NaN into later layers through the causal residual stream. The fused SDPA path already returns finite rows here. - tests/test_models.py: TestQwen3CausalMask (4 tests) -- fallback matches the fast path for mask="causal" on a GQA config; fallback rejects an unknown string mask; _create_causal_mask is skipped for all-ones / None and still called for a genuinely padded mask; the forced-fallback left-padding path leaks no NaN. Verified (mlx 0.31.2): dense (B,1,S,S) f32 mask bytes = 137438953472 -> OOM as in the traceback fused mask="causal" -> OK, no dense tensor built long single input S=16384 peak mem ~4970 MiB -> ~2538 MiB Qwen3Model under mx.compile -> runs (dense fallback), no crash TestQwen3CausalMask -> 4 passed Out of scope: the bidirectional encoders (gemma3_text, llama_bidirec, llama_nemotron_vl, modernbert) and qwen3_vl share the dense-mask pattern but need a different change (a causal string would alter their semantics). --- mlx_embeddings/models/qwen3.py | 81 +++++++++++++---- mlx_embeddings/tests/test_models.py | 135 ++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+), 19 deletions(-) diff --git a/mlx_embeddings/models/qwen3.py b/mlx_embeddings/models/qwen3.py index e9cf61dab9..63204d7993 100644 --- a/mlx_embeddings/models/qwen3.py +++ b/mlx_embeddings/models/qwen3.py @@ -207,7 +207,9 @@ def __call__( Args: hidden_states: Input hidden states, shape (batch_size, seq_len, hidden_size) - attention_mask: Attention mask, shape (batch_size, 1, seq_len, seq_len) + attention_mask: Either a dense additive mask of shape + (batch_size, 1, seq_len, seq_len) or the string ``"causal"`` + (passed through to ``mx.fast.scaled_dot_product_attention``). Returns: Attention output, shape (batch_size, seq_len, hidden_size) @@ -257,10 +259,32 @@ def __call__( attn_weights = (query_states @ key_states.transpose(0, 1, 3, 2)) * scale - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + mask = attention_mask + if isinstance(mask, str): + # Fast SDPA accepts mask="causal"; the manual path needs an + # explicit additive mask. For square self-attention + # (T_q == T_kv) lower-right causal equals lower-triangular. + if mask == "causal": + t_q = query_states.shape[-2] + tri = mx.tril(mx.ones((t_q, t_q), dtype=mx.bool_)) + mask = mx.where(tri, 0.0, -mx.inf).astype(attn_weights.dtype) + else: + # Bare raise: implicit __context__ chaining keeps the original + # fast-SDPA error visible ("During handling of...") without + # claiming it *caused* this unsupported-mask error. + raise ValueError(f"Unsupported string attention mask: {mask!r}") + + if mask is not None: + attn_weights = attn_weights + mask attn_weights = mx.softmax(attn_weights, axis=-1) + # Fully-masked rows (e.g. a left-padding query position whose only + # causal keys are themselves padded) softmax to NaN. The fused SDPA + # path returns finite rows for this case, so match it: zero the NaNs + # to stop them propagating into later layers through the causal + # residual stream. Such rows are padding positions discarded by + # last_token_pool, so zeroing them does not affect any real output. + attn_weights = mx.where(mx.isnan(attn_weights), 0.0, attn_weights) attn_output = attn_weights @ value_states # Reshape back to (batch_size, seq_len, hidden_size) @@ -392,7 +416,9 @@ def __call__( Args: input_ids: Input token IDs, shape (batch_size, seq_len) - attention_mask: Attention mask, shape (batch_size, seq_len) or (batch_size, 1, seq_len, seq_len) + attention_mask: Attention mask, shape (batch_size, seq_len) or + (batch_size, 1, seq_len, seq_len). The 2D form must use 0/1 + integer values (1 = attended, 0 = padded). Returns: Hidden states, shape (batch_size, seq_len, hidden_size) @@ -402,23 +428,40 @@ def __call__( # Get token embeddings hidden_states = self.embed_tokens(input_ids) - # Create or process attention mask + # Create or process attention mask. + # + # When there are no padded positions (single input, or an equal-length + # batch) use MLX's fused "causal" mask instead of materializing a dense + # (batch, 1, seq, seq) additive mask. That dense mask is O(batch * seq^2) + # and overflows Metal's max buffer at long context (e.g. batch=32, + # seq=32768 -> 128 GiB). The fused kernel uses O(1) mask memory and is + # numerically identical for square self-attention (T_q == T_kv) -- and + # bit-identical on the MLX/Metal versions tested. if attention_mask is None: - # Create causal mask for autoregressive generation - attention_mask = self._create_causal_mask(seq_length, hidden_states.dtype) + # Direct-call path: the top-level Model.__call__ converts None into + # an all-ones 2D mask before reaching here, so this branch is hit + # only when Qwen3Model is called directly. + attention_mask = "causal" elif attention_mask.ndim == 2: - # Convert padding mask to additive mask and combine with causal mask - # attention_mask shape: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) - padding_mask = attention_mask[:, None, None, :] - padding_mask = mx.where(padding_mask == 0, -mx.inf, 0.0).astype( - hidden_states.dtype - ) - - # Create causal mask - causal_mask = self._create_causal_mask(seq_length, hidden_states.dtype) - - # Combine masks (broadcast padding mask to match causal mask shape) - attention_mask = causal_mask + padding_mask + # The 2D mask is a 0/1 padding mask (1 = attend, 0 = padded), not an + # additive mask. `bool(...)` forces a host eval, which is illegal + # under mx.compile / mx.vmap; the embedding forward is not compiled, + # but guard the read so a future compiled caller degrades to the + # (always correct) dense path instead of crashing on it. + try: + unpadded = bool((attention_mask == 1).all()) + except ValueError: + unpadded = False # tracing (compile/vmap): take the dense path + if unpadded: + attention_mask = "causal" + else: + # Padded batch: build the dense additive mask (O(batch * seq^2)). + padding_mask = attention_mask[:, None, None, :] + padding_mask = mx.where(padding_mask == 0, -mx.inf, 0.0).astype( + hidden_states.dtype + ) + causal_mask = self._create_causal_mask(seq_length, hidden_states.dtype) + attention_mask = causal_mask + padding_mask # Apply transformer layers for layer in self.layers: diff --git a/mlx_embeddings/tests/test_models.py b/mlx_embeddings/tests/test_models.py index e94bd9857d..b9edfb659a 100644 --- a/mlx_embeddings/tests/test_models.py +++ b/mlx_embeddings/tests/test_models.py @@ -725,5 +725,140 @@ def prepare_model_inputs(self, inputs, **kwargs): self.assertTrue(mx.all(scores <= 1.0).item()) +class TestQwen3CausalMask(unittest.TestCase): + def _small_config(self): + from mlx_embeddings.models import qwen3 + + return qwen3.ModelArgs( + hidden_size=64, + num_hidden_layers=2, + intermediate_size=128, + num_attention_heads=4, + num_key_value_heads=2, # GQA: exercises mx.repeat path + head_dim=16, + vocab_size=100, + rms_norm_eps=1e-6, + ) + + def test_attention_fallback_handles_causal_string(self): + from mlx_embeddings.models import qwen3 + + mx.random.seed(0) + config = self._small_config() + attn = qwen3.Qwen3Attention(config) + mx.eval(attn.parameters()) + + h = mx.random.normal((2, 5, config.hidden_size)) + out_fast = attn(h, attention_mask="causal") + mx.eval(out_fast) + + # Force the manual fallback by making fast SDPA raise. + with patch( + "mlx.core.fast.scaled_dot_product_attention", + side_effect=RuntimeError("forced fallback"), + ) as mock_sdpa: + out_fallback = attn(h, attention_mask="causal") + mx.eval(out_fallback) + mock_sdpa.assert_called() + + self.assertTrue( + mx.allclose(out_fast, out_fallback, atol=1e-4).item(), + "manual fallback must match fast SDPA for a causal mask", + ) + + def test_fallback_left_padding_no_nan(self): + # Left padding makes the leading causal query rows fully masked; in the + # manual fallback those rows softmax to NaN. Verify the fallback zeroes + # them so NaN does not leak into the pooled embedding (regression for the + # forced-fallback left-padding path the OOM fix touches). + from mlx_embeddings.models import qwen3 + + mx.random.seed(0) + config = self._small_config() + model = qwen3.Model(config) + mx.eval(model.parameters()) + + ids = mx.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + am_left = mx.array([[0, 0, 1, 1, 1], [0, 1, 1, 1, 1]], dtype=mx.int32) + + with patch( + "mlx.core.fast.scaled_dot_product_attention", + side_effect=RuntimeError("forced fallback"), + ) as mock_sdpa: + out = model(ids, am_left).text_embeds + mx.eval(out) + mock_sdpa.assert_called() + + self.assertFalse( + bool(mx.isnan(out).any().item()), + "manual fallback must not leak NaN from fully-masked left-padding rows", + ) + self.assertEqual(out.shape, (2, config.hidden_size)) + + def test_attention_fallback_rejects_unknown_string_mask(self): + from mlx_embeddings.models import qwen3 + + mx.random.seed(0) + config = self._small_config() + attn = qwen3.Qwen3Attention(config) + mx.eval(attn.parameters()) + + h = mx.random.normal((2, 5, config.hidden_size)) + with patch( + "mlx.core.fast.scaled_dot_product_attention", + side_effect=RuntimeError("forced fallback"), + ): + with self.assertRaises(ValueError): + mx.eval(attn(h, attention_mask="full")) + + def test_model_skips_dense_mask_when_unpadded(self): + from mlx_embeddings.models import qwen3 + + mx.random.seed(0) + config = self._small_config() + model = qwen3.Model(config) + mx.eval(model.parameters()) + + ids = mx.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + + # Spy on the dense-mask builder. Assigning a plain function to the class + # makes it a bound method, so `self` is passed normally. + calls = [] + original = qwen3.Qwen3Model._create_causal_mask + + def spy(self, seq_length, dtype): + calls.append(seq_length) + return original(self, seq_length, dtype) + + qwen3.Qwen3Model._create_causal_mask = spy + try: + # All-ones mask (no padding) -> fused "causal", no dense mask built. + am_ones = mx.ones((2, 5), dtype=mx.int32) + out_ones = model(ids, am_ones).text_embeds + mx.eval(out_ones) + self.assertEqual(calls, [], "unpadded input must not build a dense mask") + + # Padded mask -> dense path unchanged, dense mask IS built. + calls.clear() + am_pad = mx.array([[0, 0, 1, 1, 1], [0, 1, 1, 1, 1]], dtype=mx.int32) + out_pad = model(ids, am_pad).text_embeds + mx.eval(out_pad) + self.assertGreater( + len(calls), 0, "padded input must still build a dense mask" + ) + + # Core forward with attention_mask=None -> fused "causal". + calls.clear() + hidden = model.model(ids, attention_mask=None) + mx.eval(hidden) + self.assertEqual(calls, [], "None mask must not build a dense mask") + finally: + qwen3.Qwen3Model._create_causal_mask = original + + self.assertEqual(out_ones.shape, (2, config.hidden_size)) + self.assertEqual(out_pad.shape, (2, config.hidden_size)) + self.assertEqual(hidden.shape, (2, 5, config.hidden_size)) + + if __name__ == "__main__": unittest.main()