Skip to content

Commit 5560c4c

Browse files
voodoo11Manul from Pathway
authored andcommitted
Revert OpenAI batch embedder (#9361)
GitOrigin-RevId: 41aaaaebba111607c0e9c090422b9c4200540d6f
1 parent 8c0c862 commit 5560c4c

File tree

2 files changed

+28
-167
lines changed

2 files changed

+28
-167
lines changed

integration_tests/xpack/test_embedders.py

Lines changed: 5 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22

33
import pathway as pw
4-
from pathway.internals.udfs.retries import ExponentialBackoffRetryStrategy
54
from pathway.internals.udfs.utils import _coerce_sync
65
from pathway.xpacks.llm import embedders
76

@@ -17,9 +16,6 @@
1716
)
1817
@pytest.mark.parametrize("strategy", ["start", "end"])
1918
def test_openai_embedder(text: str, model: str, strategy: str):
20-
table = pw.debug.table_from_rows(
21-
schema=pw.schema_from_types(text=str), rows=[(text,)]
22-
)
2319
if model is None:
2420
embedder = embedders.OpenAIEmbedder(
2521
truncation_keep_strategy=strategy,
@@ -32,102 +28,28 @@ def test_openai_embedder(text: str, model: str, strategy: str):
3228
retry_strategy=pw.udfs.ExponentialBackoffRetryStrategy(),
3329
)
3430

35-
table = table.select(embedding=embedder(pw.this.text))
31+
sync_embedder = _coerce_sync(embedder.func)
3632

37-
result = pw.debug.table_to_pandas(table).to_dict("records")
33+
embedding = sync_embedder(text)
3834

39-
assert len(result) == 1
40-
assert isinstance(result[0]["embedding"][0], float)
41-
assert len(result[0]["embedding"]) > 1500
35+
assert len(embedding) > 1500
4236

4337

4438
@pytest.mark.parametrize("model", ["text-embedding-ada-002", "text-embedding-3-small"])
4539
def test_openai_embedder_fails_no_truncation(model: str):
4640
truncation_keep_strategy = None
4741
embedder = embedders.OpenAIEmbedder(
48-
model=model,
49-
truncation_keep_strategy=truncation_keep_strategy,
50-
retry_strategy=ExponentialBackoffRetryStrategy(),
42+
model=model, truncation_keep_strategy=truncation_keep_strategy
5143
)
5244

5345
sync_embedder = _coerce_sync(embedder.func)
5446

5547
with pytest.raises(Exception) as exc:
56-
sync_embedder([LONG_TEXT])
48+
sync_embedder(LONG_TEXT)
5749

5850
assert "maximum context length" in str(exc)
5951

6052

61-
def test_openai_embedder_with_common_parameter():
62-
table = pw.debug.table_from_rows(
63-
schema=pw.schema_from_types(text=str), rows=[("aaa",), ("bbb",)]
64-
)
65-
66-
embedder = embedders.OpenAIEmbedder(
67-
model="text-embedding-3-small",
68-
retry_strategy=ExponentialBackoffRetryStrategy(),
69-
)
70-
71-
table = table.select(embedding=embedder(pw.this.text, dimensions=700))
72-
73-
result = pw.debug.table_to_pandas(table).to_dict("records")
74-
75-
assert len(result) == 2
76-
assert isinstance(result[0]["embedding"][0], float)
77-
assert len(result[0]["embedding"]) == 700
78-
assert isinstance(result[1]["embedding"][0], float)
79-
assert len(result[1]["embedding"]) == 700
80-
81-
82-
def test_openai_embedder_with_different_parameter():
83-
table = pw.debug.table_from_rows(
84-
schema=pw.schema_from_types(text=str, dimensions=int),
85-
rows=[("aaa", 300), ("bbb", 800)],
86-
)
87-
88-
embedder = embedders.OpenAIEmbedder(
89-
model="text-embedding-3-small",
90-
retry_strategy=ExponentialBackoffRetryStrategy(),
91-
)
92-
93-
table = table.select(
94-
text=pw.this.text,
95-
embedding=embedder(pw.this.text, dimensions=pw.this.dimensions),
96-
)
97-
98-
result = pw.debug.table_to_pandas(table).to_dict("records")
99-
100-
assert len(result) == 2
101-
assert isinstance(result[0]["embedding"][0], float)
102-
assert isinstance(result[1]["embedding"][0], float)
103-
if result[0]["text"] == "aaa":
104-
assert len(result[0]["embedding"]) == 300
105-
else:
106-
assert len(result[1]["embedding"]) == 300
107-
if result[0]["text"] == "bbb":
108-
assert len(result[0]["embedding"]) == 800
109-
else:
110-
assert len(result[1]["embedding"]) == 800
111-
112-
113-
def test_openai_embedder_input_as_kwarg():
114-
table = pw.debug.table_from_rows(
115-
schema=pw.schema_from_types(text=str), rows=[("foo",)]
116-
)
117-
embedder = embedders.OpenAIEmbedder(
118-
model="text-embedding-3-small",
119-
retry_strategy=pw.udfs.ExponentialBackoffRetryStrategy(),
120-
)
121-
122-
table = table.select(embedding=embedder(input=pw.this.text))
123-
124-
result = pw.debug.table_to_pandas(table).to_dict("records")
125-
126-
assert len(result) == 1
127-
assert isinstance(result[0]["embedding"][0], float)
128-
assert len(result[0]["embedding"]) > 1500
129-
130-
13153
def test_sentence_transformer_embedder():
13254
table = pw.debug.table_from_rows(
13355
schema=pw.schema_from_types(text=str), rows=[("aaa",), ("bbb",)]

python/pathway/xpacks/llm/embedders.py

Lines changed: 23 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55
import asyncio
66
import logging
7-
from typing import Any, Literal
7+
from typing import Literal
88

99
import numpy as np
1010

@@ -85,23 +85,6 @@ def __call__(
8585
return super().__call__(input, *args, **kwargs)
8686

8787

88-
def _split_batched_kwargs(
89-
kwargs: dict[str, list[Any]]
90-
) -> tuple[dict[str, Any], dict[str, list[Any]]]:
91-
constant_kwargs = {}
92-
per_row_kwargs = {}
93-
94-
if kwargs:
95-
for key, values in kwargs.items():
96-
v = values[0]
97-
if all(value == v for value in values):
98-
constant_kwargs[key] = v
99-
else:
100-
per_row_kwargs[key] = values
101-
102-
return constant_kwargs, per_row_kwargs
103-
104-
10588
class OpenAIEmbedder(BaseEmbedder):
10689
"""Pathway wrapper for OpenAI Embedding services.
10790
@@ -130,8 +113,6 @@ class OpenAIEmbedder(BaseEmbedder):
130113
Can be ``"start"``, ``"end"`` or ``None``. ``"start"`` will keep the first part of the text
131114
and remove the rest. ``"end"`` will keep the last part of the text.
132115
If `None`, no truncation will be applied to any of the documents, this may cause API exceptions.
133-
batch_size: maximum size of a single batch to be sent to the embedder. Bigger
134-
batches may reduce the time needed for embedding.
135116
encoding_format: The format to return the embeddings in. Can be either `float` or
136117
`base64 <https://pypi.org/project/pybase64/>`_.
137118
user: A unique identifier representing your end-user, which can help OpenAI to monitor
@@ -176,7 +157,6 @@ def __init__(
176157
cache_strategy: udfs.CacheStrategy | None = None,
177158
model: str | None = "text-embedding-3-small",
178159
truncation_keep_strategy: Literal["start", "end"] | None = "start",
179-
batch_size: int = 128,
180160
**openai_kwargs,
181161
):
182162
with optional_imports("xpack-llm"):
@@ -185,7 +165,8 @@ def __init__(
185165
_monkeypatch_openai_async()
186166
executor = udfs.async_executor(capacity=capacity, retry_strategy=retry_strategy)
187167
super().__init__(
188-
executor=executor, cache_strategy=cache_strategy, max_batch_size=batch_size
168+
executor=executor,
169+
cache_strategy=cache_strategy,
189170
)
190171
self.truncation_keep_strategy = truncation_keep_strategy
191172
self.kwargs = dict(openai_kwargs)
@@ -194,64 +175,32 @@ def __init__(
194175
if model is not None:
195176
self.kwargs["model"] = model
196177

197-
async def __wrapped__(self, inputs: list[str], **kwargs) -> list[np.ndarray]:
178+
async def __wrapped__(self, input, **kwargs) -> np.ndarray:
198179
"""Embed the documents
199180
200181
Args:
201-
inputs: mandatory, the strings to embed.
182+
input: mandatory, the string to embed.
202183
**kwargs: optional parameters, if unset defaults from the constructor
203184
will be taken.
204185
"""
186+
input = input or "."
205187

188+
kwargs = {**self.kwargs, **kwargs}
206189
kwargs = _extract_value_inside_dict(kwargs)
207190

208-
if kwargs.get("model") is None and self.kwargs.get("model") is None:
191+
if kwargs.get("model") is None:
209192
raise ValueError(
210193
"`model` parameter is missing in `OpenAIEmbedder`. "
211194
"Please provide the model name either in the constructor or in the function call."
212195
)
213196

214-
constant_kwargs, per_row_kwargs = _split_batched_kwargs(kwargs)
215-
constant_kwargs = {**self.kwargs, **constant_kwargs}
216-
217197
if self.truncation_keep_strategy:
218-
if "model" in per_row_kwargs:
219-
inputs = [
220-
self.truncate_context(model, input, self.truncation_keep_strategy)
221-
for (model, input) in zip(per_row_kwargs["model"], inputs)
222-
]
223-
else:
224-
inputs = [
225-
self.truncate_context(
226-
constant_kwargs["model"], input, self.truncation_keep_strategy
227-
)
228-
for input in inputs
229-
]
230-
231-
# if kwargs are not the same for every input we cannot batch them
232-
if per_row_kwargs:
233-
234-
async def embed_single(input, kwargs) -> np.ndarray:
235-
kwargs = {**constant_kwargs, **kwargs}
236-
ret = await self.client.embeddings.create(input=[input], **kwargs)
237-
return np.array(ret.data[0].embedding)
238-
239-
list_of_per_row_kwargs = [
240-
dict(zip(per_row_kwargs, values))
241-
for values in zip(*per_row_kwargs.values())
242-
]
243-
async with asyncio.TaskGroup() as tg:
244-
tasks = [
245-
tg.create_task(embed_single(input, kwargs))
246-
for input, kwargs in zip(inputs, list_of_per_row_kwargs)
247-
]
248-
249-
result_list = [task.result() for task in tasks]
250-
return result_list
198+
input = self.truncate_context(
199+
kwargs["model"], input, self.truncation_keep_strategy
200+
)
251201

252-
else:
253-
ret = await self.client.embeddings.create(input=inputs, **constant_kwargs)
254-
return [np.array(datum.embedding) for datum in ret.data]
202+
ret = await self.client.embeddings.create(input=[input], **kwargs)
203+
return np.array(ret.data[0].embedding)
255204

256205
@staticmethod
257206
def truncate_context(
@@ -298,25 +247,6 @@ def truncate_context(
298247

299248
return tokenizer.decode(tokens)
300249

301-
@staticmethod
302-
def _count_tokens(text: str, model: str) -> int:
303-
with optional_imports("xpack-llm"):
304-
import tiktoken
305-
306-
tokenizer = tiktoken.encoding_for_model(model)
307-
tokens = tokenizer.encode(text)
308-
return len(tokens)
309-
310-
def get_embedding_dimension(self, **kwargs):
311-
"""Computes number of embedder's dimensions by asking the embedder to embed ``"."``.
312-
313-
Args:
314-
**kwargs: parameters of the embedder, if unset defaults from the constructor
315-
will be taken.
316-
"""
317-
kwargs_as_list = {k: [v] for k, v in kwargs.items()}
318-
return len(_coerce_sync(self.__wrapped__)(["."], **kwargs_as_list)[0])
319-
320250

321251
class LiteLLMEmbedder(BaseEmbedder):
322252
"""Pathway wrapper for `litellm.embedding`.
@@ -470,7 +400,16 @@ def __wrapped__(self, input: list[str], **kwargs) -> list[np.ndarray]:
470400
""" # noqa: E501
471401

472402
kwargs = _extract_value_inside_dict(kwargs)
473-
constant_kwargs, per_row_kwargs = _split_batched_kwargs(kwargs)
403+
constant_kwargs = {}
404+
per_row_kwargs = {}
405+
406+
if kwargs:
407+
for key, values in kwargs.items():
408+
v = values[0]
409+
if all(value == v for value in values):
410+
constant_kwargs[key] = v
411+
else:
412+
per_row_kwargs[key] = values
474413

475414
# if kwargs are not the same for every input we cannot batch them
476415
if per_row_kwargs:

0 commit comments

Comments
 (0)