Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlx_embeddings/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ def convert(
for file in files:
shutil.copy(file, mlx_path)

src_pooling = model_path / "1_Pooling"
if src_pooling.is_dir():
shutil.copytree(src_pooling, mlx_path / "1_Pooling", dirs_exist_ok=True)

tokenizer.save_pretrained(mlx_path)

save_config(config, config_path=mlx_path / "config.json")
Expand Down
10 changes: 0 additions & 10 deletions mlx_embeddings/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,6 @@ class ViTModelOutput:
vision_model_output: Optional[mx.array] = None


def mean_pooling(token_embeddings: mx.array, attention_mask: mx.array):
input_mask_expanded = mx.expand_dims(attention_mask, -1)
input_mask_expanded = mx.broadcast_to(
input_mask_expanded, token_embeddings.shape
).astype(mx.float32)
sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, axis=1)
sum_mask = mx.maximum(mx.sum(input_mask_expanded, axis=1), 1e-9)
return sum_embeddings / sum_mask


def normalize_embeddings(embeddings, p=2, axis=-1, keepdims=True, eps=1e-9):
return embeddings / mx.maximum(
mx.linalg.norm(embeddings, ord=p, axis=axis, keepdims=keepdims), eps
Expand Down
11 changes: 7 additions & 4 deletions mlx_embeddings/models/bert.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional, Tuple

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings
from .base import BaseModelArgs, BaseModelOutput, normalize_embeddings
from .pooling import pool_by_config


@dataclass
Expand All @@ -22,6 +23,7 @@ class ModelArgs(BaseModelArgs):
initializer_range: float = 0.02
layer_norm_eps: float = 1e-12
vocab_size: int = 30522
pooling_config: dict = field(default_factory=lambda: {"pooling_mode": "mean"})


class BertEmbeddings(nn.Module):
Expand Down Expand Up @@ -224,8 +226,9 @@ def __call__(self, input_ids, token_type_ids=None, attention_mask=None):
sequence_output = encoder_outputs
pooled_output = self.pooler(sequence_output)

# normalized features
text_embeds = mean_pooling(sequence_output, attention_mask)
text_embeds = pool_by_config(
sequence_output, attention_mask, self.config.pooling_config
)
text_embeds = normalize_embeddings(text_embeds)

return BaseModelOutput(
Expand Down
3 changes: 2 additions & 1 deletion mlx_embeddings/models/gemma3_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from mlx_lm.models.base import create_attention_mask
from mlx_lm.models.gemma3_text import ModelArgs, RMSNorm, TransformerBlock

from .base import BaseModelOutput, mean_pooling, normalize_embeddings
from .base import BaseModelOutput, normalize_embeddings
from .pooling import mean_pooling


class Gemma3Model(nn.Module):
Expand Down
3 changes: 2 additions & 1 deletion mlx_embeddings/models/lfm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from mlx_lm.models.lfm2 import Lfm2DecoderLayer
from mlx_lm.models.lfm2 import ModelArgs as Lfm2ModelArgs

from .base import BaseModelOutput, mean_pooling, normalize_embeddings
from .base import BaseModelOutput, normalize_embeddings
from .pooling import mean_pooling


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion mlx_embeddings/models/llama_bidirec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import mlx.nn as nn
from mlx_lm.models.llama import TransformerBlock

from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings
from .base import BaseModelArgs, BaseModelOutput, normalize_embeddings
from .pooling import mean_pooling


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion mlx_embeddings/models/llama_nemotron_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import mlx.nn as nn
import numpy as np

from ..base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings
from ..base import BaseModelArgs, BaseModelOutput, normalize_embeddings
from ..llama_bidirec import LlamaBidirectionalModel
from ..llama_bidirec import ModelArgs as LlamaBidirectModelArgs
from ..pooling import mean_pooling
from ..siglip import SiglipVisionTransformer
from ..siglip import VisionConfig as SiglipVisionConfig

Expand Down
3 changes: 2 additions & 1 deletion mlx_embeddings/models/modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings
from .base import BaseModelArgs, BaseModelOutput, normalize_embeddings
from .pooling import mean_pooling


@dataclass
Expand Down
119 changes: 119 additions & 0 deletions mlx_embeddings/models/pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Any, Dict

import mlx.core as mx


def mean_pooling(token_embeddings: mx.array, attention_mask: mx.array):
input_mask_expanded = mx.expand_dims(attention_mask, -1)
input_mask_expanded = mx.broadcast_to(
input_mask_expanded, token_embeddings.shape
).astype(mx.float32)
sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, axis=1)
sum_mask = mx.maximum(mx.sum(input_mask_expanded, axis=1), 1e-9)
return sum_embeddings / sum_mask


def cls_pooling(token_embeddings: mx.array, attention_mask: mx.array) -> mx.array:
first_indices = mx.argmax(attention_mask, axis=1)
batch_size = token_embeddings.shape[0]
hidden_dim = token_embeddings.shape[-1]
gather_idx = mx.broadcast_to(
first_indices[:, None, None], (batch_size, 1, hidden_dim)
)
return mx.squeeze(mx.take_along_axis(token_embeddings, gather_idx, axis=1), axis=1)


def max_pooling(token_embeddings: mx.array, attention_mask: mx.array) -> mx.array:
mask = mx.expand_dims(attention_mask, -1)
mask = mx.broadcast_to(mask, token_embeddings.shape).astype(token_embeddings.dtype)
masked = mx.where(mask == 0, -float("inf"), token_embeddings)
return mx.max(masked, axis=1)


def lasttoken_pooling(token_embeddings: mx.array, attention_mask: mx.array) -> mx.array:
batch_size, seq_len, hidden_dim = token_embeddings.shape
flipped = attention_mask[:, ::-1]
flip_indices = mx.argmax(flipped, axis=1)
has_any_real = mx.max(flipped, axis=1)
flip_indices = mx.where(has_any_real == 0, seq_len - 1, flip_indices)
last_indices = seq_len - flip_indices - 1
gather_idx = mx.broadcast_to(
last_indices[:, None, None], (batch_size, 1, hidden_dim)
)
mask = mx.broadcast_to(attention_mask[:, :, None], token_embeddings.shape).astype(
token_embeddings.dtype
)
return mx.squeeze(
mx.take_along_axis(token_embeddings * mask, gather_idx, axis=1), axis=1
)


_LEGACY_POOLING_MODE_KWARGS = {
"pooling_mode_cls_token": "cls",
"pooling_mode_max_tokens": "max",
"pooling_mode_mean_tokens": "mean",
"pooling_mode_mean_sqrt_len_tokens": "mean_sqrt_len_tokens",
"pooling_mode_weightedmean_tokens": "weightedmean",
"pooling_mode_lasttoken": "lasttoken",
}

_SUPPORTED_POOL_MODES = {"cls", "mean", "max", "lasttoken"}
_KNOWN_UNSUPPORTED_POOL_MODES = {"weightedmean", "mean_sqrt_len_tokens"}


def _normalize_pooling_config(
pooling_config: Dict[str, Any],
) -> Dict[str, Any]:
cfg = dict(pooling_config)
found = [k for k in _LEGACY_POOLING_MODE_KWARGS if k in cfg]
if not found:
return cfg
if "pooling_mode" not in cfg:
active = tuple(
name
for key, name in _LEGACY_POOLING_MODE_KWARGS.items()
if cfg.get(key, False)
)
if not active:
active = ("mean",)
cfg["pooling_mode"] = active[0] if len(active) == 1 else active
for k in found:
del cfg[k]
return cfg


def pool_by_config(
token_embeddings: mx.array,
attention_mask: mx.array,
pooling_config: Dict[str, Any],
) -> mx.array:
cfg = _normalize_pooling_config(pooling_config)
mode = cfg["pooling_mode"]
if not cfg.get("include_prompt", True):
raise NotImplementedError(
"Prompt-aware pooling (include_prompt=False) is not supported. "
"This affects INSTRUCTOR-style models."
)
if isinstance(mode, (tuple, list)):
raise NotImplementedError(
f"Concatenated pooling mode {mode!r} is not supported; "
"only a single pooling mode is allowed."
)
if mode in _KNOWN_UNSUPPORTED_POOL_MODES:
raise NotImplementedError(
f"Pooling mode {mode!r} is not supported. "
f"Supported modes: {sorted(_SUPPORTED_POOL_MODES)}."
)

if mode == "cls":
return cls_pooling(token_embeddings, attention_mask)
if mode == "max":
return max_pooling(token_embeddings, attention_mask)
if mode == "lasttoken":
return lasttoken_pooling(token_embeddings, attention_mask)
if mode == "mean":
return mean_pooling(token_embeddings, attention_mask)
raise ValueError(
f"Unknown pooling mode {mode!r}. "
f"Supported modes: {sorted(_SUPPORTED_POOL_MODES)}."
)
12 changes: 7 additions & 5 deletions mlx_embeddings/models/xlm_roberta.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional, Tuple

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings
from .base import BaseModelArgs, BaseModelOutput, normalize_embeddings
from .pooling import pool_by_config


@dataclass
Expand All @@ -25,7 +26,7 @@ class ModelArgs(BaseModelArgs):
output_past: bool = True
pad_token_id: int = 1
position_embedding_type: str = "absolute"
pooling_config: dict = None
pooling_config: dict = field(default_factory=lambda: {"pooling_mode": "mean"})


class XLMRobertaEmbeddings(nn.Module):
Expand Down Expand Up @@ -352,8 +353,9 @@ def __call__(
self.pooler(sequence_output) if self.pooler is not None else None
)

# normalized features
text_embeds = mean_pooling(sequence_output, attention_mask)
text_embeds = pool_by_config(
sequence_output, attention_mask, self.config.pooling_config
)
text_embeds = normalize_embeddings(text_embeds)

return BaseModelOutput(
Expand Down
32 changes: 0 additions & 32 deletions mlx_embeddings/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import mlx.core as mx
import numpy as np
import pytest

from mlx_embeddings.models.base import (
BaseModelArgs,
BaseModelOutput,
ViTModelOutput,
mean_pooling,
normalize_embeddings,
)
from mlx_embeddings.tokenizer_utils import TokenizerWrapper
Expand Down Expand Up @@ -100,36 +98,6 @@ def test_initialization(self):
assert output.vision_model_output is mock_array


class TestMeanPooling:
def test_mean_pooling(self):
# Create sample inputs
batch_size, seq_len, hidden_dim = 2, 3, 4
token_embeddings = mx.random.normal((batch_size, seq_len, hidden_dim))

# Test case 1: No masking (all 1s)
attention_mask = mx.ones((batch_size, seq_len))
result = mean_pooling(token_embeddings, attention_mask)

# Expected result is the mean across sequence dimension
expected = mx.mean(token_embeddings, axis=1)
np.testing.assert_allclose(result.tolist(), expected.tolist(), rtol=1e-5)

# Test case 2: With masking
attention_mask = mx.array(
[
[1, 1, 0], # Only first two tokens are valid
[1, 0, 0], # Only first token is valid
]
)
result = mean_pooling(token_embeddings, attention_mask)

# Manual calculation for verification
expected_0 = mx.sum(token_embeddings[0, :2], axis=0) / 2
expected_1 = token_embeddings[1, 0] # Just the first embedding
expected = mx.stack([expected_0, expected_1])
np.testing.assert_allclose(result.tolist(), expected.tolist(), rtol=1e-5)


class TestNormalizeEmbeddings:
def test_normalize_embeddings(self):
# Test case 1: 2D array
Expand Down
20 changes: 12 additions & 8 deletions mlx_embeddings/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,23 +655,27 @@ def convert_tokens_to_ids(self, token):
dummy_tokenizer = DummyTokenizer()
dummy_image_processor = MagicMock()
dummy_image_processor.merge_size = 2
dummy_auto_tokenizer = MagicMock()
dummy_auto_tokenizer.from_pretrained.return_value = dummy_tokenizer
dummy_auto_image_processor = MagicMock()
dummy_auto_image_processor.from_pretrained.return_value = dummy_image_processor

with (
patch.object(
qwen3_vl.processor.AutoTokenizer,
"from_pretrained",
return_value=dummy_tokenizer,
qwen3_vl.processor,
"AutoTokenizer",
dummy_auto_tokenizer,
) as mock_tokenizer,
patch.object(
qwen3_vl.processor.AutoImageProcessor,
"from_pretrained",
return_value=dummy_image_processor,
qwen3_vl.processor,
"AutoImageProcessor",
dummy_auto_image_processor,
) as mock_image_processor,
):
processor = qwen3_vl.Processor.from_pretrained("dummy-model")

mock_tokenizer.assert_called_once()
mock_image_processor.assert_called_once()
mock_tokenizer.from_pretrained.assert_called_once()
mock_image_processor.from_pretrained.assert_called_once()
self.assertIs(processor.tokenizer, dummy_tokenizer)
self.assertIs(processor.image_processor, dummy_image_processor)
self.assertEqual(processor.processor.chat_template, "dummy-template")
Expand Down
Loading
Loading