4
4
"""
5
5
import asyncio
6
6
import logging
7
- from typing import Any , Literal
7
+ from typing import Literal
8
8
9
9
import numpy as np
10
10
@@ -85,23 +85,6 @@ def __call__(
85
85
return super ().__call__ (input , * args , ** kwargs )
86
86
87
87
88
- def _split_batched_kwargs (
89
- kwargs : dict [str , list [Any ]]
90
- ) -> tuple [dict [str , Any ], dict [str , list [Any ]]]:
91
- constant_kwargs = {}
92
- per_row_kwargs = {}
93
-
94
- if kwargs :
95
- for key , values in kwargs .items ():
96
- v = values [0 ]
97
- if all (value == v for value in values ):
98
- constant_kwargs [key ] = v
99
- else :
100
- per_row_kwargs [key ] = values
101
-
102
- return constant_kwargs , per_row_kwargs
103
-
104
-
105
88
class OpenAIEmbedder (BaseEmbedder ):
106
89
"""Pathway wrapper for OpenAI Embedding services.
107
90
@@ -130,8 +113,6 @@ class OpenAIEmbedder(BaseEmbedder):
130
113
Can be ``"start"``, ``"end"`` or ``None``. ``"start"`` will keep the first part of the text
131
114
and remove the rest. ``"end"`` will keep the last part of the text.
132
115
If `None`, no truncation will be applied to any of the documents, this may cause API exceptions.
133
- batch_size: maximum size of a single batch to be sent to the embedder. Bigger
134
- batches may reduce the time needed for embedding.
135
116
encoding_format: The format to return the embeddings in. Can be either `float` or
136
117
`base64 <https://pypi.org/project/pybase64/>`_.
137
118
user: A unique identifier representing your end-user, which can help OpenAI to monitor
@@ -176,7 +157,6 @@ def __init__(
176
157
cache_strategy : udfs .CacheStrategy | None = None ,
177
158
model : str | None = "text-embedding-3-small" ,
178
159
truncation_keep_strategy : Literal ["start" , "end" ] | None = "start" ,
179
- batch_size : int = 128 ,
180
160
** openai_kwargs ,
181
161
):
182
162
with optional_imports ("xpack-llm" ):
@@ -185,7 +165,8 @@ def __init__(
185
165
_monkeypatch_openai_async ()
186
166
executor = udfs .async_executor (capacity = capacity , retry_strategy = retry_strategy )
187
167
super ().__init__ (
188
- executor = executor , cache_strategy = cache_strategy , max_batch_size = batch_size
168
+ executor = executor ,
169
+ cache_strategy = cache_strategy ,
189
170
)
190
171
self .truncation_keep_strategy = truncation_keep_strategy
191
172
self .kwargs = dict (openai_kwargs )
@@ -194,64 +175,32 @@ def __init__(
194
175
if model is not None :
195
176
self .kwargs ["model" ] = model
196
177
197
- async def __wrapped__ (self , inputs : list [ str ] , ** kwargs ) -> list [ np .ndarray ] :
178
+ async def __wrapped__ (self , input , ** kwargs ) -> np .ndarray :
198
179
"""Embed the documents
199
180
200
181
Args:
201
- inputs : mandatory, the strings to embed.
182
+ input : mandatory, the string to embed.
202
183
**kwargs: optional parameters, if unset defaults from the constructor
203
184
will be taken.
204
185
"""
186
+ input = input or "."
205
187
188
+ kwargs = {** self .kwargs , ** kwargs }
206
189
kwargs = _extract_value_inside_dict (kwargs )
207
190
208
- if kwargs .get ("model" ) is None and self . kwargs . get ( "model" ) is None :
191
+ if kwargs .get ("model" ) is None :
209
192
raise ValueError (
210
193
"`model` parameter is missing in `OpenAIEmbedder`. "
211
194
"Please provide the model name either in the constructor or in the function call."
212
195
)
213
196
214
- constant_kwargs , per_row_kwargs = _split_batched_kwargs (kwargs )
215
- constant_kwargs = {** self .kwargs , ** constant_kwargs }
216
-
217
197
if self .truncation_keep_strategy :
218
- if "model" in per_row_kwargs :
219
- inputs = [
220
- self .truncate_context (model , input , self .truncation_keep_strategy )
221
- for (model , input ) in zip (per_row_kwargs ["model" ], inputs )
222
- ]
223
- else :
224
- inputs = [
225
- self .truncate_context (
226
- constant_kwargs ["model" ], input , self .truncation_keep_strategy
227
- )
228
- for input in inputs
229
- ]
230
-
231
- # if kwargs are not the same for every input we cannot batch them
232
- if per_row_kwargs :
233
-
234
- async def embed_single (input , kwargs ) -> np .ndarray :
235
- kwargs = {** constant_kwargs , ** kwargs }
236
- ret = await self .client .embeddings .create (input = [input ], ** kwargs )
237
- return np .array (ret .data [0 ].embedding )
238
-
239
- list_of_per_row_kwargs = [
240
- dict (zip (per_row_kwargs , values ))
241
- for values in zip (* per_row_kwargs .values ())
242
- ]
243
- async with asyncio .TaskGroup () as tg :
244
- tasks = [
245
- tg .create_task (embed_single (input , kwargs ))
246
- for input , kwargs in zip (inputs , list_of_per_row_kwargs )
247
- ]
248
-
249
- result_list = [task .result () for task in tasks ]
250
- return result_list
198
+ input = self .truncate_context (
199
+ kwargs ["model" ], input , self .truncation_keep_strategy
200
+ )
251
201
252
- else :
253
- ret = await self .client .embeddings .create (input = inputs , ** constant_kwargs )
254
- return [np .array (datum .embedding ) for datum in ret .data ]
202
+ ret = await self .client .embeddings .create (input = [input ], ** kwargs )
203
+ return np .array (ret .data [0 ].embedding )
255
204
256
205
@staticmethod
257
206
def truncate_context (
@@ -298,25 +247,6 @@ def truncate_context(
298
247
299
248
return tokenizer .decode (tokens )
300
249
301
- @staticmethod
302
- def _count_tokens (text : str , model : str ) -> int :
303
- with optional_imports ("xpack-llm" ):
304
- import tiktoken
305
-
306
- tokenizer = tiktoken .encoding_for_model (model )
307
- tokens = tokenizer .encode (text )
308
- return len (tokens )
309
-
310
- def get_embedding_dimension (self , ** kwargs ):
311
- """Computes number of embedder's dimensions by asking the embedder to embed ``"."``.
312
-
313
- Args:
314
- **kwargs: parameters of the embedder, if unset defaults from the constructor
315
- will be taken.
316
- """
317
- kwargs_as_list = {k : [v ] for k , v in kwargs .items ()}
318
- return len (_coerce_sync (self .__wrapped__ )(["." ], ** kwargs_as_list )[0 ])
319
-
320
250
321
251
class LiteLLMEmbedder (BaseEmbedder ):
322
252
"""Pathway wrapper for `litellm.embedding`.
@@ -470,7 +400,16 @@ def __wrapped__(self, input: list[str], **kwargs) -> list[np.ndarray]:
470
400
""" # noqa: E501
471
401
472
402
kwargs = _extract_value_inside_dict (kwargs )
473
- constant_kwargs , per_row_kwargs = _split_batched_kwargs (kwargs )
403
+ constant_kwargs = {}
404
+ per_row_kwargs = {}
405
+
406
+ if kwargs :
407
+ for key , values in kwargs .items ():
408
+ v = values [0 ]
409
+ if all (value == v for value in values ):
410
+ constant_kwargs [key ] = v
411
+ else :
412
+ per_row_kwargs [key ] = values
474
413
475
414
# if kwargs are not the same for every input we cannot batch them
476
415
if per_row_kwargs :
0 commit comments