diff --git a/mlx_embeddings/convert.py b/mlx_embeddings/convert.py index 66339d0211..c6b224b580 100644 --- a/mlx_embeddings/convert.py +++ b/mlx_embeddings/convert.py @@ -233,6 +233,10 @@ def convert( for file in files: shutil.copy(file, mlx_path) + src_pooling = model_path / "1_Pooling" + if src_pooling.is_dir(): + shutil.copytree(src_pooling, mlx_path / "1_Pooling", dirs_exist_ok=True) + tokenizer.save_pretrained(mlx_path) save_config(config, config_path=mlx_path / "config.json") diff --git a/mlx_embeddings/models/base.py b/mlx_embeddings/models/base.py index f77f46750a..82df2343da 100644 --- a/mlx_embeddings/models/base.py +++ b/mlx_embeddings/models/base.py @@ -38,16 +38,6 @@ class ViTModelOutput: vision_model_output: Optional[mx.array] = None -def mean_pooling(token_embeddings: mx.array, attention_mask: mx.array): - input_mask_expanded = mx.expand_dims(attention_mask, -1) - input_mask_expanded = mx.broadcast_to( - input_mask_expanded, token_embeddings.shape - ).astype(mx.float32) - sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, axis=1) - sum_mask = mx.maximum(mx.sum(input_mask_expanded, axis=1), 1e-9) - return sum_embeddings / sum_mask - - def normalize_embeddings(embeddings, p=2, axis=-1, keepdims=True, eps=1e-9): return embeddings / mx.maximum( mx.linalg.norm(embeddings, ord=p, axis=axis, keepdims=keepdims), eps diff --git a/mlx_embeddings/models/bert.py b/mlx_embeddings/models/bert.py index 683c77a539..4c4b1c47c0 100644 --- a/mlx_embeddings/models/bert.py +++ b/mlx_embeddings/models/bert.py @@ -1,11 +1,12 @@ import math -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings +from .base import BaseModelArgs, BaseModelOutput, normalize_embeddings +from .pooling import pool_by_config @dataclass @@ -22,6 +23,7 @@ class ModelArgs(BaseModelArgs): initializer_range: float = 0.02 layer_norm_eps: float = 1e-12 vocab_size: int = 30522 + pooling_config: dict = field(default_factory=lambda: {"pooling_mode": "mean"}) class BertEmbeddings(nn.Module): @@ -224,8 +226,9 @@ def __call__(self, input_ids, token_type_ids=None, attention_mask=None): sequence_output = encoder_outputs pooled_output = self.pooler(sequence_output) - # normalized features - text_embeds = mean_pooling(sequence_output, attention_mask) + text_embeds = pool_by_config( + sequence_output, attention_mask, self.config.pooling_config + ) text_embeds = normalize_embeddings(text_embeds) return BaseModelOutput( diff --git a/mlx_embeddings/models/gemma3_text.py b/mlx_embeddings/models/gemma3_text.py index bd12ffec19..1a9d4ba14f 100644 --- a/mlx_embeddings/models/gemma3_text.py +++ b/mlx_embeddings/models/gemma3_text.py @@ -6,7 +6,8 @@ from mlx_lm.models.base import create_attention_mask from mlx_lm.models.gemma3_text import ModelArgs, RMSNorm, TransformerBlock -from .base import BaseModelOutput, mean_pooling, normalize_embeddings +from .base import BaseModelOutput, normalize_embeddings +from .pooling import mean_pooling class Gemma3Model(nn.Module): diff --git a/mlx_embeddings/models/lfm2.py b/mlx_embeddings/models/lfm2.py index 61a097c8ad..339cd93959 100644 --- a/mlx_embeddings/models/lfm2.py +++ b/mlx_embeddings/models/lfm2.py @@ -9,7 +9,8 @@ from mlx_lm.models.lfm2 import Lfm2DecoderLayer from mlx_lm.models.lfm2 import ModelArgs as Lfm2ModelArgs -from .base import BaseModelOutput, mean_pooling, normalize_embeddings +from .base import BaseModelOutput, normalize_embeddings +from .pooling import mean_pooling @dataclass diff --git a/mlx_embeddings/models/llama_bidirec.py b/mlx_embeddings/models/llama_bidirec.py index f722f82477..cd9647db47 100644 --- a/mlx_embeddings/models/llama_bidirec.py +++ b/mlx_embeddings/models/llama_bidirec.py @@ -5,7 +5,8 @@ import mlx.nn as nn from mlx_lm.models.llama import TransformerBlock -from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings +from .base import BaseModelArgs, BaseModelOutput, normalize_embeddings +from .pooling import mean_pooling @dataclass diff --git a/mlx_embeddings/models/llama_nemotron_vl/model.py b/mlx_embeddings/models/llama_nemotron_vl/model.py index b96661b378..2c595ff2f6 100644 --- a/mlx_embeddings/models/llama_nemotron_vl/model.py +++ b/mlx_embeddings/models/llama_nemotron_vl/model.py @@ -6,9 +6,10 @@ import mlx.nn as nn import numpy as np -from ..base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings +from ..base import BaseModelArgs, BaseModelOutput, normalize_embeddings from ..llama_bidirec import LlamaBidirectionalModel from ..llama_bidirec import ModelArgs as LlamaBidirectModelArgs +from ..pooling import mean_pooling from ..siglip import SiglipVisionTransformer from ..siglip import VisionConfig as SiglipVisionConfig diff --git a/mlx_embeddings/models/modernbert.py b/mlx_embeddings/models/modernbert.py index 827233a653..d38968b72b 100644 --- a/mlx_embeddings/models/modernbert.py +++ b/mlx_embeddings/models/modernbert.py @@ -5,7 +5,8 @@ import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings +from .base import BaseModelArgs, BaseModelOutput, normalize_embeddings +from .pooling import mean_pooling @dataclass diff --git a/mlx_embeddings/models/pooling.py b/mlx_embeddings/models/pooling.py new file mode 100644 index 0000000000..2c2650e29d --- /dev/null +++ b/mlx_embeddings/models/pooling.py @@ -0,0 +1,119 @@ +from typing import Any, Dict + +import mlx.core as mx + + +def mean_pooling(token_embeddings: mx.array, attention_mask: mx.array): + input_mask_expanded = mx.expand_dims(attention_mask, -1) + input_mask_expanded = mx.broadcast_to( + input_mask_expanded, token_embeddings.shape + ).astype(mx.float32) + sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, axis=1) + sum_mask = mx.maximum(mx.sum(input_mask_expanded, axis=1), 1e-9) + return sum_embeddings / sum_mask + + +def cls_pooling(token_embeddings: mx.array, attention_mask: mx.array) -> mx.array: + first_indices = mx.argmax(attention_mask, axis=1) + batch_size = token_embeddings.shape[0] + hidden_dim = token_embeddings.shape[-1] + gather_idx = mx.broadcast_to( + first_indices[:, None, None], (batch_size, 1, hidden_dim) + ) + return mx.squeeze(mx.take_along_axis(token_embeddings, gather_idx, axis=1), axis=1) + + +def max_pooling(token_embeddings: mx.array, attention_mask: mx.array) -> mx.array: + mask = mx.expand_dims(attention_mask, -1) + mask = mx.broadcast_to(mask, token_embeddings.shape).astype(token_embeddings.dtype) + masked = mx.where(mask == 0, -float("inf"), token_embeddings) + return mx.max(masked, axis=1) + + +def lasttoken_pooling(token_embeddings: mx.array, attention_mask: mx.array) -> mx.array: + batch_size, seq_len, hidden_dim = token_embeddings.shape + flipped = attention_mask[:, ::-1] + flip_indices = mx.argmax(flipped, axis=1) + has_any_real = mx.max(flipped, axis=1) + flip_indices = mx.where(has_any_real == 0, seq_len - 1, flip_indices) + last_indices = seq_len - flip_indices - 1 + gather_idx = mx.broadcast_to( + last_indices[:, None, None], (batch_size, 1, hidden_dim) + ) + mask = mx.broadcast_to(attention_mask[:, :, None], token_embeddings.shape).astype( + token_embeddings.dtype + ) + return mx.squeeze( + mx.take_along_axis(token_embeddings * mask, gather_idx, axis=1), axis=1 + ) + + +_LEGACY_POOLING_MODE_KWARGS = { + "pooling_mode_cls_token": "cls", + "pooling_mode_max_tokens": "max", + "pooling_mode_mean_tokens": "mean", + "pooling_mode_mean_sqrt_len_tokens": "mean_sqrt_len_tokens", + "pooling_mode_weightedmean_tokens": "weightedmean", + "pooling_mode_lasttoken": "lasttoken", +} + +_SUPPORTED_POOL_MODES = {"cls", "mean", "max", "lasttoken"} +_KNOWN_UNSUPPORTED_POOL_MODES = {"weightedmean", "mean_sqrt_len_tokens"} + + +def _normalize_pooling_config( + pooling_config: Dict[str, Any], +) -> Dict[str, Any]: + cfg = dict(pooling_config) + found = [k for k in _LEGACY_POOLING_MODE_KWARGS if k in cfg] + if not found: + return cfg + if "pooling_mode" not in cfg: + active = tuple( + name + for key, name in _LEGACY_POOLING_MODE_KWARGS.items() + if cfg.get(key, False) + ) + if not active: + active = ("mean",) + cfg["pooling_mode"] = active[0] if len(active) == 1 else active + for k in found: + del cfg[k] + return cfg + + +def pool_by_config( + token_embeddings: mx.array, + attention_mask: mx.array, + pooling_config: Dict[str, Any], +) -> mx.array: + cfg = _normalize_pooling_config(pooling_config) + mode = cfg["pooling_mode"] + if not cfg.get("include_prompt", True): + raise NotImplementedError( + "Prompt-aware pooling (include_prompt=False) is not supported. " + "This affects INSTRUCTOR-style models." + ) + if isinstance(mode, (tuple, list)): + raise NotImplementedError( + f"Concatenated pooling mode {mode!r} is not supported; " + "only a single pooling mode is allowed." + ) + if mode in _KNOWN_UNSUPPORTED_POOL_MODES: + raise NotImplementedError( + f"Pooling mode {mode!r} is not supported. " + f"Supported modes: {sorted(_SUPPORTED_POOL_MODES)}." + ) + + if mode == "cls": + return cls_pooling(token_embeddings, attention_mask) + if mode == "max": + return max_pooling(token_embeddings, attention_mask) + if mode == "lasttoken": + return lasttoken_pooling(token_embeddings, attention_mask) + if mode == "mean": + return mean_pooling(token_embeddings, attention_mask) + raise ValueError( + f"Unknown pooling mode {mode!r}. " + f"Supported modes: {sorted(_SUPPORTED_POOL_MODES)}." + ) diff --git a/mlx_embeddings/models/xlm_roberta.py b/mlx_embeddings/models/xlm_roberta.py index 1259ea8705..32b563657e 100644 --- a/mlx_embeddings/models/xlm_roberta.py +++ b/mlx_embeddings/models/xlm_roberta.py @@ -1,11 +1,12 @@ import math -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings +from .base import BaseModelArgs, BaseModelOutput, normalize_embeddings +from .pooling import pool_by_config @dataclass @@ -25,7 +26,7 @@ class ModelArgs(BaseModelArgs): output_past: bool = True pad_token_id: int = 1 position_embedding_type: str = "absolute" - pooling_config: dict = None + pooling_config: dict = field(default_factory=lambda: {"pooling_mode": "mean"}) class XLMRobertaEmbeddings(nn.Module): @@ -352,8 +353,9 @@ def __call__( self.pooler(sequence_output) if self.pooler is not None else None ) - # normalized features - text_embeds = mean_pooling(sequence_output, attention_mask) + text_embeds = pool_by_config( + sequence_output, attention_mask, self.config.pooling_config + ) text_embeds = normalize_embeddings(text_embeds) return BaseModelOutput( diff --git a/mlx_embeddings/tests/test_base.py b/mlx_embeddings/tests/test_base.py index 18fd7e33fc..784b3466d8 100644 --- a/mlx_embeddings/tests/test_base.py +++ b/mlx_embeddings/tests/test_base.py @@ -1,12 +1,10 @@ import mlx.core as mx import numpy as np -import pytest from mlx_embeddings.models.base import ( BaseModelArgs, BaseModelOutput, ViTModelOutput, - mean_pooling, normalize_embeddings, ) from mlx_embeddings.tokenizer_utils import TokenizerWrapper @@ -100,36 +98,6 @@ def test_initialization(self): assert output.vision_model_output is mock_array -class TestMeanPooling: - def test_mean_pooling(self): - # Create sample inputs - batch_size, seq_len, hidden_dim = 2, 3, 4 - token_embeddings = mx.random.normal((batch_size, seq_len, hidden_dim)) - - # Test case 1: No masking (all 1s) - attention_mask = mx.ones((batch_size, seq_len)) - result = mean_pooling(token_embeddings, attention_mask) - - # Expected result is the mean across sequence dimension - expected = mx.mean(token_embeddings, axis=1) - np.testing.assert_allclose(result.tolist(), expected.tolist(), rtol=1e-5) - - # Test case 2: With masking - attention_mask = mx.array( - [ - [1, 1, 0], # Only first two tokens are valid - [1, 0, 0], # Only first token is valid - ] - ) - result = mean_pooling(token_embeddings, attention_mask) - - # Manual calculation for verification - expected_0 = mx.sum(token_embeddings[0, :2], axis=0) / 2 - expected_1 = token_embeddings[1, 0] # Just the first embedding - expected = mx.stack([expected_0, expected_1]) - np.testing.assert_allclose(result.tolist(), expected.tolist(), rtol=1e-5) - - class TestNormalizeEmbeddings: def test_normalize_embeddings(self): # Test case 1: 2D array diff --git a/mlx_embeddings/tests/test_models.py b/mlx_embeddings/tests/test_models.py index efe2cf9604..e94bd9857d 100644 --- a/mlx_embeddings/tests/test_models.py +++ b/mlx_embeddings/tests/test_models.py @@ -655,23 +655,27 @@ def convert_tokens_to_ids(self, token): dummy_tokenizer = DummyTokenizer() dummy_image_processor = MagicMock() dummy_image_processor.merge_size = 2 + dummy_auto_tokenizer = MagicMock() + dummy_auto_tokenizer.from_pretrained.return_value = dummy_tokenizer + dummy_auto_image_processor = MagicMock() + dummy_auto_image_processor.from_pretrained.return_value = dummy_image_processor with ( patch.object( - qwen3_vl.processor.AutoTokenizer, - "from_pretrained", - return_value=dummy_tokenizer, + qwen3_vl.processor, + "AutoTokenizer", + dummy_auto_tokenizer, ) as mock_tokenizer, patch.object( - qwen3_vl.processor.AutoImageProcessor, - "from_pretrained", - return_value=dummy_image_processor, + qwen3_vl.processor, + "AutoImageProcessor", + dummy_auto_image_processor, ) as mock_image_processor, ): processor = qwen3_vl.Processor.from_pretrained("dummy-model") - mock_tokenizer.assert_called_once() - mock_image_processor.assert_called_once() + mock_tokenizer.from_pretrained.assert_called_once() + mock_image_processor.from_pretrained.assert_called_once() self.assertIs(processor.tokenizer, dummy_tokenizer) self.assertIs(processor.image_processor, dummy_image_processor) self.assertEqual(processor.processor.chat_template, "dummy-template") diff --git a/mlx_embeddings/tests/test_pooling.py b/mlx_embeddings/tests/test_pooling.py new file mode 100644 index 0000000000..b99cbc7a37 --- /dev/null +++ b/mlx_embeddings/tests/test_pooling.py @@ -0,0 +1,207 @@ +import mlx.core as mx +import numpy as np +import pytest + +from mlx_embeddings.models.pooling import ( + _SUPPORTED_POOL_MODES, + _normalize_pooling_config, + cls_pooling, + lasttoken_pooling, + max_pooling, + mean_pooling, + pool_by_config, +) + + +class TestMeanPooling: + def test_mean_pooling(self): + # Create sample inputs + batch_size, seq_len, hidden_dim = 2, 3, 4 + token_embeddings = mx.random.normal((batch_size, seq_len, hidden_dim)) + + # Test case 1: No masking (all 1s) + attention_mask = mx.ones((batch_size, seq_len)) + result = mean_pooling(token_embeddings, attention_mask) + + # Expected result is the mean across sequence dimension + expected = mx.mean(token_embeddings, axis=1) + np.testing.assert_allclose(result.tolist(), expected.tolist(), rtol=1e-5) + + # Test case 2: With masking + attention_mask = mx.array( + [ + [1, 1, 0], # Only first two tokens are valid + [1, 0, 0], # Only first token is valid + ] + ) + result = mean_pooling(token_embeddings, attention_mask) + + # Manual calculation for verification + expected_0 = mx.sum(token_embeddings[0, :2], axis=0) / 2 + expected_1 = token_embeddings[1, 0] # Just the first embedding + expected = mx.stack([expected_0, expected_1]) + np.testing.assert_allclose(result.tolist(), expected.tolist(), rtol=1e-5) + + +class TestClsPooling: + def test_right_padded_uses_position_zero(self): + """For right-padded inputs, the first real token is at position 0.""" + token_embeddings = mx.array( + [[[1.0], [2.0], [3.0], [4.0]], [[5.0], [6.0], [7.0], [8.0]]] + ) + attention_mask = mx.array([[1, 1, 1, 0], [1, 1, 0, 0]]) + result = cls_pooling(token_embeddings, attention_mask) + np.testing.assert_allclose(result.tolist(), [[1.0], [5.0]]) + + def test_left_padded_finds_first_real_token(self): + """For left-padded inputs (decoder-style models), the first real token is the first 1 in the + attention mask, not position 0.""" + token_embeddings = mx.array( + [[[1.0], [2.0], [3.0], [4.0]], [[5.0], [6.0], [7.0], [8.0]]] + ) + attention_mask = mx.array([[0, 0, 1, 1], [0, 1, 1, 1]]) + result = cls_pooling(token_embeddings, attention_mask) + np.testing.assert_allclose(result.tolist(), [[3.0], [6.0]]) + + +class TestMaxPooling: + def test_respects_attention_mask(self): + # Last position has the largest value but is masked out; max should + # therefore come from the last unmasked token. + token_embeddings = mx.array([[[1.0], [3.0], [5.0], [10.0]]]) + attention_mask = mx.array([[1, 1, 1, 0]]) + result = max_pooling(token_embeddings, attention_mask) + assert result.shape == (1, 1) + np.testing.assert_allclose(result.tolist(), [[5.0]]) + + +class TestLastTokenPooling: + def test_finds_last_attended_token(self): + # Each row has a different pattern of attended tokens; the last + # attended position should be selected. + token_embeddings = mx.array( + [ + [[0.0], [1.0], [2.0], [3.0]], + [[5.0], [6.0], [7.0], [8.0]], + ] + ) + attention_mask = mx.array([[1, 1, 1, 0], [1, 1, 0, 0]]) + result = lasttoken_pooling(token_embeddings, attention_mask) + assert result.shape == (2, 1) + np.testing.assert_allclose(result.tolist(), [[2.0], [6.0]]) + + def test_all_padding_returns_zero_vector(self): + dim = 2 + token_embeddings = mx.ones((1, 4, dim)) + attention_mask = mx.zeros((1, 4), dtype=mx.int32) + result = lasttoken_pooling(token_embeddings, attention_mask) + assert result.shape == (1, dim) + np.testing.assert_allclose(result.tolist(), [[0.0, 0.0]]) + + +# Shared test fixtures: two sequences with different lengths and a mix of padding. +# seq 0: 3 real tokens + 1 pad, seq 1: 4 real tokens, no pad +_FIXTURE_TOKEN_EMBEDDINGS = mx.array( + [ + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [99.0, 99.0]], + [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0], [70.0, 80.0]], + ] +) +_FIXTURE_ATTENTION_MASK = mx.array([[1, 1, 1, 0], [1, 1, 1, 1]]) +_FIXTURE_EXPECTED_BY_MODE = { + "cls": [[1.0, 2.0], [10.0, 20.0]], + "max": [[5.0, 6.0], [70.0, 80.0]], + "mean": [[3.0, 4.0], [40.0, 50.0]], + "lasttoken": [[5.0, 6.0], [70.0, 80.0]], +} +_POOLING_FN_BY_MODE = { + "cls": cls_pooling, + "max": max_pooling, + "mean": mean_pooling, + "lasttoken": lasttoken_pooling, +} + + +class TestPoolingExactValues: + def test_exact_values(self): + """Verify each pooling mode produces the expected exact values.""" + for mode, expected in _FIXTURE_EXPECTED_BY_MODE.items(): + result = _POOLING_FN_BY_MODE[mode]( + _FIXTURE_TOKEN_EMBEDDINGS, _FIXTURE_ATTENTION_MASK + ) + assert result.shape == (2, 2), f"shape mismatch for mode={mode!r}" + np.testing.assert_allclose( + result.tolist(), + expected, + atol=1e-5, + err_msg=f"value mismatch for mode={mode!r}", + ) + + +class TestNormalizePoolingConfig: + def test_pooling_legacy_config_conversion(self): + """Verify that old-style saved configs are silently converted when loading.""" + old_config = { + "embedding_dimension": 384, + "pooling_mode_cls_token": False, + "pooling_mode_mean_tokens": True, + "pooling_mode_max_tokens": False, + "pooling_mode_mean_sqrt_len_tokens": False, + "pooling_mode_weightedmean_tokens": False, + "pooling_mode_lasttoken": False, + "include_prompt": True, + } + assert _normalize_pooling_config(old_config) == { + "embedding_dimension": 384, + "pooling_mode": "mean", + "include_prompt": True, + } + + def test_pooling_legacy_config_conversion_multi_mode(self): + """Verify legacy config with multiple active modes converts to a tuple.""" + old_config = { + "embedding_dimension": 384, + "pooling_mode_cls_token": True, + "pooling_mode_mean_tokens": True, + "pooling_mode_max_tokens": False, + "pooling_mode_mean_sqrt_len_tokens": False, + "pooling_mode_weightedmean_tokens": False, + "pooling_mode_lasttoken": False, + "include_prompt": True, + } + assert _normalize_pooling_config(old_config) == { + "embedding_dimension": 384, + "pooling_mode": ("cls", "mean"), + "include_prompt": True, + } + + +class TestPoolByConfig: + def test_forward_all_modes(self): + # Basic sanity check that all pooling strategies run and produce the + # expected sentence embedding shape. + embedding_dimension = 8 + batch_size, seq_len = 3, 5 + token_embeddings = mx.random.normal((batch_size, seq_len, embedding_dimension)) + + # Mix of left / right padding patterns, but always at least one non-pad token + attention_mask = mx.array( + [ + [1, 1, 1, 0, 0], + [0, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + ] + ) + for mode in sorted(_SUPPORTED_POOL_MODES): + result = pool_by_config( + token_embeddings, attention_mask, {"pooling_mode": mode} + ) + assert result.shape == (batch_size, embedding_dimension), f"mode={mode!r}" + + def test_invalid_mode_raises(self): + token_embeddings = mx.random.normal((1, 4, 4)) + attention_mask = mx.ones((1, 4)) + with pytest.raises(ValueError, match="Unknown pooling mode"): + pool_by_config( + token_embeddings, attention_mask, {"pooling_mode": "nonexistent"} + ) diff --git a/mlx_embeddings/utils.py b/mlx_embeddings/utils.py index e46fc82d0d..004246c5fe 100644 --- a/mlx_embeddings/utils.py +++ b/mlx_embeddings/utils.py @@ -110,6 +110,15 @@ def load_config(model_path: Path) -> dict: return config +def _read_pooling_config(model_path: Path) -> Optional[dict]: + """Return the parsed ``1_Pooling/config.json``, or None if absent.""" + pooling_cfg_path = model_path / "1_Pooling" / "config.json" + if not pooling_cfg_path.exists(): + return None + with open(pooling_cfg_path, "r") as f: + return json.load(f) + + def load_model( model_path: Path, lazy: bool = False, @@ -142,6 +151,11 @@ def load_model( config = load_config(model_path) config.update(model_config) + if "pooling_config" not in config: + pooling_cfg = _read_pooling_config(model_path) + if pooling_cfg is not None: + config["pooling_config"] = pooling_cfg + weight_files = glob.glob(str(model_path / "**/model*.safetensors"), recursive=True) if not weight_files: