Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 62 additions & 19 deletions mlx_embeddings/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def __call__(

Args:
hidden_states: Input hidden states, shape (batch_size, seq_len, hidden_size)
attention_mask: Attention mask, shape (batch_size, 1, seq_len, seq_len)
attention_mask: Either a dense additive mask of shape
(batch_size, 1, seq_len, seq_len) or the string ``"causal"``
(passed through to ``mx.fast.scaled_dot_product_attention``).

Returns:
Attention output, shape (batch_size, seq_len, hidden_size)
Expand Down Expand Up @@ -257,10 +259,32 @@ def __call__(

attn_weights = (query_states @ key_states.transpose(0, 1, 3, 2)) * scale

if attention_mask is not None:
attn_weights = attn_weights + attention_mask
mask = attention_mask
if isinstance(mask, str):
# Fast SDPA accepts mask="causal"; the manual path needs an
# explicit additive mask. For square self-attention
# (T_q == T_kv) lower-right causal equals lower-triangular.
if mask == "causal":
t_q = query_states.shape[-2]
tri = mx.tril(mx.ones((t_q, t_q), dtype=mx.bool_))
mask = mx.where(tri, 0.0, -mx.inf).astype(attn_weights.dtype)
else:
# Bare raise: implicit __context__ chaining keeps the original
# fast-SDPA error visible ("During handling of...") without
# claiming it *caused* this unsupported-mask error.
raise ValueError(f"Unsupported string attention mask: {mask!r}")

if mask is not None:
attn_weights = attn_weights + mask

attn_weights = mx.softmax(attn_weights, axis=-1)
# Fully-masked rows (e.g. a left-padding query position whose only
# causal keys are themselves padded) softmax to NaN. The fused SDPA
# path returns finite rows for this case, so match it: zero the NaNs
# to stop them propagating into later layers through the causal
# residual stream. Such rows are padding positions discarded by
# last_token_pool, so zeroing them does not affect any real output.
attn_weights = mx.where(mx.isnan(attn_weights), 0.0, attn_weights)
attn_output = attn_weights @ value_states

# Reshape back to (batch_size, seq_len, hidden_size)
Expand Down Expand Up @@ -392,7 +416,9 @@ def __call__(

Args:
input_ids: Input token IDs, shape (batch_size, seq_len)
attention_mask: Attention mask, shape (batch_size, seq_len) or (batch_size, 1, seq_len, seq_len)
attention_mask: Attention mask, shape (batch_size, seq_len) or
(batch_size, 1, seq_len, seq_len). The 2D form must use 0/1
integer values (1 = attended, 0 = padded).

Returns:
Hidden states, shape (batch_size, seq_len, hidden_size)
Expand All @@ -402,23 +428,40 @@ def __call__(
# Get token embeddings
hidden_states = self.embed_tokens(input_ids)

# Create or process attention mask
# Create or process attention mask.
#
# When there are no padded positions (single input, or an equal-length
# batch) use MLX's fused "causal" mask instead of materializing a dense
# (batch, 1, seq, seq) additive mask. That dense mask is O(batch * seq^2)
# and overflows Metal's max buffer at long context (e.g. batch=32,
# seq=32768 -> 128 GiB). The fused kernel uses O(1) mask memory and is
# numerically identical for square self-attention (T_q == T_kv) -- and
# bit-identical on the MLX/Metal versions tested.
if attention_mask is None:
# Create causal mask for autoregressive generation
attention_mask = self._create_causal_mask(seq_length, hidden_states.dtype)
# Direct-call path: the top-level Model.__call__ converts None into
# an all-ones 2D mask before reaching here, so this branch is hit
# only when Qwen3Model is called directly.
attention_mask = "causal"
elif attention_mask.ndim == 2:
# Convert padding mask to additive mask and combine with causal mask
# attention_mask shape: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len)
padding_mask = attention_mask[:, None, None, :]
padding_mask = mx.where(padding_mask == 0, -mx.inf, 0.0).astype(
hidden_states.dtype
)

# Create causal mask
causal_mask = self._create_causal_mask(seq_length, hidden_states.dtype)

# Combine masks (broadcast padding mask to match causal mask shape)
attention_mask = causal_mask + padding_mask
# The 2D mask is a 0/1 padding mask (1 = attend, 0 = padded), not an
# additive mask. `bool(...)` forces a host eval, which is illegal
# under mx.compile / mx.vmap; the embedding forward is not compiled,
# but guard the read so a future compiled caller degrades to the
# (always correct) dense path instead of crashing on it.
try:
unpadded = bool((attention_mask == 1).all())
except ValueError:
unpadded = False # tracing (compile/vmap): take the dense path
if unpadded:
attention_mask = "causal"
else:
# Padded batch: build the dense additive mask (O(batch * seq^2)).
padding_mask = attention_mask[:, None, None, :]
padding_mask = mx.where(padding_mask == 0, -mx.inf, 0.0).astype(
hidden_states.dtype
)
causal_mask = self._create_causal_mask(seq_length, hidden_states.dtype)
attention_mask = causal_mask + padding_mask

# Apply transformer layers
for layer in self.layers:
Expand Down
135 changes: 135 additions & 0 deletions mlx_embeddings/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,5 +725,140 @@ def prepare_model_inputs(self, inputs, **kwargs):
self.assertTrue(mx.all(scores <= 1.0).item())


class TestQwen3CausalMask(unittest.TestCase):
def _small_config(self):
from mlx_embeddings.models import qwen3

return qwen3.ModelArgs(
hidden_size=64,
num_hidden_layers=2,
intermediate_size=128,
num_attention_heads=4,
num_key_value_heads=2, # GQA: exercises mx.repeat path
head_dim=16,
vocab_size=100,
rms_norm_eps=1e-6,
)

def test_attention_fallback_handles_causal_string(self):
from mlx_embeddings.models import qwen3

mx.random.seed(0)
config = self._small_config()
attn = qwen3.Qwen3Attention(config)
mx.eval(attn.parameters())

h = mx.random.normal((2, 5, config.hidden_size))
out_fast = attn(h, attention_mask="causal")
mx.eval(out_fast)

# Force the manual fallback by making fast SDPA raise.
with patch(
"mlx.core.fast.scaled_dot_product_attention",
side_effect=RuntimeError("forced fallback"),
) as mock_sdpa:
out_fallback = attn(h, attention_mask="causal")
mx.eval(out_fallback)
mock_sdpa.assert_called()

self.assertTrue(
mx.allclose(out_fast, out_fallback, atol=1e-4).item(),
"manual fallback must match fast SDPA for a causal mask",
)

def test_fallback_left_padding_no_nan(self):
# Left padding makes the leading causal query rows fully masked; in the
# manual fallback those rows softmax to NaN. Verify the fallback zeroes
# them so NaN does not leak into the pooled embedding (regression for the
# forced-fallback left-padding path the OOM fix touches).
from mlx_embeddings.models import qwen3

mx.random.seed(0)
config = self._small_config()
model = qwen3.Model(config)
mx.eval(model.parameters())

ids = mx.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
am_left = mx.array([[0, 0, 1, 1, 1], [0, 1, 1, 1, 1]], dtype=mx.int32)

with patch(
"mlx.core.fast.scaled_dot_product_attention",
side_effect=RuntimeError("forced fallback"),
) as mock_sdpa:
out = model(ids, am_left).text_embeds
mx.eval(out)
mock_sdpa.assert_called()

self.assertFalse(
bool(mx.isnan(out).any().item()),
"manual fallback must not leak NaN from fully-masked left-padding rows",
)
self.assertEqual(out.shape, (2, config.hidden_size))

def test_attention_fallback_rejects_unknown_string_mask(self):
from mlx_embeddings.models import qwen3

mx.random.seed(0)
config = self._small_config()
attn = qwen3.Qwen3Attention(config)
mx.eval(attn.parameters())

h = mx.random.normal((2, 5, config.hidden_size))
with patch(
"mlx.core.fast.scaled_dot_product_attention",
side_effect=RuntimeError("forced fallback"),
):
with self.assertRaises(ValueError):
mx.eval(attn(h, attention_mask="full"))

def test_model_skips_dense_mask_when_unpadded(self):
from mlx_embeddings.models import qwen3

mx.random.seed(0)
config = self._small_config()
model = qwen3.Model(config)
mx.eval(model.parameters())

ids = mx.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])

# Spy on the dense-mask builder. Assigning a plain function to the class
# makes it a bound method, so `self` is passed normally.
calls = []
original = qwen3.Qwen3Model._create_causal_mask

def spy(self, seq_length, dtype):
calls.append(seq_length)
return original(self, seq_length, dtype)

qwen3.Qwen3Model._create_causal_mask = spy
try:
# All-ones mask (no padding) -> fused "causal", no dense mask built.
am_ones = mx.ones((2, 5), dtype=mx.int32)
out_ones = model(ids, am_ones).text_embeds
mx.eval(out_ones)
self.assertEqual(calls, [], "unpadded input must not build a dense mask")

# Padded mask -> dense path unchanged, dense mask IS built.
calls.clear()
am_pad = mx.array([[0, 0, 1, 1, 1], [0, 1, 1, 1, 1]], dtype=mx.int32)
out_pad = model(ids, am_pad).text_embeds
mx.eval(out_pad)
self.assertGreater(
len(calls), 0, "padded input must still build a dense mask"
)

# Core forward with attention_mask=None -> fused "causal".
calls.clear()
hidden = model.model(ids, attention_mask=None)
mx.eval(hidden)
self.assertEqual(calls, [], "None mask must not build a dense mask")
finally:
qwen3.Qwen3Model._create_causal_mask = original

self.assertEqual(out_ones.shape, (2, config.hidden_size))
self.assertEqual(out_pad.shape, (2, config.hidden_size))
self.assertEqual(hidden.shape, (2, 5, config.hidden_size))


if __name__ == "__main__":
unittest.main()