diff --git a/python/openai/openai_frontend/engine/triton_engine.py b/python/openai/openai_frontend/engine/triton_engine.py index 268a90c4b9..179b963b57 100644 --- a/python/openai/openai_frontend/engine/triton_engine.py +++ b/python/openai/openai_frontend/engine/triton_engine.py @@ -57,6 +57,8 @@ _create_trtllm_generate_request, _create_vllm_embedding_request, _create_vllm_generate_request, + _get_openai_chat_format_logprobs_from_vllm_response, + _get_openai_completion_format_logprobs_from_vllm_response, _get_output, _get_usage_from_response, _get_vllm_lora_names, @@ -66,6 +68,7 @@ from schemas.openai import ( ChatCompletionChoice, ChatCompletionFinishReason, + ChatCompletionLogprobs, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, ChatCompletionNamedToolChoice, @@ -255,13 +258,22 @@ async def chat( response, metadata.backend, RequestKind.GENERATION ) + # Parse logprobs if requested + logprobs_data = None + if request.logprobs: + openai_logprobs = _get_openai_chat_format_logprobs_from_vllm_response( + response + ) + if openai_logprobs: + logprobs_data = ChatCompletionLogprobs(content=openai_logprobs) + return CreateChatCompletionResponse( id=request_id, choices=[ ChatCompletionChoice( index=0, message=response_message, - logprobs=None, + logprobs=logprobs_data, finish_reason=finish_reason, ) ], @@ -360,10 +372,17 @@ async def completion( response, metadata.backend, RequestKind.GENERATION ) + # Parse logprobs if requested + logprobs_data = None + if request.logprobs is not None and request.logprobs > 0: + logprobs_data = _get_openai_completion_format_logprobs_from_vllm_response( + response + ) + choice = Choice( finish_reason=FinishReason.stop, index=0, - logprobs=None, + logprobs=logprobs_data, text=text, ) return CreateCompletionResponse( @@ -605,6 +624,15 @@ async def _streaming_chat_iterator( ) previous_text = current_text + # Parse logprobs for this chunk if requested + chunk_logprobs = None + if request.logprobs: + openai_logprobs = _get_openai_chat_format_logprobs_from_vllm_response( + response + ) + if openai_logprobs: + chunk_logprobs = ChatCompletionLogprobs(content=openai_logprobs) + # if the response delta is None (e.g. because it was a # "control token" for tool calls or the parser otherwise # wasn't ready to send a token, then @@ -618,7 +646,7 @@ async def _streaming_chat_iterator( choice = ChatCompletionStreamingResponseChoice( index=0, delta=response_delta, - logprobs=None, + logprobs=chunk_logprobs, finish_reason=finish_reason, ) @@ -791,8 +819,19 @@ def _validate_chat_request( f"Received n={request.n}, but only single choice (n=1) is currently supported" ) - if request.logit_bias is not None or request.logprobs: - raise ClientError("logit bias and log probs not currently supported") + if request.logit_bias is not None: + raise ClientError("logit bias is not currently supported") + + # Logprobs are only supported for vLLM backend currently + if metadata.backend != "vllm" and ( + request.logprobs is not None or request.top_logprobs is not None + ): + raise ClientError( + "logprobs are currently available only for the vLLM backend" + ) + + if request.top_logprobs is not None and not request.logprobs: + raise ClientError("`top_logprobs` can only be used when `logprobs` is True") self._verify_chat_tool_call_settings(request=request) @@ -847,16 +886,32 @@ async def _streaming_completion_iterator( model = request.model include_usage = request.stream_options and request.stream_options.include_usage usage_accumulator = _StreamingUsageAccumulator(backend) + current_offset = 0 async for response in responses: if include_usage: usage_accumulator.update(response) text = _get_output(response) + + # Parse logprobs for this chunk if requested + chunk_logprobs = None + if request.logprobs is not None and request.logprobs > 0: + chunk_logprobs = ( + _get_openai_completion_format_logprobs_from_vllm_response(response) + ) + # Adjust text offsets based on accumulated output + if chunk_logprobs and chunk_logprobs.text_offset: + chunk_logprobs.text_offset = [ + offset + current_offset for offset in chunk_logprobs.text_offset + ] + + current_offset += len(text) + choice = Choice( finish_reason=FinishReason.stop if response.final else None, index=0, - logprobs=None, + logprobs=chunk_logprobs, text=text, ) chunk = CreateCompletionResponse( @@ -942,8 +997,18 @@ def _validate_completion_request( f"Received best_of={request.best_of}, but only single choice (best_of=1) is currently supported" ) - if request.logit_bias is not None or request.logprobs is not None: - raise ClientError("logit bias and log probs not supported") + if request.logit_bias is not None: + raise ClientError("logit bias is not supported") + + # Logprobs are only supported for vLLM backend currently + if ( + request.logprobs is not None + and request.logprobs > 0 + and metadata.backend != "vllm" + ): + raise ClientError( + "logprobs are currently available only for the vLLM backend" + ) if request.stream_options and not request.stream: raise ClientError("`stream_options` can only be used when `stream` is True") diff --git a/python/openai/openai_frontend/engine/utils/triton.py b/python/openai/openai_frontend/engine/utils/triton.py index 545d9a72db..e2c4cf92c9 100644 --- a/python/openai/openai_frontend/engine/utils/triton.py +++ b/python/openai/openai_frontend/engine/utils/triton.py @@ -27,22 +27,26 @@ import json import os import re +import sys from dataclasses import asdict, dataclass, field from enum import Enum from pathlib import Path -from typing import Iterable, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Union import numpy as np import tritonserver from pydantic import BaseModel from schemas.openai import ( ChatCompletionNamedToolChoice, + ChatCompletionTokenLogprob, ChatCompletionToolChoiceOption1, CompletionUsage, CreateChatCompletionRequest, CreateCompletionRequest, CreateEmbeddingRequest, EmbeddingUsage, + Logprobs, + TopLogprob, ) from utils.utils import ClientError, ServerError @@ -83,6 +87,8 @@ def _create_vllm_generate_request( "max_completion_tokens", # will be handled explicitly "max_tokens", + "logprobs", + "top_logprobs", } # NOTE: The exclude_none is important, as internals may not support @@ -92,6 +98,7 @@ def _create_vllm_generate_request( exclude_none=True, ) + request_logprobs = False # Indicates CreateChatCompletionRequest if hasattr(request, "max_completion_tokens"): if request.max_completion_tokens is not None: @@ -102,11 +109,31 @@ def _create_vllm_generate_request( # If neither is set, use a default value for max_tokens else: sampling_parameters["max_tokens"] = default_max_tokens + + # Handle logprobs for chat completions + # OpenAI API: logprobs (bool), top_logprobs (int 0-20) + # vLLM API: logprobs (int) - number of top token logprobs to return + if request.logprobs and request.top_logprobs is not None: + sampling_parameters["logprobs"] = request.top_logprobs + request_logprobs = True + elif request.logprobs: + # If logprobs=True but top_logprobs not specified, default to 1 + sampling_parameters["logprobs"] = 1 + request_logprobs = True # Indicates CreateCompletionRequest - elif request.max_tokens is not None: - sampling_parameters["max_tokens"] = request.max_tokens else: - sampling_parameters["max_tokens"] = default_max_tokens + if request.max_tokens is not None: + sampling_parameters["max_tokens"] = request.max_tokens + else: + sampling_parameters["max_tokens"] = default_max_tokens + + # Handle logprobs for completions + # OpenAI API: logprobs (int 0-5) - number of top token log probs + # vLLM API: logprobs (int) - same behavior, pass directly + if request.logprobs is not None and request.logprobs > 0: + sampling_parameters["logprobs"] = request.logprobs + request_logprobs = True + inputs["return_logprobs"] = np.bool_([request_logprobs]) if lora_name is not None: sampling_parameters["lora_name"] = lora_name @@ -376,6 +403,161 @@ def _get_output(response: tritonserver._api._response.InferenceResponse) -> str: return "" +def _get_logprobs_from_response( + response: tritonserver._api._response.InferenceResponse, +) -> Optional[List[Dict]]: + """ + Extracts logprobs from a Triton inference response (vLLM backend). + + Returns: + List of dictionaries containing logprobs data, or None if not available. + Format: [ + { + token_id: { + "logprob": float, + "rank": int, + "decoded_token": str + } + }, + ... + ] + """ + if "logprobs" not in response.outputs: + return None + + logprobs_tensor = response.outputs["logprobs"] + if logprobs_tensor is None: + return None + + # The logprobs are stored as JSON string (vLLM backend) + logprobs_str = _to_string(logprobs_tensor) + + if logprobs_str == "null": + return None + + try: + logprobs_data = json.loads(logprobs_str) + return logprobs_data + except json.JSONDecodeError: + return None + + +def _get_openai_chat_format_logprobs_from_vllm_response( + response: tritonserver._api._response.InferenceResponse, +) -> Optional[List[ChatCompletionTokenLogprob]]: + """ + Convert logprobs from a Triton inference response (vLLM backend) to OpenAI chat completion format. + + Args: + response: Triton inference response containing logprobs output. + + Returns: + List of ChatCompletionTokenLogprob objects, or None if no logprobs available. + """ + vllm_logprobs = _get_logprobs_from_response(response) + + if not vllm_logprobs: + return None + + openai_logprobs = [] + for token_logprobs_dict in vllm_logprobs: + if not token_logprobs_dict: + continue + + # Sort by rank to identify the selected token (rank=1 is always the chosen token) + sorted_tokens = sorted( + token_logprobs_dict.items(), key=lambda x: x[1].get("rank", sys.maxsize) + ) + + # The first token (lowest rank) is the selected token + selected_token_id, selected_token_data = sorted_tokens[0] + selected_token = selected_token_data["decoded_token"] + selected_logprob = selected_token_data["logprob"] + + # Convert to bytes representation + token_bytes = list(selected_token.encode("utf-8")) + + top_logprobs_list = [] + for token_id, token_data in sorted_tokens: + decoded_token = token_data["decoded_token"] + top_logprobs_list.append( + TopLogprob( + token=decoded_token, + logprob=token_data["logprob"], + bytes=list(decoded_token.encode("utf-8")), + ) + ) + + openai_logprobs.append( + ChatCompletionTokenLogprob( + token=selected_token, + logprob=selected_logprob, + bytes=token_bytes, + top_logprobs=top_logprobs_list, + ) + ) + + return openai_logprobs + + +def _get_openai_completion_format_logprobs_from_vllm_response( + response: tritonserver._api._response.InferenceResponse, +) -> Optional[Logprobs]: + """ + Convert logprobs from a Triton inference response (vLLM backend) to OpenAI completion format. + + Args: + response: Triton inference response containing logprobs output. + + Returns: + Logprobs object for completions API, or None if no logprobs available. + """ + vllm_logprobs = _get_logprobs_from_response(response) + + if not vllm_logprobs: + return None + + text_offset = [] + token_logprobs = [] + tokens = [] + top_logprobs = [] + + current_offset = 0 + for token_logprobs_dict in vllm_logprobs: + if not token_logprobs_dict: + continue + + # Sort by rank to identify the selected token (rank=1 is always the chosen token) + sorted_tokens = sorted( + token_logprobs_dict.items(), key=lambda x: x[1].get("rank", sys.maxsize) + ) + + # The first token (lowest rank) is the selected token + selected_token_id, selected_token_data = sorted_tokens[0] + selected_token = selected_token_data["decoded_token"] + selected_logprob = selected_token_data["logprob"] + + text_offset.append(current_offset) + token_logprobs.append(selected_logprob) + tokens.append(selected_token) + + # Build top_logprobs dict for this position + top_logprobs_dict = {} + for token_id, token_data in sorted_tokens: + decoded_token = token_data["decoded_token"] + top_logprobs_dict[decoded_token] = token_data["logprob"] + top_logprobs.append(top_logprobs_dict) + + current_offset += len(selected_token) + + return Logprobs( + text_offset=text_offset, + token_logprobs=token_logprobs, + tokens=tokens, + top_logprobs=top_logprobs, + ) + + def _validate_triton_responses_non_streaming( responses: List[tritonserver._api._response.InferenceResponse], ): diff --git a/python/openai/openai_frontend/schemas/openai.py b/python/openai/openai_frontend/schemas/openai.py index 81ff6e93b3..57fb7b1017 100644 --- a/python/openai/openai_frontend/schemas/openai.py +++ b/python/openai/openai_frontend/schemas/openai.py @@ -530,7 +530,7 @@ class ChatCompletionTokenLogprob(BaseModel): ) -class Logprobs2(BaseModel): +class ChatCompletionLogprobs(BaseModel): content: List[ChatCompletionTokenLogprob] = Field( ..., description="A list of message content tokens with log probability information.", @@ -539,7 +539,7 @@ class Logprobs2(BaseModel): class ChatCompletionStreamingResponseChoice(BaseModel): delta: ChatCompletionStreamResponseDelta - logprobs: Optional[Logprobs2] = Field( + logprobs: Optional[ChatCompletionLogprobs] = Field( None, description="Log probability information for the choice." ) finish_reason: ChatCompletionFinishReason | None = Field( @@ -730,7 +730,7 @@ class ChatCompletionChoice(BaseModel): ..., description="The index of the choice in the list of choices." ) message: ChatCompletionResponseMessage - logprobs: Logprobs2 | None = Field( + logprobs: ChatCompletionLogprobs | None = Field( ..., description="Log probability information for the choice." ) diff --git a/python/openai/tests/test_chat_completions.py b/python/openai/tests/test_chat_completions.py index 4034def2cd..e708689e0e 100644 --- a/python/openai/tests/test_chat_completions.py +++ b/python/openai/tests/test_chat_completions.py @@ -154,13 +154,10 @@ def test_chat_completions_sampling_parameters( ) # FIXME: Add support and remove this check - unsupported_parameters = ["logprobs", "logit_bias"] + unsupported_parameters = ["logit_bias"] if param_key in unsupported_parameters: assert response.status_code == 400 - assert ( - response.json()["detail"] - == "logit bias and log probs not currently supported" - ) + assert response.json()["detail"] == "logit bias is not currently supported" return assert response.status_code == 200 @@ -522,10 +519,6 @@ def test_multi_lora(self): def test_request_n_choices(self): pass - @pytest.mark.skip(reason="Not Implemented Yet") - def test_request_logprobs(self): - pass - @pytest.mark.skip(reason="Not Implemented Yet") def test_request_logit_bias(self): pass @@ -548,6 +541,126 @@ def test_usage_response(self, client, model: str, messages: List[dict]): usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] ) + def test_chat_completions_logprobs( + self, client, backend: str, model: str, messages: List[dict] + ): + """Test logprobs parameter for chat completions.""" + response = client.post( + "/v1/chat/completions", + json={ + "model": model, + "messages": messages, + "logprobs": True, + "top_logprobs": 2, + "max_tokens": 10, + }, + ) + + # Non-vLLM backends should raise an error + if backend != "vllm": + assert response.status_code == 400 + assert ( + "logprobs are currently available only for the vLLM backend" + in response.json()["detail"] + ) + return + + assert response.status_code == 200 + response_json = response.json() + + # Check that logprobs are present in the response + choice = response_json["choices"][0] + assert "logprobs" in choice + logprobs = choice["logprobs"] + + assert logprobs is not None + assert "content" in logprobs + content = logprobs["content"] + assert isinstance(content, list) + assert len(content) > 0 + + # Validate structure of each token logprob + for token_logprob in content: + assert "token" in token_logprob + assert "logprob" in token_logprob + assert "bytes" in token_logprob + assert "top_logprobs" in token_logprob + + assert isinstance(token_logprob["token"], str) + assert isinstance(token_logprob["logprob"], (int, float)) + assert isinstance(token_logprob["bytes"], list) + assert isinstance(token_logprob["top_logprobs"], list) + + # Validate top_logprobs structure + for top_logprob in token_logprob["top_logprobs"]: + assert "token" in top_logprob + assert "logprob" in top_logprob + assert "bytes" in top_logprob + + def test_chat_completions_logprobs_false( + self, client, model: str, messages: List[dict] + ): + """Test that logprobs=False returns no logprobs.""" + response = client.post( + "/v1/chat/completions", + json={ + "model": model, + "messages": messages, + "logprobs": False, + "max_tokens": 10, + }, + ) + + assert response.status_code == 200 + response_json = response.json() + + # logprobs should be None when logprobs=False + choice = response_json["choices"][0] + assert choice.get("logprobs") is None + + @pytest.mark.parametrize("top_logprobs_value", [0, 5]) + def test_chat_completions_top_logprobs_without_logprobs( + self, client, model: str, messages: List[dict], top_logprobs_value: int + ): + """Test that top_logprobs without logprobs raises validation error.""" + response = client.post( + "/v1/chat/completions", + json={ + "model": model, + "messages": messages, + "top_logprobs": top_logprobs_value, + "max_tokens": 10, + }, + ) + + # Should raise validation error for any value when logprobs is not True + assert response.status_code == 400 + assert ( + "`top_logprobs` can only be used when `logprobs` is True" + in response.json()["detail"] + ) + + def test_chat_completions_top_logprobs_validation( + self, client, model: str, messages: List[dict] + ): + """Test that top_logprobs > 20 is rejected by schema validation.""" + response = client.post( + "/v1/chat/completions", + json={ + "model": model, + "messages": messages, + "logprobs": True, + "top_logprobs": 25, # Exceeds maximum of 20 + "max_tokens": 5, + }, + ) + + # Should raise schema validation error + assert response.status_code == 422 + assert "Input should be less than or equal to 20" in str( + response.json()["detail"] + ) + # For tests that won't use the same pytest fixture for server startup across # the whole class test suite. diff --git a/python/openai/tests/test_completions.py b/python/openai/tests/test_completions.py index 3c88a68acd..ea1cf62166 100644 --- a/python/openai/tests/test_completions.py +++ b/python/openai/tests/test_completions.py @@ -81,10 +81,10 @@ def test_completions_sampling_parameters( print("Response:", response.json()) # FIXME: Add support and remove this check - unsupported_parameters = ["logprobs", "logit_bias"] + unsupported_parameters = ["logit_bias"] if sampling_parameter in unsupported_parameters: assert response.status_code == 400 - assert response.json()["detail"] == "logit bias and log probs not supported" + assert response.json()["detail"] == "logit bias is not supported" return assert response.status_code == 200 @@ -397,3 +397,98 @@ def test_usage_response(self, client, model: str, prompt: str): assert ( usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] ) + + def test_completions_logprobs(self, client, backend: str, model: str, prompt: str): + """Test logprobs parameter for completions.""" + response = client.post( + "/v1/completions", + json={ + "model": model, + "prompt": prompt, + "logprobs": 3, + "max_tokens": 10, + }, + ) + + # Non-vLLM backends should raise an error + if backend != "vllm": + assert response.status_code == 400 + assert ( + "logprobs are currently available only for the vLLM backend" + in response.json()["detail"] + ) + return + + assert response.status_code == 200 + response_json = response.json() + + # Check that logprobs are present in the response + choice = response_json["choices"][0] + assert "logprobs" in choice + logprobs = choice["logprobs"] + + assert logprobs is not None + assert "text_offset" in logprobs + assert "token_logprobs" in logprobs + assert "tokens" in logprobs + assert "top_logprobs" in logprobs + + assert isinstance(logprobs["text_offset"], list) + assert isinstance(logprobs["token_logprobs"], list) + assert isinstance(logprobs["tokens"], list) + assert isinstance(logprobs["top_logprobs"], list) + + # All lists should have the same length + num_tokens = len(logprobs["tokens"]) + assert len(logprobs["text_offset"]) == num_tokens + assert len(logprobs["token_logprobs"]) == num_tokens + assert len(logprobs["top_logprobs"]) == num_tokens + + # Validate each token + for i in range(num_tokens): + assert isinstance(logprobs["tokens"][i], str) + assert isinstance(logprobs["token_logprobs"][i], (int, float)) + assert isinstance(logprobs["text_offset"][i], int) + assert isinstance(logprobs["top_logprobs"][i], dict) + + # Validate top_logprobs dict contains token -> logprob mappings + for token, logprob in logprobs["top_logprobs"][i].items(): + assert isinstance(token, str) + assert isinstance(logprob, (int, float)) + + def test_completions_logprobs_zero(self, client, model: str, prompt: str): + """Test that logprobs=0 returns no logprobs.""" + response = client.post( + "/v1/completions", + json={ + "model": model, + "prompt": prompt, + "logprobs": 0, + "max_tokens": 10, + }, + ) + + assert response.status_code == 200 + response_json = response.json() + + # logprobs should be None when logprobs=0 + choice = response_json["choices"][0] + assert choice.get("logprobs") is None + + def test_completions_logprobs_validation(self, client, model: str, prompt: str): + """Test that logprobs > 5 is rejected by schema validation.""" + response = client.post( + "/v1/completions", + json={ + "model": model, + "prompt": prompt, + "logprobs": 7, # Exceeds maximum of 5 + "max_tokens": 5, + }, + ) + + # Should raise schema validation error + assert response.status_code == 422 + assert "Input should be less than or equal to 5" in str( + response.json()["detail"] + ) diff --git a/python/openai/tests/test_models/mock_llm/config.pbtxt b/python/openai/tests/test_models/mock_llm/config.pbtxt index 5f665ff543..cac10b5de0 100644 --- a/python/openai/tests/test_models/mock_llm/config.pbtxt +++ b/python/openai/tests/test_models/mock_llm/config.pbtxt @@ -1,4 +1,4 @@ -# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -41,6 +41,12 @@ input [ name: "stream" data_type: TYPE_BOOL dims: [ 1, 1 ] + }, + { + name: "return_logprobs" + data_type: TYPE_BOOL + dims: [ 1, 1 ] + optional: true } ] diff --git a/python/openai/tests/test_openai_client.py b/python/openai/tests/test_openai_client.py index 0c02f69db3..e5ce4ae55e 100644 --- a/python/openai/tests/test_openai_client.py +++ b/python/openai/tests/test_openai_client.py @@ -26,6 +26,7 @@ from typing import List +import numpy as np import openai import pytest @@ -482,3 +483,234 @@ async def test_stream_options_without_streaming( stream_options={"include_usage": True}, ) assert "`stream_options` can only be used when `stream` is True" in str(e.value) + + @pytest.mark.asyncio + async def test_chat_completion_logprobs( + self, client: openai.AsyncOpenAI, backend: str, model: str, messages: List[dict] + ): + """Test logprobs for chat completions and compare streaming vs non-streaming.""" + # Non-vLLM backends should raise an error + if backend != "vllm": + with pytest.raises(openai.BadRequestError) as exc_info: + await client.chat.completions.create( + model=model, + messages=messages, + logprobs=True, + top_logprobs=2, + max_tokens=10, + ) + assert "logprobs are currently available only for the vLLM backend" in str( + exc_info.value + ) + return + + # Test non-streaming + seed = 0 + temperature = 0.0 + chat_completion = await client.chat.completions.create( + model=model, + messages=messages, + logprobs=True, + top_logprobs=2, + max_tokens=10, + temperature=temperature, + seed=seed, + stream=False, + ) + + assert chat_completion.choices[0].message.content + assert chat_completion.choices[0].logprobs is not None + + logprobs = chat_completion.choices[0].logprobs + assert logprobs.content is not None + assert len(logprobs.content) > 0 + + # Validate each token logprob + for token_logprob in logprobs.content: + assert token_logprob.token + assert isinstance(token_logprob.logprob, float) + assert isinstance(token_logprob.bytes, list) + assert token_logprob.top_logprobs is not None + assert len(token_logprob.top_logprobs) > 0 + + # Test streaming and compare with non-streaming + stream = await client.chat.completions.create( + model=model, + messages=messages, + logprobs=True, + top_logprobs=2, + max_tokens=10, + temperature=temperature, + seed=seed, + stream=True, + ) + + chunks = [] + stream_logprobs = [] + async for chunk in stream: + if chunk.choices[0].delta.content: + chunks.append(chunk.choices[0].delta.content) + if chunk.choices[0].logprobs and chunk.choices[0].logprobs.content: + stream_logprobs.extend(chunk.choices[0].logprobs.content) + + # Assert streaming output matches non-streaming + streamed_output = "".join(chunks) + assert streamed_output == chat_completion.choices[0].message.content + + # Assert both streaming and non-streaming produce logprobs + assert len(stream_logprobs) > 0, "Streaming should produce logprobs" + assert len(stream_logprobs) == len(logprobs.content), "Same number of tokens" + + # Compare tokens and logprob values (using np.allclose for float comparison) + stream_tokens_list = [t.token for t in stream_logprobs] + non_stream_tokens_list = [t.token for t in logprobs.content] + stream_logprobs_values = [t.logprob for t in stream_logprobs] + non_stream_logprobs_values = [t.logprob for t in logprobs.content] + + assert stream_tokens_list == non_stream_tokens_list, "Tokens should match" + assert np.allclose( + stream_logprobs_values, non_stream_logprobs_values, rtol=0, atol=1e-2 + ), "Logprob values should be close" + + @pytest.mark.asyncio + async def test_completion_logprobs( + self, client: openai.AsyncOpenAI, backend: str, model: str, prompt: str + ): + """Test logprobs for completions.""" + # Non-vLLM backends should raise an error + if backend != "vllm": + with pytest.raises(openai.BadRequestError) as exc_info: + await client.completions.create( + model=model, + prompt=prompt, + logprobs=3, + max_tokens=10, + ) + assert "logprobs are currently available only for the vLLM backend" in str( + exc_info.value + ) + return + + # Test non-streaming + seed = 0 + temperature = 0.0 + completion = await client.completions.create( + model=model, + prompt=prompt, + logprobs=3, + max_tokens=10, + temperature=temperature, + seed=seed, + stream=False, + ) + + assert completion.choices[0].text + assert completion.choices[0].logprobs is not None + + logprobs = completion.choices[0].logprobs + assert logprobs.tokens is not None + assert logprobs.token_logprobs is not None + assert logprobs.text_offset is not None + assert logprobs.top_logprobs is not None + + num_tokens = len(logprobs.tokens) + assert len(logprobs.token_logprobs) == num_tokens + assert len(logprobs.text_offset) == num_tokens + assert len(logprobs.top_logprobs) == num_tokens + + # Test streaming and compare with non-streaming + stream = await client.completions.create( + model=model, + prompt=prompt, + logprobs=3, + max_tokens=10, + temperature=temperature, + seed=seed, + stream=True, + ) + + chunks = [] + stream_tokens = [] + stream_token_logprobs = [] + stream_text_offsets = [] + stream_top_logprobs = [] + + async for chunk in stream: + if chunk.choices[0].text: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].logprobs: + lp = chunk.choices[0].logprobs + if lp.tokens: + stream_tokens.extend(lp.tokens) + if lp.token_logprobs: + stream_token_logprobs.extend(lp.token_logprobs) + if lp.text_offset: + stream_text_offsets.extend(lp.text_offset) + if lp.top_logprobs: + stream_top_logprobs.extend(lp.top_logprobs) + + # Assert streaming output matches non-streaming + streamed_output = "".join(chunks) + assert streamed_output == completion.choices[0].text + + # Compare values (using np.allclose for float comparison) + assert stream_tokens == logprobs.tokens, "Tokens should match" + assert stream_text_offsets == logprobs.text_offset, "Text offsets should match" + assert stream_top_logprobs == logprobs.top_logprobs, "Top logprobs should match" + assert np.allclose( + stream_token_logprobs, logprobs.token_logprobs, rtol=0, atol=1e-2 + ), "Token logprob values should be close" + + @pytest.mark.parametrize("top_logprobs_value", [0, 5]) + @pytest.mark.asyncio + async def test_top_logprobs_requires_logprobs( + self, + client: openai.AsyncOpenAI, + model: str, + messages: List[dict], + top_logprobs_value: int, + ): + """ + Test that top_logprobs without logprobs raises an error + """ + with pytest.raises(openai.BadRequestError) as exc_info: + await client.chat.completions.create( + model=model, + messages=messages, + top_logprobs=top_logprobs_value, # Without logprobs=True + max_tokens=5, + ) + assert "`top_logprobs` can only be used when `logprobs` is True" in str( + exc_info.value + ) + + @pytest.mark.asyncio + async def test_chat_top_logprobs_exceeds_max( + self, client: openai.AsyncOpenAI, model: str, messages: List[dict] + ): + """Test that top_logprobs > 20 raises schema validation error.""" + with pytest.raises(openai.UnprocessableEntityError) as exc_info: + await client.chat.completions.create( + model=model, + messages=messages, + logprobs=True, + top_logprobs=25, # Exceeds maximum of 20 + max_tokens=5, + ) + # Pydantic validation error + assert "less than or equal to 20" in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_completion_logprobs_exceeds_max( + self, client: openai.AsyncOpenAI, model: str, prompt: str + ): + """Test that logprobs > 5 raises schema validation error.""" + with pytest.raises(openai.UnprocessableEntityError) as exc_info: + await client.completions.create( + model=model, + prompt=prompt, + logprobs=7, # Exceeds maximum of 5 + max_tokens=5, + ) + # Pydantic validation error + assert "less than or equal to 5" in str(exc_info.value).lower()