Skip to content

Commit fd0c122

Browse files
authored
feat(storage): add vector and analytics buckets support (#1318)
1 parent adffee4 commit fd0c122

File tree

12 files changed

+794
-15
lines changed

12 files changed

+794
-15
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ help::
3232

3333
mypy: $(call FORALL_PKGS,mypy)
3434
help::
35-
@echo " mypy -- Run mypy on all files"
35+
@echo " mypy -- Run mypy on all files"
3636

3737
ruff:
3838
@uv run ruff check --fix

src/storage/run-unasync.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55
paths = Path("src/storage3").glob("**/*.py")
66
tests = Path("tests").glob("**/*.py")
77

8-
rules = (unasync._DEFAULT_RULE,)
8+
rules = (
9+
unasync.Rule(
10+
fromdir="/_async/",
11+
todir="/_sync/",
12+
additional_replacements={"AsyncClient": "Client"},
13+
),
14+
unasync._DEFAULT_RULE,
15+
)
16+
917

1018
files = [str(p) for p in list(paths) + list(tests)]
1119

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import List, Optional
2+
3+
from httpx import QueryParams
4+
5+
from ..types import (
6+
AnalyticsBucket,
7+
AnalyticsBucketDeleteResponse,
8+
AnalyticsBucketsParser,
9+
SortColumn,
10+
SortOrder,
11+
)
12+
from .request import AsyncRequestBuilder
13+
14+
15+
class AsyncStorageAnalyticsClient:
16+
def __init__(self, request: AsyncRequestBuilder) -> None:
17+
self._request = request
18+
19+
async def create(self, bucket_name: str) -> AnalyticsBucket:
20+
body = {"name": bucket_name}
21+
data = await self._request.send(http_method="POST", path=["bucket"], body=body)
22+
return AnalyticsBucket.model_validate_json(data.content)
23+
24+
async def list(
25+
self,
26+
limit: Optional[int] = None,
27+
offset: Optional[int] = None,
28+
sort_column: Optional[SortColumn] = None,
29+
sort_order: Optional[SortOrder] = None,
30+
search: Optional[str] = None,
31+
) -> List[AnalyticsBucket]:
32+
params = dict(
33+
limit=limit,
34+
offset=offset,
35+
sort_column=sort_column,
36+
sort_order=sort_order,
37+
search=search,
38+
)
39+
filtered_params = QueryParams(
40+
**{k: v for k, v in params.items() if v is not None}
41+
)
42+
data = await self._request.send(
43+
http_method="GET", path=["bucket"], query_params=filtered_params
44+
)
45+
return AnalyticsBucketsParser.validate_json(data.content)
46+
47+
async def delete(self, bucket_name: str) -> AnalyticsBucketDeleteResponse:
48+
data = await self._request.send(
49+
http_method="DELETE", path=["bucket", bucket_name]
50+
)
51+
return AnalyticsBucketDeleteResponse.model_validate_json(data.content)

src/storage/src/storage3/_async/client.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88
from storage3.constants import DEFAULT_TIMEOUT
99

1010
from ..version import __version__
11+
from .analytics import AsyncStorageAnalyticsClient
1112
from .bucket import AsyncStorageBucketAPI
1213
from .file_api import AsyncBucketProxy
14+
from .request import AsyncRequestBuilder
15+
from .vectors import AsyncStorageVectorsClient
1316

1417
__all__ = [
1518
"AsyncStorageClient",
@@ -80,3 +83,18 @@ def from_(self, id: str) -> AsyncBucketProxy:
8083
The unique identifier of the bucket
8184
"""
8285
return AsyncBucketProxy(id, self._base_url, self._headers, self._client)
86+
87+
def vectors(self) -> AsyncStorageVectorsClient:
88+
return AsyncStorageVectorsClient(
89+
url=self._base_url.joinpath("v1", "vector"),
90+
headers=self._headers,
91+
session=self.session,
92+
)
93+
94+
def analytics(self) -> AsyncStorageAnalyticsClient:
95+
request = AsyncRequestBuilder(
96+
session=self.session,
97+
headers=self._headers,
98+
base_url=self._base_url.joinpath("v1", "iceberg"),
99+
)
100+
return AsyncStorageAnalyticsClient(request=request)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Optional
2+
3+
from httpx import AsyncClient, Headers, HTTPStatusError, QueryParams, Response
4+
from pydantic import ValidationError
5+
from yarl import URL
6+
7+
from ..exceptions import StorageApiError, VectorBucketErrorMessage
8+
from ..types import JSON, RequestMethod
9+
10+
11+
class AsyncRequestBuilder:
12+
def __init__(self, session: AsyncClient, base_url: URL, headers: Headers) -> None:
13+
self._session = session
14+
self._base_url = base_url
15+
self.headers = headers
16+
17+
async def send(
18+
self,
19+
http_method: RequestMethod,
20+
path: list[str],
21+
body: JSON = None,
22+
query_params: Optional[QueryParams] = None,
23+
) -> Response:
24+
response = await self._session.request(
25+
method=http_method,
26+
json=body,
27+
url=str(self._base_url.joinpath(*path)),
28+
headers=self.headers,
29+
params=query_params or QueryParams(),
30+
)
31+
try:
32+
response.raise_for_status()
33+
return response
34+
except HTTPStatusError as exc:
35+
try:
36+
error = VectorBucketErrorMessage.model_validate_json(response.content)
37+
raise StorageApiError(
38+
message=error.message,
39+
code=error.code or "400",
40+
status=error.statusCode,
41+
) from exc
42+
except ValidationError as exc:
43+
raise StorageApiError(
44+
message=f"The request failed, but could not parse error message response:'{response.text}'",
45+
code="LibraryError",
46+
status=response.status_code,
47+
) from exc
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
from __future__ import annotations
2+
3+
from typing import List, Optional
4+
5+
from httpx import AsyncClient, Headers
6+
from yarl import URL
7+
8+
from ..exceptions import StorageApiError, VectorBucketException
9+
from ..types import (
10+
JSON,
11+
DistanceMetric,
12+
GetVectorBucketResponse,
13+
GetVectorIndexResponse,
14+
GetVectorsResponse,
15+
ListVectorBucketsResponse,
16+
ListVectorIndexesResponse,
17+
ListVectorsResponse,
18+
MetadataConfiguration,
19+
QueryVectorsResponse,
20+
VectorData,
21+
VectorFilter,
22+
VectorObject,
23+
)
24+
from .request import AsyncRequestBuilder
25+
26+
27+
# used to not send non-required values as `null`
28+
# for they cannot be null
29+
def remove_none(**kwargs: JSON) -> JSON:
30+
return {key: val for key, val in kwargs.items() if val is not None}
31+
32+
33+
class AsyncVectorBucketScope:
34+
def __init__(self, request: AsyncRequestBuilder, bucket_name: str) -> None:
35+
self._request = request
36+
self._bucket_name = bucket_name
37+
38+
def with_metadata(self, **data: JSON) -> JSON:
39+
return remove_none(vectorBucketName=self._bucket_name, **data)
40+
41+
async def create_index(
42+
self,
43+
index_name: str,
44+
dimension: int,
45+
distance_metric: DistanceMetric,
46+
data_type: str,
47+
metadata: Optional[MetadataConfiguration] = None,
48+
) -> None:
49+
body = self.with_metadata(
50+
indexName=index_name,
51+
dimension=dimension,
52+
distanceMetric=distance_metric,
53+
dataType=data_type,
54+
metadataConfiguration=metadata.model_dump(by_alias=True)
55+
if metadata
56+
else None,
57+
)
58+
await self._request.send(http_method="POST", path=["CreateIndex"], body=body)
59+
60+
async def get_index(self, index_name: str) -> Optional[GetVectorIndexResponse]:
61+
body = self.with_metadata(indexName=index_name)
62+
try:
63+
data = await self._request.send(
64+
http_method="POST", path=["GetIndex"], body=body
65+
)
66+
return GetVectorIndexResponse.model_validate_json(data.content)
67+
except StorageApiError:
68+
return None
69+
70+
async def list_indexes(
71+
self,
72+
next_token: Optional[str] = None,
73+
max_results: Optional[int] = None,
74+
prefix: Optional[str] = None,
75+
) -> ListVectorIndexesResponse:
76+
body = self.with_metadata(
77+
next_token=next_token, max_results=max_results, prefix=prefix
78+
)
79+
data = await self._request.send(
80+
http_method="POST", path=["ListIndexes"], body=body
81+
)
82+
return ListVectorIndexesResponse.model_validate_json(data.content)
83+
84+
async def delete_index(self, index_name: str) -> None:
85+
body = self.with_metadata(indexName=index_name)
86+
await self._request.send(http_method="POST", path=["DeleteIndex"], body=body)
87+
88+
def index(self, index_name: str) -> AsyncVectorIndexScope:
89+
return AsyncVectorIndexScope(self._request, self._bucket_name, index_name)
90+
91+
92+
class AsyncVectorIndexScope:
93+
def __init__(
94+
self, request: AsyncRequestBuilder, bucket_name: str, index_name: str
95+
) -> None:
96+
self._request = request
97+
self._bucket_name = bucket_name
98+
self._index_name = index_name
99+
100+
def with_metadata(self, **data: JSON) -> JSON:
101+
return remove_none(
102+
vectorBucketName=self._bucket_name,
103+
indexName=self._index_name,
104+
**data,
105+
)
106+
107+
async def put(self, vectors: List[VectorObject]) -> None:
108+
body = self.with_metadata(
109+
vectors=[v.model_dump(exclude_none=True) for v in vectors]
110+
)
111+
await self._request.send(http_method="POST", path=["PutVectors"], body=body)
112+
113+
async def get(
114+
self, *keys: str, return_data: bool = True, return_metadata: bool = True
115+
) -> GetVectorsResponse:
116+
body = self.with_metadata(
117+
keys=keys, returnData=return_data, returnMetadata=return_metadata
118+
)
119+
data = await self._request.send(
120+
http_method="POST", path=["GetVectors"], body=body
121+
)
122+
return GetVectorsResponse.model_validate_json(data.content)
123+
124+
async def list(
125+
self,
126+
max_results: Optional[int] = None,
127+
next_token: Optional[str] = None,
128+
return_data: bool = True,
129+
return_metadata: bool = True,
130+
segment_count: Optional[int] = None,
131+
segment_index: Optional[int] = None,
132+
) -> ListVectorsResponse:
133+
body = self.with_metadata(
134+
maxResults=max_results,
135+
nextToken=next_token,
136+
returnData=return_data,
137+
returnMetadata=return_metadata,
138+
segmentCount=segment_count,
139+
segmentIndex=segment_index,
140+
)
141+
data = await self._request.send(
142+
http_method="POST", path=["ListVectors"], body=body
143+
)
144+
return ListVectorsResponse.model_validate_json(data.content)
145+
146+
async def query(
147+
self,
148+
query_vector: VectorData,
149+
topK: Optional[int] = None,
150+
filter: Optional[VectorFilter] = None,
151+
return_distance: bool = True,
152+
return_metadata: bool = True,
153+
) -> QueryVectorsResponse:
154+
body = self.with_metadata(
155+
queryVector=dict(query_vector),
156+
topK=topK,
157+
filter=filter,
158+
returnDistance=return_distance,
159+
returnMetadata=return_metadata,
160+
)
161+
data = await self._request.send(
162+
http_method="POST", path=["QueryVectors"], body=body
163+
)
164+
return QueryVectorsResponse.model_validate_json(data.content)
165+
166+
async def delete(self, keys: List[str]) -> None:
167+
if len(keys) < 1 or len(keys) > 500:
168+
raise VectorBucketException("Keys batch size must be between 1 and 500.")
169+
body = self.with_metadata(keys=keys)
170+
await self._request.send(http_method="POST", path=["DeleteVectors"], body=body)
171+
172+
173+
class AsyncStorageVectorsClient:
174+
def __init__(self, url: URL, headers: Headers, session: AsyncClient) -> None:
175+
self._request = AsyncRequestBuilder(session, base_url=URL(url), headers=headers)
176+
177+
def from_(self, bucket_name: str) -> AsyncVectorBucketScope:
178+
return AsyncVectorBucketScope(self._request, bucket_name)
179+
180+
async def create_bucket(self, bucket_name: str) -> None:
181+
body = {"vectorBucketName": bucket_name}
182+
await self._request.send(
183+
http_method="POST", path=["CreateVectorBucket"], body=body
184+
)
185+
186+
async def get_bucket(self, bucket_name: str) -> Optional[GetVectorBucketResponse]:
187+
body = {"vectorBucketName": bucket_name}
188+
try:
189+
data = await self._request.send(
190+
http_method="POST", path=["GetVectorBucket"], body=body
191+
)
192+
return GetVectorBucketResponse.model_validate_json(data.content)
193+
except StorageApiError:
194+
return None
195+
196+
async def list_buckets(
197+
self,
198+
prefix: Optional[str] = None,
199+
max_results: Optional[int] = None,
200+
next_token: Optional[str] = None,
201+
) -> ListVectorBucketsResponse:
202+
body = remove_none(prefix=prefix, maxResults=max_results, nextToken=next_token)
203+
data = await self._request.send(
204+
http_method="POST", path=["ListVectorBuckets"], body=body
205+
)
206+
return ListVectorBucketsResponse.model_validate_json(data.content)
207+
208+
async def delete_bucket(self, bucket_name: str) -> None:
209+
body = {"vectorBucketName": bucket_name}
210+
await self._request.send(
211+
http_method="POST", path=["DeleteVectorBucket"], body=body
212+
)

0 commit comments

Comments
 (0)