Skip to content

Commit 767d747

Browse files
authored
support sse for vertex (#49)
1 parent 0a5959c commit 767d747

1 file changed

Lines changed: 55 additions & 7 deletions

File tree

app/services/providers/vertex_adapter.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import json
33
import time
4+
import uuid
45
from collections.abc import AsyncGenerator
56
from datetime import datetime, timezone
67
from typing import Any
@@ -236,15 +237,62 @@ def error_handler(error_text: str, http_status: int):
236237
# vertex doesn't do actual streaming, it just returns a stream of json objects
237238
url = f"{self._base_url}/v1/projects/{self.project_id}/locations/{self.location}/publishers/{self.publisher}/models/{model_name}:streamRawPredict"
238239
async def custom_stream_response(url, headers, anthropic_payload, model_name):
240+
"""Call Vertex streamRawPredict and convert the *single* SSE frame into OpenAI chunk format."""
241+
239242
async def stream_response() -> AsyncGenerator[bytes, None]:
240-
resp = await AnthropicAdapter.process_regular_response(url, headers, anthropic_payload, model_name, error_handler)
241-
resp['object'] = 'chat.completion.chunk'
242-
for choice in resp['choices']:
243-
choice['delta'] = choice['message']
244-
del choice['message']
245-
yield f"data: {json.dumps(resp)}\n\n".encode()
246-
yield b"data: [DONE]\n\n"
243+
async with aiohttp.ClientSession() as session:
244+
async with session.post(url, headers=headers, json=anthropic_payload) as response:
245+
if response.status != 200:
246+
error_text = await response.text()
247+
error_handler(error_text, response.status)
248+
249+
# Read the entire event-stream; Vertex currently responds with a few data: lines
250+
body = await response.text()
251+
252+
# Extract JSON payload(s) from SSE lines that start with "data: "
253+
payloads: list[dict[str, Any]] = []
254+
for line in body.splitlines():
255+
line = line.strip()
256+
if not line.startswith("data:"):
257+
continue
258+
data_part = line[len("data:"):].strip()
259+
if data_part == "[DONE]":
260+
continue
261+
try:
262+
payloads.append(json.loads(data_part))
263+
except json.JSONDecodeError:
264+
continue
265+
266+
if not payloads:
267+
raise ProviderAPIException("Vertex", response.status, "Empty response from Vertex streamRawPredict")
268+
269+
# Vertex typically returns a single JSON object – use the first
270+
vertex_resp = payloads[0]
271+
272+
# Convert to OpenAI chunk structure expected by Forge callers
273+
openai_chunk = {
274+
"id": vertex_resp.get("responseId", f"chatcmpl-{uuid.uuid4().hex}"),
275+
"object": "chat.completion.chunk",
276+
"created": int(time.time()),
277+
"model": model_name,
278+
"choices": [
279+
{
280+
"index": 0,
281+
"delta": {
282+
"role": "assistant",
283+
"content": vertex_resp.get("candidates", [{}])[0].get("content", "")
284+
},
285+
"finish_reason": None,
286+
}
287+
],
288+
}
289+
290+
# Yield the chunk then DONE, mimicking OpenAI stream format
291+
yield f"data: {json.dumps(openai_chunk)}\n\n".encode()
292+
yield b"data: [DONE]\n\n"
293+
247294
return stream_response()
295+
248296
return await custom_stream_response(url, headers, anthropic_payload, model_name)
249297
else:
250298
url = f"{self._base_url}/v1/projects/{self.project_id}/locations/{self.location}/publishers/{self.publisher}/models/{model_name}:rawPredict"

0 commit comments

Comments
 (0)