Skip to content

Commit d038d95

Browse files
committed
feat(demohouse/shopping): mock vdb
1 parent bf2932c commit d038d95

File tree

4 files changed

+3043
-0
lines changed

4 files changed

+3043
-0
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import json
2+
import logging
3+
import os
4+
from typing import AsyncIterable
5+
from arkitect.launcher.local.serve import launch_serve
6+
from arkitect.telemetry.trace import task
7+
from arkitect.types.llm.model import ArkChatRequest, ArkChatParameters
8+
from volcenginesdkarkruntime.types.chat import (
9+
ChatCompletionChunk,
10+
ChatCompletionContentPartImageParam,
11+
)
12+
13+
from arkitect.core.component.context.context import Context
14+
15+
from arkitect.core.runtime import Response
16+
from volcenginesdkarkruntime.types.chat.chat_completion_chunk import Choice, ChoiceDelta
17+
18+
from arkitect.core.component.context.model import State
19+
20+
from arkitect.types.llm.model import ChatCompletionMessageToolCallParam
21+
from vdb import vector_search
22+
23+
logger = logging.getLogger(__name__)
24+
25+
DOUBAO_VLM_ENDPOINT = "doubao-1-5-pro-32k-250115"
26+
27+
28+
@task()
29+
async def default_model_calling(
30+
request: ArkChatRequest,
31+
) -> AsyncIterable[ChatCompletionChunk]:
32+
parameters = ArkChatParameters(**request.__dict__)
33+
image_urls = [
34+
content.get("image_url", {}).get("url", "")
35+
for message in request.messages
36+
if isinstance(message.content, list)
37+
for content in message.content
38+
if isinstance(content, ChatCompletionContentPartImageParam)
39+
]
40+
image_url = image_urls[-1] if len(image_urls) > 0 else ""
41+
42+
async def modify_url_hook(
43+
state: State, param: ChatCompletionMessageToolCallParam
44+
) -> ChatCompletionMessageToolCallParam:
45+
arguments = json.loads(param["function"]["arguments"])
46+
arguments["image_url"] = image_url
47+
param["function"]["arguments"] = json.dumps(arguments)
48+
return param
49+
50+
async with Context(
51+
model=DOUBAO_VLM_ENDPOINT, tools=[vector_search], parameters=parameters
52+
) as ctx:
53+
ctx.tool_hooks.update(vector_search=[modify_url_hook])
54+
stream = await ctx.completions.create(
55+
messages=[m.model_dump() for m in request.messages], stream=True
56+
)
57+
tool_call = False
58+
async for chunk in stream:
59+
if tool_call:
60+
tool_result = ctx.get_latest_message()
61+
chunk.choices.append(
62+
Choice(
63+
role="tool",
64+
delta=ChoiceDelta(content=tool_result.get("content")),
65+
index=len(chunk.choices),
66+
)
67+
)
68+
yield chunk
69+
if chunk.choices and chunk.choices[0].finish_reason == "tool_calls":
70+
tool_call = True
71+
72+
73+
@task()
74+
async def main(request: ArkChatRequest) -> AsyncIterable[Response]:
75+
async for resp in default_model_calling(request):
76+
yield resp
77+
78+
79+
if __name__ == "__main__":
80+
port = os.getenv("_FAAS_RUNTIME_PORT")
81+
launch_serve(
82+
package_path="main",
83+
port=int(port) if port else 8888,
84+
health_check_path="/v1/ping",
85+
endpoint_path="/api/v3/bots/chat/completions",
86+
)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import os
2+
3+
import tos
4+
5+
from httpx import Timeout
6+
from tos import HttpMethodType
7+
from volcenginesdkarkruntime import AsyncArk
8+
from volcenginesdkarkruntime.types.multimodal_embedding import (
9+
MultimodalEmbeddingContentPartTextParam,
10+
MultimodalEmbeddingResponse,
11+
MultimodalEmbeddingContentPartImageParam,
12+
)
13+
from volcengine.viking_db import *
14+
from volcenginesdkarkruntime.types.multimodal_embedding.embedding_content_part_image_param import (
15+
ImageURL,
16+
)
17+
18+
19+
COLLECTION_NAME = "shopping_demo"
20+
INDEX_NAME = "shopping_demo"
21+
MODEL_NAME = "doubao-embedding-vision-241215"
22+
LIMIT = 6
23+
SCORE_THRESHOLD = 300
24+
25+
vikingdb_service = VikingDBService(
26+
host="api-vikingdb.volces.com",
27+
region="cn-beijing",
28+
scheme="https",
29+
connection_timeout=30,
30+
socket_timeout=30,
31+
)
32+
vikingdb_service.set_ak(os.environ.get("VOLC_ACCESSKEY"))
33+
vikingdb_service.set_sk(os.environ.get("VOLC_SECRETKEY"))
34+
35+
tos_client = tos.TosClientV2(
36+
os.getenv("VOLC_ACCESSKEY"),
37+
os.getenv("VOLC_SECRETKEY"),
38+
"tos-cn-beijing.volces.com",
39+
"cn-beijing",
40+
)
41+
42+
43+
async def vector_search(text: str, image_url: str) -> str:
44+
"""获取商品相关信息,当想要了解商品信息,比如价格,详细介绍,销量,评价时调用该工具
45+
46+
Args:
47+
text: 商品的描述信息
48+
image_url: 固定填写为<image_url>
49+
"""
50+
client = AsyncArk(timeout=Timeout(connect=1.0, timeout=60.0))
51+
embedding_input = [MultimodalEmbeddingContentPartTextParam(type="text", text=text)]
52+
if image_url != "":
53+
embedding_input.append(
54+
MultimodalEmbeddingContentPartImageParam(
55+
type="image_url", image_url=ImageURL(url=image_url)
56+
)
57+
)
58+
resp: MultimodalEmbeddingResponse = await client.multimodal_embeddings.create(
59+
model=MODEL_NAME,
60+
input=embedding_input,
61+
)
62+
embedding = resp.data.get("embedding", [])
63+
index = await vikingdb_service.async_get_index(COLLECTION_NAME, INDEX_NAME)
64+
retrieve = await index.async_search_by_vector(vector=embedding, limit=LIMIT)
65+
retrieve_fields = [
66+
json.loads(result.fields.get("data"))
67+
for result in retrieve
68+
if result.score > SCORE_THRESHOLD
69+
]
70+
mock_data = [
71+
{
72+
"名称": item.get("Name", ""),
73+
"类别": item.get("category", ""),
74+
"子类别": item.get("sub_category", ""),
75+
"价格": item.get("price", "99"),
76+
"销量": item.get("sales", "999"),
77+
"商品链接": tos_client.pre_signed_url(
78+
http_method=HttpMethodType.Http_Method_Get,
79+
bucket="shopping",
80+
key=item.get("key", ""),
81+
expires=600,
82+
).signed_url,
83+
}
84+
for item in retrieve_fields
85+
]
86+
return json.dumps(mock_data, ensure_ascii=False)

0 commit comments

Comments
 (0)