diff --git a/mlx_embeddings/models/qwen3.py b/mlx_embeddings/models/qwen3.py index e9cf61dab9..63204d7993 100644 --- a/mlx_embeddings/models/qwen3.py +++ b/mlx_embeddings/models/qwen3.py @@ -207,7 +207,9 @@ def __call__( Args: hidden_states: Input hidden states, shape (batch_size, seq_len, hidden_size) - attention_mask: Attention mask, shape (batch_size, 1, seq_len, seq_len) + attention_mask: Either a dense additive mask of shape + (batch_size, 1, seq_len, seq_len) or the string ``"causal"`` + (passed through to ``mx.fast.scaled_dot_product_attention``). Returns: Attention output, shape (batch_size, seq_len, hidden_size) @@ -257,10 +259,32 @@ def __call__( attn_weights = (query_states @ key_states.transpose(0, 1, 3, 2)) * scale - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + mask = attention_mask + if isinstance(mask, str): + # Fast SDPA accepts mask="causal"; the manual path needs an + # explicit additive mask. For square self-attention + # (T_q == T_kv) lower-right causal equals lower-triangular. + if mask == "causal": + t_q = query_states.shape[-2] + tri = mx.tril(mx.ones((t_q, t_q), dtype=mx.bool_)) + mask = mx.where(tri, 0.0, -mx.inf).astype(attn_weights.dtype) + else: + # Bare raise: implicit __context__ chaining keeps the original + # fast-SDPA error visible ("During handling of...") without + # claiming it *caused* this unsupported-mask error. + raise ValueError(f"Unsupported string attention mask: {mask!r}") + + if mask is not None: + attn_weights = attn_weights + mask attn_weights = mx.softmax(attn_weights, axis=-1) + # Fully-masked rows (e.g. a left-padding query position whose only + # causal keys are themselves padded) softmax to NaN. The fused SDPA + # path returns finite rows for this case, so match it: zero the NaNs + # to stop them propagating into later layers through the causal + # residual stream. Such rows are padding positions discarded by + # last_token_pool, so zeroing them does not affect any real output. + attn_weights = mx.where(mx.isnan(attn_weights), 0.0, attn_weights) attn_output = attn_weights @ value_states # Reshape back to (batch_size, seq_len, hidden_size) @@ -392,7 +416,9 @@ def __call__( Args: input_ids: Input token IDs, shape (batch_size, seq_len) - attention_mask: Attention mask, shape (batch_size, seq_len) or (batch_size, 1, seq_len, seq_len) + attention_mask: Attention mask, shape (batch_size, seq_len) or + (batch_size, 1, seq_len, seq_len). The 2D form must use 0/1 + integer values (1 = attended, 0 = padded). Returns: Hidden states, shape (batch_size, seq_len, hidden_size) @@ -402,23 +428,40 @@ def __call__( # Get token embeddings hidden_states = self.embed_tokens(input_ids) - # Create or process attention mask + # Create or process attention mask. + # + # When there are no padded positions (single input, or an equal-length + # batch) use MLX's fused "causal" mask instead of materializing a dense + # (batch, 1, seq, seq) additive mask. That dense mask is O(batch * seq^2) + # and overflows Metal's max buffer at long context (e.g. batch=32, + # seq=32768 -> 128 GiB). The fused kernel uses O(1) mask memory and is + # numerically identical for square self-attention (T_q == T_kv) -- and + # bit-identical on the MLX/Metal versions tested. if attention_mask is None: - # Create causal mask for autoregressive generation - attention_mask = self._create_causal_mask(seq_length, hidden_states.dtype) + # Direct-call path: the top-level Model.__call__ converts None into + # an all-ones 2D mask before reaching here, so this branch is hit + # only when Qwen3Model is called directly. + attention_mask = "causal" elif attention_mask.ndim == 2: - # Convert padding mask to additive mask and combine with causal mask - # attention_mask shape: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) - padding_mask = attention_mask[:, None, None, :] - padding_mask = mx.where(padding_mask == 0, -mx.inf, 0.0).astype( - hidden_states.dtype - ) - - # Create causal mask - causal_mask = self._create_causal_mask(seq_length, hidden_states.dtype) - - # Combine masks (broadcast padding mask to match causal mask shape) - attention_mask = causal_mask + padding_mask + # The 2D mask is a 0/1 padding mask (1 = attend, 0 = padded), not an + # additive mask. `bool(...)` forces a host eval, which is illegal + # under mx.compile / mx.vmap; the embedding forward is not compiled, + # but guard the read so a future compiled caller degrades to the + # (always correct) dense path instead of crashing on it. + try: + unpadded = bool((attention_mask == 1).all()) + except ValueError: + unpadded = False # tracing (compile/vmap): take the dense path + if unpadded: + attention_mask = "causal" + else: + # Padded batch: build the dense additive mask (O(batch * seq^2)). + padding_mask = attention_mask[:, None, None, :] + padding_mask = mx.where(padding_mask == 0, -mx.inf, 0.0).astype( + hidden_states.dtype + ) + causal_mask = self._create_causal_mask(seq_length, hidden_states.dtype) + attention_mask = causal_mask + padding_mask # Apply transformer layers for layer in self.layers: diff --git a/mlx_embeddings/tests/test_models.py b/mlx_embeddings/tests/test_models.py index e94bd9857d..b9edfb659a 100644 --- a/mlx_embeddings/tests/test_models.py +++ b/mlx_embeddings/tests/test_models.py @@ -725,5 +725,140 @@ def prepare_model_inputs(self, inputs, **kwargs): self.assertTrue(mx.all(scores <= 1.0).item()) +class TestQwen3CausalMask(unittest.TestCase): + def _small_config(self): + from mlx_embeddings.models import qwen3 + + return qwen3.ModelArgs( + hidden_size=64, + num_hidden_layers=2, + intermediate_size=128, + num_attention_heads=4, + num_key_value_heads=2, # GQA: exercises mx.repeat path + head_dim=16, + vocab_size=100, + rms_norm_eps=1e-6, + ) + + def test_attention_fallback_handles_causal_string(self): + from mlx_embeddings.models import qwen3 + + mx.random.seed(0) + config = self._small_config() + attn = qwen3.Qwen3Attention(config) + mx.eval(attn.parameters()) + + h = mx.random.normal((2, 5, config.hidden_size)) + out_fast = attn(h, attention_mask="causal") + mx.eval(out_fast) + + # Force the manual fallback by making fast SDPA raise. + with patch( + "mlx.core.fast.scaled_dot_product_attention", + side_effect=RuntimeError("forced fallback"), + ) as mock_sdpa: + out_fallback = attn(h, attention_mask="causal") + mx.eval(out_fallback) + mock_sdpa.assert_called() + + self.assertTrue( + mx.allclose(out_fast, out_fallback, atol=1e-4).item(), + "manual fallback must match fast SDPA for a causal mask", + ) + + def test_fallback_left_padding_no_nan(self): + # Left padding makes the leading causal query rows fully masked; in the + # manual fallback those rows softmax to NaN. Verify the fallback zeroes + # them so NaN does not leak into the pooled embedding (regression for the + # forced-fallback left-padding path the OOM fix touches). + from mlx_embeddings.models import qwen3 + + mx.random.seed(0) + config = self._small_config() + model = qwen3.Model(config) + mx.eval(model.parameters()) + + ids = mx.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + am_left = mx.array([[0, 0, 1, 1, 1], [0, 1, 1, 1, 1]], dtype=mx.int32) + + with patch( + "mlx.core.fast.scaled_dot_product_attention", + side_effect=RuntimeError("forced fallback"), + ) as mock_sdpa: + out = model(ids, am_left).text_embeds + mx.eval(out) + mock_sdpa.assert_called() + + self.assertFalse( + bool(mx.isnan(out).any().item()), + "manual fallback must not leak NaN from fully-masked left-padding rows", + ) + self.assertEqual(out.shape, (2, config.hidden_size)) + + def test_attention_fallback_rejects_unknown_string_mask(self): + from mlx_embeddings.models import qwen3 + + mx.random.seed(0) + config = self._small_config() + attn = qwen3.Qwen3Attention(config) + mx.eval(attn.parameters()) + + h = mx.random.normal((2, 5, config.hidden_size)) + with patch( + "mlx.core.fast.scaled_dot_product_attention", + side_effect=RuntimeError("forced fallback"), + ): + with self.assertRaises(ValueError): + mx.eval(attn(h, attention_mask="full")) + + def test_model_skips_dense_mask_when_unpadded(self): + from mlx_embeddings.models import qwen3 + + mx.random.seed(0) + config = self._small_config() + model = qwen3.Model(config) + mx.eval(model.parameters()) + + ids = mx.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + + # Spy on the dense-mask builder. Assigning a plain function to the class + # makes it a bound method, so `self` is passed normally. + calls = [] + original = qwen3.Qwen3Model._create_causal_mask + + def spy(self, seq_length, dtype): + calls.append(seq_length) + return original(self, seq_length, dtype) + + qwen3.Qwen3Model._create_causal_mask = spy + try: + # All-ones mask (no padding) -> fused "causal", no dense mask built. + am_ones = mx.ones((2, 5), dtype=mx.int32) + out_ones = model(ids, am_ones).text_embeds + mx.eval(out_ones) + self.assertEqual(calls, [], "unpadded input must not build a dense mask") + + # Padded mask -> dense path unchanged, dense mask IS built. + calls.clear() + am_pad = mx.array([[0, 0, 1, 1, 1], [0, 1, 1, 1, 1]], dtype=mx.int32) + out_pad = model(ids, am_pad).text_embeds + mx.eval(out_pad) + self.assertGreater( + len(calls), 0, "padded input must still build a dense mask" + ) + + # Core forward with attention_mask=None -> fused "causal". + calls.clear() + hidden = model.model(ids, attention_mask=None) + mx.eval(hidden) + self.assertEqual(calls, [], "None mask must not build a dense mask") + finally: + qwen3.Qwen3Model._create_causal_mask = original + + self.assertEqual(out_ones.shape, (2, config.hidden_size)) + self.assertEqual(out_pad.shape, (2, config.hidden_size)) + self.assertEqual(hidden.shape, (2, 5, config.hidden_size)) + + if __name__ == "__main__": unittest.main()