diff --git a/pylock.toml b/pylock.toml index 11fecb48..c9329a22 100644 --- a/pylock.toml +++ b/pylock.toml @@ -4227,4 +4227,4 @@ strategy = ["inherit_metadata", "static_urls"] requires_python = "~=3.12" [[tool.pdm.targets]] -requires_python = ">=3.10.0,<3.12" +requires_python = ">=3.10.0,<3.12" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f1624d3f..8fe6d950 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,13 @@ include = ["*"] [tool.pdm] distribution = true +[[tool.pdm.source]] +name = "torch" +type = "find_links" +#url = "https://download.pytorch.org/whl/cpu/torch_stable.html" +url = "https://download.pytorch.org/whl/cpu/torch/" +include_packages = ["torch"] + # ************************************************ # ********** Project Metadata ********** @@ -54,29 +61,25 @@ dependencies = [ "httpx[http2]<1.0.0", "loguru", "msgpack", - "numpy", + "numpy<2.0.0", "pillow", "protobuf", "pydantic>=2.11.7", "pydantic-settings>=2.0.0", + "pydub", "pyyaml>=6.0.0", "rich", "sanic", "transformers", "uvloop>=0.18", + "librosa>=0.11.0", + "torch", ] [project.optional-dependencies] -perf = [ - "orjson", - "msgpack", - "msgspec", - "uvloop", -] -recommended = [ - "tiktoken>=0.11.0", # For OpenAI tokenizer - "blobfile>=3.1.0", # For OpenAI tokenizer -] +perf = ["orjson", "msgpack", "msgspec", "uvloop"] +openai = ["tiktoken>=0.11.0", "blobfile>=3.1.0"] +recommended = ["guidellm[perf,openai]"] dev = [ # build "build>=1.0.0", @@ -118,7 +121,7 @@ dev = [ ] [dependency-groups] -dev = [ "guidellm[dev]" ] +dev = ["guidellm[dev]"] [project.urls] homepage = "https://github.com/vllm-project/guidellm" diff --git a/src/guidellm/__init__.py b/src/guidellm/__init__.py index f2206e94..f466073e 100644 --- a/src/guidellm/__init__.py +++ b/src/guidellm/__init__.py @@ -7,6 +7,8 @@ import logging import os +from datasets import config + with ( open(os.devnull, "w") as devnull, # noqa: PTH123 contextlib.redirect_stderr(devnull), @@ -19,6 +21,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Silence warnings for tokenizers hf_logging.set_verbosity_error() logging.getLogger("transformers").setLevel(logging.ERROR) + config.USE_AUDIO_DECODE = False from .logger import configure_logger, logger from .settings import ( diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index dbc8e1da..680ac852 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -28,39 +28,26 @@ import asyncio import codecs from pathlib import Path -from typing import Annotated import click -from pydantic import ValidationError try: import uvloop - - HAS_UVLOOP: Annotated[ - bool, "Flag indicating if uvloop is available for event loop optimization" - ] = True except ImportError: uvloop = None - HAS_UVLOOP: Annotated[ - bool, "Flag indicating if uvloop is available for event loop optimization" - ] = False - from guidellm.backends import BackendType from guidellm.benchmark import ( GenerativeConsoleBenchmarkerProgress, - InjectExtrasAggregator, ProfileType, benchmark_generative_text, reimport_benchmarks_report, ) -from guidellm.benchmark.scenario import ( - GenerativeTextScenario, - get_builtin_scenarios, -) +from guidellm.benchmark.scenario import GenerativeTextScenario from guidellm.mock_server import MockServer, MockServerConfig from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset from guidellm.scheduler import StrategyType +from guidellm.schemas import GenerativeRequestType from guidellm.settings import print_config from guidellm.utils import Console, DefaultGroupHandler, get_literal_vals from guidellm.utils import cli as cli_tools @@ -136,25 +123,25 @@ def benchmark(): help="Run a benchmark against a generative model using the specified arguments.", context_settings={"auto_envvar_prefix": "GUIDELLM"}, ) -@click.option( - "--scenario", - type=cli_tools.Union( - click.Path( - exists=True, - readable=True, - file_okay=True, - dir_okay=False, - path_type=Path, - ), - click.Choice(get_builtin_scenarios()), - ), - default=None, - help=( - "The name of a builtin scenario or path to a config file. " - "Missing values from the config will use defaults. " - "Options specified on the commandline will override the scenario." - ), -) +# @click.option( +# "--scenario", +# type=cli_tools.Union( +# click.Path( +# exists=True, +# readable=True, +# file_okay=True, +# dir_okay=False, +# path_type=Path, +# ), +# click.Choice(get_builtin_scenarios()), +# ), +# default=None, +# help=( +# "The name of a builtin scenario or path to a config file. " +# "Missing values from the config will use defaults. " +# "Options specified on the commandline will override the scenario." +# ), +# ) @click.option( "--target", type=str, @@ -163,6 +150,7 @@ def benchmark(): @click.option( "--data", type=str, + multiple=True, help=( "The HuggingFace dataset ID, a path to a HuggingFace dataset, " "a path to a data file csv, json, jsonl, or txt, " @@ -191,12 +179,6 @@ def benchmark(): "For rate-type=synchronous,throughput, this must not be set." ), ) -@click.option( - "--random-seed", - default=GenerativeTextScenario.get_default("random_seed"), - type=int, - help="The random seed to use for benchmarking to ensure reproducibility.", -) # Backend configuration @click.option( "--backend", @@ -217,9 +199,7 @@ def benchmark(): default=GenerativeTextScenario.get_default("backend_kwargs"), help=( "A JSON string containing any arguments to pass to the backend as a " - "dict with **kwargs. Headers can be removed by setting their value to " - "null. For example: " - """'{"headers": {"Authorization": null, "Custom-Header": "Custom-Value"}}'""" + "dict with **kwargs." ), ) @click.option( @@ -232,6 +212,24 @@ def benchmark(): ), ) # Data configuration +@click.option( + "--request-type", + default="chat_completions", + type=click.Choice(list(get_literal_vals(GenerativeRequestType))), + help=( + "The type of request to create for each data sample and send to the backend. " + f"Supported types: {list(get_literal_vals(GenerativeRequestType))}." + ), +) +@click.option( + "--request-formatter-kwargs", + default=None, + callback=cli_tools.parse_json, + help=( + "A JSON string containing any arguments to pass to the request formatter " + "as a dict with **kwargs." + ), +) @click.option( "--processor", default=GenerativeTextScenario.get_default("processor"), @@ -253,22 +251,60 @@ def benchmark(): ) @click.option( "--data-args", - default=GenerativeTextScenario.get_default("data_args"), + multiple=True, + default=None, callback=cli_tools.parse_json, help=( "A JSON string containing any arguments to pass to the dataset creation " "as a dict with **kwargs." ), ) +@click.option( + "--data-samples", + default=-1, + type=int, + help=( + "The number of samples to use from the dataset. If -1 (default), will use all " + "samples in the dataset and dynamically generate samples. " + "If >1, will precompile that number of items from the dataset configs." + ), +) +@click.option( + "--data-column-mappings", + default=None, + callback=cli_tools.parse_json, + help=( + "A JSON string of column mappings to apply to the dataset to map into request " + "column types." + ), +) @click.option( "--data-sampler", - default=GenerativeTextScenario.get_default("data_sampler"), - type=click.Choice(["random"]), + default=None, + type=click.Choice(["shuffle"]), + help="The data sampler type to use.", +) +@click.option( + "--data-num-workers", + default=None, + type=int, + help="The number of worker processes to use for data loading.", +) +@click.option( + "--dataloader_kwargs", + default=None, + callback=cli_tools.parse_json, help=( - "The data sampler type to use. 'random' will add a random shuffle on the data. " - "Defaults to None" + "A JSON string containing any arguments to pass to the dataloader constructor " + "as a dict with **kwargs." ), ) +@click.option( + "--random-seed", + default=GenerativeTextScenario.get_default("random_seed"), + type=int, + help="The random seed to use for benchmarking to ensure reproducibility.", +) # Output configuration @click.option( "--output-path", @@ -311,11 +347,6 @@ def benchmark(): help="Set this flag to display stats for the processes running the benchmarks", ) # Aggregators configuration -@click.option( - "--output-extras", - callback=cli_tools.parse_json, - help="A JSON string of extra data to save with the output benchmarks", -) @click.option( "--warmup", "--warmup-percent", # legacy alias @@ -345,10 +376,9 @@ def benchmark(): ), ) @click.option( - "--request-samples", + "--sample-requests", "--output-sampling", # legacy alias - "request_samples", - default=GenerativeTextScenario.get_default("request_samples"), + "sample_requests", type=int, help=( "The number of samples for each request status and each benchmark to save " @@ -393,7 +423,45 @@ def benchmark(): default=GenerativeTextScenario.get_default("max_global_error_rate"), help="Maximum global error rate allowed across all benchmarks", ) -def run(**kwargs): +def run( + target, + data, + profile, + rate, + # Backend Configuration + backend, + backend_kwargs, + model, + # Data configuration + request_type, + request_formatter_kwargs, + processor, + processor_args, + data_args, + data_samples, + data_column_mappings, + data_sampler, + data_num_workers, + dataloader_kwargs, + random_seed, + # Output configuration + output_path, + output_formats, + # Updates configuration + disable_console_outputs, + disable_progress, + display_scheduler_stats, + # Benchmarker configuration + sample_requests, + warmup, + cooldown, + # Constraints configuration + max_seconds, + max_requests, + max_errors, + max_error_rate, + max_global_error_rate, +): """ Execute a generative text benchmark against a target model backend. @@ -402,53 +470,58 @@ def run(**kwargs): Supports multiple backends, data sources, output formats, and constraint types for flexible benchmark configuration. """ - scenario = kwargs.pop("scenario") - click_ctx = click.get_current_context() - overrides = cli_tools.set_if_not_default(click_ctx, **kwargs) + data_request_formatter = ( + request_type + if not request_formatter_kwargs + else {"request_type": request_type, **request_formatter_kwargs} + ) - try: - # If a scenario file was specified read from it - if scenario is None: - _scenario = GenerativeTextScenario.model_validate(overrides) - elif isinstance(scenario, Path): - _scenario = GenerativeTextScenario.from_file(scenario, overrides) - else: # Only builtins can make it here; click will catch anything else - _scenario = GenerativeTextScenario.from_builtin(scenario, overrides) - except ValidationError as e: - # Translate pydantic valdation error to click argument error - errs = e.errors(include_url=False, include_context=True, include_input=True) - param_name = "--" + str(errs[0]["loc"][0]).replace("_", "-") - raise click.BadParameter( - errs[0]["msg"], ctx=click_ctx, param_hint=param_name - ) from e - - if HAS_UVLOOP: + if uvloop is not None: asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.run( benchmark_generative_text( - scenario=_scenario, + target=target, + data=list(data), + # Benchmark configuration + profile=profile, + rate=rate, + # Backend configuration + backend=backend, + backend_kwargs=backend_kwargs, + model=model, + # Data configuration + processor=processor, + processor_args=processor_args, + data_args=data_args, + data_samples=data_samples, + data_column_mapper=data_column_mappings, + data_request_formatter=data_request_formatter, + data_sampler=data_sampler, + data_num_workers=data_num_workers, + dataloader_kwargs=dataloader_kwargs, + random_seed=random_seed, # Output configuration - output_path=kwargs["output_path"], - output_formats=[ - fmt - for fmt in kwargs["output_formats"] - if not kwargs["disable_console_outputs"] or fmt != "console" - ], + output_path=output_path, + output_formats=output_formats, # Updates configuration progress=( - [ - GenerativeConsoleBenchmarkerProgress( - display_scheduler_stats=kwargs["display_scheduler_stats"] - ) - ] - if not kwargs["disable_progress"] + GenerativeConsoleBenchmarkerProgress( + display_scheduler_stats=display_scheduler_stats + ) + if not disable_progress else None ), - print_updates=not kwargs["disable_console_outputs"], - # Aggregators configuration - add_aggregators={ - "extras": InjectExtrasAggregator(extras=kwargs["output_extras"]) - }, + print_updates=not disable_console_outputs, + # Benchmarker configuration + sample_requests=sample_requests, + warmup=warmup, + cooldown=cooldown, + # Constraints configuration + max_seconds=max_seconds, + max_requests=max_requests, + max_errors=max_errors, + max_error_rate=max_error_rate, + max_global_error_rate=max_global_error_rate, ) ) diff --git a/src/guidellm/backends/__init__.py b/src/guidellm/backends/__init__.py index 064722ac..6577fa72 100644 --- a/src/guidellm/backends/__init__.py +++ b/src/guidellm/backends/__init__.py @@ -1,26 +1,33 @@ """ Backend infrastructure for GuideLLM language model interactions. -Provides abstract base classes, implemented backends, request/response objects, -and timing utilities for standardized communication with LLM providers. +Provides abstract base classes, concrete backend implementations, and response +handlers for standardized communication with generative AI model providers. +The backend system supports distributed execution across worker processes with +pluggable response handlers for different API formats. Key components include +the abstract Backend base class, OpenAI-compatible HTTP backend, and response +handlers for processing streaming and non-streaming API responses. """ -from .backend import ( - Backend, - BackendType, -) -from .objects import ( - GenerationRequest, - GenerationRequestTimings, - GenerationResponse, -) +from __future__ import annotations + +from .backend import Backend, BackendType from .openai import OpenAIHTTPBackend +from .response_handlers import ( + AudioResponseHandler, + ChatCompletionsResponseHandler, + GenerationResponseHandler, + GenerationResponseHandlerFactory, + TextCompletionsResponseHandler, +) __all__ = [ + "AudioResponseHandler", "Backend", "BackendType", - "GenerationRequest", - "GenerationRequestTimings", - "GenerationResponse", + "ChatCompletionsResponseHandler", + "GenerationResponseHandler", + "GenerationResponseHandlerFactory", "OpenAIHTTPBackend", + "TextCompletionsResponseHandler", ] diff --git a/src/guidellm/backends/backend.py b/src/guidellm/backends/backend.py index 8f91d5e7..89169a48 100644 --- a/src/guidellm/backends/backend.py +++ b/src/guidellm/backends/backend.py @@ -2,13 +2,8 @@ Backend interface and registry for generative AI model interactions. Provides the abstract base class for implementing backends that communicate with -generative AI models. Backends handle the lifecycle of generation requests. - -Classes: - Backend: Abstract base class for generative AI backends with registry support. - -Type Aliases: - BackendType: Literal type defining supported backend implementations. +generative AI models. Backends handle the lifecycle of generation requests and +provide a standard interface for distributed execution across worker processes. """ from __future__ import annotations @@ -16,11 +11,8 @@ from abc import abstractmethod from typing import Literal -from guidellm.backends.objects import ( - GenerationRequest, - GenerationResponse, -) from guidellm.scheduler import BackendInterface +from guidellm.schemas import GenerationRequest, GenerationResponse from guidellm.utils import RegistryMixin __all__ = [ @@ -37,11 +29,12 @@ class Backend( BackendInterface[GenerationRequest, GenerationResponse], ): """ - Base class for generative AI backends with registry and lifecycle. + Base class for generative AI backends with registry and lifecycle management. Provides a standard interface for backends that communicate with generative AI models. Combines the registry pattern for automatic discovery with a defined - lifecycle for process-based distributed execution. + lifecycle for process-based distributed execution. Backend state must be + pickleable for distributed execution across process boundaries. Backend lifecycle phases: 1. Creation and configuration @@ -50,9 +43,6 @@ class Backend( 4. Request resolution - Process generation requests 5. Process shutdown - Clean up resources - Backend state (excluding process_startup resources) must be pickleable for - distributed execution across process boundaries. - Example: :: @Backend.register("my_backend") @@ -72,10 +62,10 @@ def create(cls, type_: BackendType, **kwargs) -> Backend: """ Create a backend instance based on the backend type. - :param type_: The type of backend to create. - :param kwargs: Additional arguments for backend initialization. - :return: An instance of a subclass of Backend. - :raises ValueError: If the backend type is not registered. + :param type_: The type of backend to create + :param kwargs: Additional arguments for backend initialization + :return: An instance of a subclass of Backend + :raises ValueError: If the backend type is not registered """ backend = cls.get_registered_object(type_) @@ -92,28 +82,29 @@ def __init__(self, type_: BackendType): """ Initialize a backend instance. - :param type_: The backend type identifier. + :param type_: The backend type identifier """ self.type_ = type_ @property def processes_limit(self) -> int | None: """ - :return: Maximum number of worker processes supported. None if unlimited. + :return: Maximum number of worker processes supported, None if unlimited """ return None @property def requests_limit(self) -> int | None: """ - :return: Maximum number of concurrent requests supported globally. - None if unlimited. + :return: Maximum number of concurrent requests supported globally, + None if unlimited """ return None @abstractmethod async def default_model(self) -> str | None: """ - :return: The default model name or identifier for generation requests. + :return: The default model name or identifier for generation requests, + None if no default model is available """ ... diff --git a/src/guidellm/backends/objects.py b/src/guidellm/backends/objects.py deleted file mode 100644 index 001aeb70..00000000 --- a/src/guidellm/backends/objects.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -Backend object models for request and response handling. - -Provides standardized models for generation requests, responses, and timing -information to ensure consistent data handling across different backend -implementations. -""" - -import uuid -from typing import Any, Literal - -from pydantic import Field - -from guidellm.scheduler import ( - MeasuredRequestTimings, - SchedulerMessagingPydanticRegistry, -) -from guidellm.utils import StandardBaseModel - -__all__ = [ - "GenerationRequest", - "GenerationRequestTimings", - "GenerationResponse", -] - - -@SchedulerMessagingPydanticRegistry.register() -class GenerationRequest(StandardBaseModel): - """Request model for backend generation operations.""" - - request_id: str = Field( - default_factory=lambda: str(uuid.uuid4()), - description="Unique identifier for the request.", - ) - request_type: Literal["text_completions", "chat_completions"] = Field( - default="text_completions", - description=( - "Type of request. 'text_completions' uses backend.text_completions(), " - "'chat_completions' uses backend.chat_completions()." - ), - ) - content: Any = Field( - description=( - "Request content. For text_completions: string or list of strings. " - "For chat_completions: string, list of messages, or raw content " - "(set raw_content=True in params)." - ) - ) - params: dict[str, Any] = Field( - default_factory=dict, - description=( - "Additional parameters passed to backend methods. " - "Common: max_tokens, temperature, stream." - ), - ) - stats: dict[Literal["prompt_tokens"], int] = Field( - default_factory=dict, - description="Request statistics including prompt token count.", - ) - constraints: dict[Literal["output_tokens"], int] = Field( - default_factory=dict, - description="Request constraints such as maximum output tokens.", - ) - - -@SchedulerMessagingPydanticRegistry.register() -class GenerationResponse(StandardBaseModel): - """Response model for backend generation operations.""" - - request_id: str = Field( - description="Unique identifier matching the original GenerationRequest." - ) - request_args: dict[str, Any] = Field( - description="Arguments passed to the backend for this request." - ) - value: str | None = Field( - default=None, - description="Complete generated text content. None for streaming responses.", - ) - delta: str | None = Field( - default=None, description="Incremental text content for streaming responses." - ) - iterations: int = Field( - default=0, description="Number of generation iterations completed." - ) - request_prompt_tokens: int | None = Field( - default=None, description="Token count from the original request prompt." - ) - request_output_tokens: int | None = Field( - default=None, - description="Expected output token count from the original request.", - ) - response_prompt_tokens: int | None = Field( - default=None, description="Actual prompt token count reported by the backend." - ) - response_output_tokens: int | None = Field( - default=None, description="Actual output token count reported by the backend." - ) - - @property - def prompt_tokens(self) -> int | None: - """ - :return: The number of prompt tokens used in the request - (response_prompt_tokens if available, otherwise request_prompt_tokens). - """ - return self.response_prompt_tokens or self.request_prompt_tokens - - @property - def output_tokens(self) -> int | None: - """ - :return: The number of output tokens generated in the response - (response_output_tokens if available, otherwise request_output_tokens). - """ - return self.response_output_tokens or self.request_output_tokens - - @property - def total_tokens(self) -> int | None: - """ - :return: The total number of tokens used in the request and response. - Sum of prompt_tokens and output_tokens. - """ - if self.prompt_tokens is None or self.output_tokens is None: - return None - return self.prompt_tokens + self.output_tokens - - def preferred_prompt_tokens( - self, preferred_source: Literal["request", "response"] - ) -> int | None: - if preferred_source == "request": - return self.request_prompt_tokens or self.response_prompt_tokens - else: - return self.response_prompt_tokens or self.request_prompt_tokens - - def preferred_output_tokens( - self, preferred_source: Literal["request", "response"] - ) -> int | None: - if preferred_source == "request": - return self.request_output_tokens or self.response_output_tokens - else: - return self.response_output_tokens or self.request_output_tokens - - -@SchedulerMessagingPydanticRegistry.register() -@MeasuredRequestTimings.register("generation_request_timings") -class GenerationRequestTimings(MeasuredRequestTimings): - """Timing model for tracking generation request lifecycle events.""" - - timings_type: Literal["generation_request_timings"] = "generation_request_timings" - first_iteration: float | None = Field( - default=None, - description="Unix timestamp when the first generation iteration began.", - ) - last_iteration: float | None = Field( - default=None, - description="Unix timestamp when the last generation iteration completed.", - ) diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py index c8eb70f3..1e74fc6e 100644 --- a/src/guidellm/backends/openai.py +++ b/src/guidellm/backends/openai.py @@ -4,44 +4,27 @@ Provides HTTP-based backend for OpenAI-compatible servers including OpenAI API, vLLM servers, and other compatible inference engines. Supports text and chat completions with streaming, authentication, and multimodal capabilities. - -Classes: - UsageStats: Token usage statistics for generation requests. - OpenAIHTTPBackend: HTTP backend for OpenAI-compatible API servers. +Handles request formatting, response parsing, error handling, and token usage +tracking with flexible parameter customization. """ -import base64 -import contextlib -import copy -import json +from __future__ import annotations + +import asyncio import time from collections.abc import AsyncIterator -from pathlib import Path -from typing import Any, ClassVar +from typing import Any import httpx -from PIL import Image -from pydantic import dataclasses from guidellm.backends.backend import Backend -from guidellm.backends.objects import ( - GenerationRequest, - GenerationRequestTimings, - GenerationResponse, +from guidellm.backends.response_handlers import ( + GenerationResponseHandler, + GenerationResponseHandlerFactory, ) -from guidellm.scheduler import ScheduledRequestInfo - -__all__ = ["OpenAIHTTPBackend", "UsageStats"] - -ContentT = str | list[str | dict[str, str | dict[str, str]] | Path | Image.Image] | Any - +from guidellm.schemas import GenerationRequest, GenerationResponse, RequestInfo -@dataclasses.dataclass -class UsageStats: - """Token usage statistics for generation requests.""" - - prompt_tokens: int | None = None - output_tokens: int | None = None +__all__ = ["OpenAIHTTPBackend"] @Backend.register("openai_http") @@ -68,74 +51,54 @@ class OpenAIHTTPBackend(Backend): await backend.process_shutdown() """ - HEALTH_PATH: ClassVar[str] = "/health" - MODELS_PATH: ClassVar[str] = "/v1/models" - TEXT_COMPLETIONS_PATH: ClassVar[str] = "/v1/completions" - CHAT_COMPLETIONS_PATH: ClassVar[str] = "/v1/chat/completions" - - MODELS_KEY: ClassVar[str] = "models" - TEXT_COMPLETIONS_KEY: ClassVar[str] = "text_completions" - CHAT_COMPLETIONS_KEY: ClassVar[str] = "chat_completions" - def __init__( self, target: str, model: str | None = None, - api_key: str | None = None, - organization: str | None = None, - project: str | None = None, + api_routes: dict[str, str] | None = None, + response_handlers: dict[str, Any] | None = None, timeout: float = 60.0, http2: bool = True, follow_redirects: bool = True, - max_output_tokens: int | None = None, - stream_response: bool = True, - extra_query: dict | None = None, - extra_body: dict | None = None, - remove_from_body: list[str] | None = None, - headers: dict | None = None, verify: bool = False, + validate_backend: bool | str | dict[str, Any] = True, ): """ - Initialize OpenAI HTTP backend. - - :param target: Target URL for the OpenAI server (e.g., "http://localhost:8000"). - :param model: Model to use for requests. If None, uses first available model. - :param api_key: API key for authentication. Adds Authorization header - if provided. - :param organization: Organization ID. Adds OpenAI-Organization header - if provided. - :param project: Project ID. Adds OpenAI-Project header if provided. - :param timeout: Request timeout in seconds. Defaults to 60 seconds. - :param http2: Whether to use HTTP/2. Defaults to True. - :param follow_redirects: Whether to follow redirects. Default True. - :param max_output_tokens: Maximum tokens for completions. If None, none is set. - :param stream_response: Whether to stream responses by default. Can be - overridden per request. Defaults to True. - :param extra_query: Additional query parameters. Both general and - endpoint-specific with type keys supported. - :param extra_body: Additional body parameters. Both general and - endpoint-specific with type keys supported. - :param remove_from_body: Parameter names to remove from request bodies. - :param headers: Additional HTTP headers. - :param verify: Whether to verify SSL certificates. Default False. + Initialize OpenAI HTTP backend with server configuration. + + :param target: Base URL of the OpenAI-compatible server + :param model: Model identifier for generation requests + :param api_routes: Custom API endpoint routes mapping + :param response_handlers: Custom response handlers for different request types + :param timeout: Request timeout in seconds + :param http2: Enable HTTP/2 protocol support + :param follow_redirects: Follow HTTP redirects automatically + :param verify: Enable SSL certificate verification + :param validate_backend: Backend validation configuration """ super().__init__(type_="openai_http") # Request Values self.target = target.rstrip("/").removesuffix("/v1") self.model = model - self.headers = self._build_headers(api_key, organization, project, headers) # Store configuration + self.api_routes = api_routes or { + "health": "health", + "models": "v1/models", + "text_completions": "v1/completions", + "chat_completions": "v1/chat/completions", + "audio_transcriptions": "v1/audio/transcriptions", + "audio_translations": "v1/audio/translations", + } + self.response_handlers = response_handlers self.timeout = timeout self.http2 = http2 self.follow_redirects = follow_redirects self.verify = verify - self.max_output_tokens = max_output_tokens - self.stream_response = stream_response - self.extra_query = extra_query or {} - self.extra_body = extra_body or {} - self.remove_from_body = remove_from_body or [] + self.validate_backend: dict[str, Any] | None = self._resolve_validate_kwargs( + validate_backend + ) # Runtime state self._in_process = False @@ -144,33 +107,27 @@ def __init__( @property def info(self) -> dict[str, Any]: """ - :return: Dictionary containing backend configuration details. + Get backend configuration details. + + :return: Dictionary containing backend configuration details """ return { "target": self.target, "model": self.model, - "headers": self.headers, "timeout": self.timeout, "http2": self.http2, "follow_redirects": self.follow_redirects, "verify": self.verify, - "max_output_tokens": self.max_output_tokens, - "stream_response": self.stream_response, - "extra_query": self.extra_query, - "extra_body": self.extra_body, - "remove_from_body": self.remove_from_body, - "health_path": self.HEALTH_PATH, - "models_path": self.MODELS_PATH, - "text_completions_path": self.TEXT_COMPLETIONS_PATH, - "chat_completions_path": self.CHAT_COMPLETIONS_PATH, + "openai_paths": self.api_routes, + "validate_backend": self.validate_backend, } async def process_startup(self): """ Initialize HTTP client and backend resources. - :raises RuntimeError: If backend is already initialized. - :raises httpx.Exception: If HTTP client cannot be created. + :raises RuntimeError: If backend is already initialized + :raises httpx.RequestError: If HTTP client cannot be created """ if self._in_process: raise RuntimeError("Backend already started up for process.") @@ -187,8 +144,8 @@ async def process_shutdown(self): """ Clean up HTTP client and backend resources. - :raises RuntimeError: If backend was not properly initialized. - :raises httpx.Exception: If HTTP client cannot be closed. + :raises RuntimeError: If backend was not properly initialized + :raises httpx.RequestError: If HTTP client cannot be closed """ if not self._in_process: raise RuntimeError("Backend not started up for process.") @@ -199,69 +156,38 @@ async def process_shutdown(self): async def validate(self): """ - Validate backend configuration and connectivity. + Validate backend connectivity and configuration. - Validate backend configuration and connectivity through test requests, - and auto-selects first available model if none is configured. - - :raises RuntimeError: If backend cannot connect or validate configuration. + :raises RuntimeError: If backend cannot connect or validate configuration """ - self._check_in_process() - - if self.model: - with contextlib.suppress(httpx.TimeoutException, httpx.HTTPStatusError): - # Model is set, use /health endpoint as first check - target = f"{self.target}{self.HEALTH_PATH}" - headers = self._get_headers() - response = await self._async_client.get(target, headers=headers) # type: ignore [union-attr] - response.raise_for_status() - - return - - with contextlib.suppress(httpx.TimeoutException, httpx.HTTPStatusError): - # Check if models endpoint is available next - models = await self.available_models() - if models and not self.model: - self.model = models[0] - elif not self.model: - raise RuntimeError( - "No model available and could not set a default model " - "from the server's available models." - ) - - return - - with contextlib.suppress(httpx.TimeoutException, httpx.HTTPStatusError): - # Last check, fall back on dummy request to text completions - async for _, __ in self.text_completions( - prompt="Validate backend", - request_id="validate", - output_token_count=1, - stream_response=False, - ): - pass + if self._async_client is None: + raise RuntimeError("Backend not started up for process.") + if not self.validate_backend: return - raise RuntimeError( - "Backend validation failed. Could not connect to the server or " - "validate the backend configuration." - ) + try: + response = await self._async_client.request(**self.validate_backend) + response.raise_for_status() + except Exception as exc: + raise RuntimeError( + "Backend validation request failed. Could not connect to the server " + "or validate the backend configuration." + ) from exc async def available_models(self) -> list[str]: """ Get available models from the target server. - :return: List of model identifiers. - :raises HTTPError: If models endpoint returns an error. - :raises RuntimeError: If backend is not initialized. + :return: List of model identifiers + :raises httpx.HTTPError: If models endpoint returns an error + :raises RuntimeError: If backend is not initialized """ - self._check_in_process() + if self._async_client is None: + raise RuntimeError("Backend not started up for process.") - target = f"{self.target}{self.MODELS_PATH}" - headers = self._get_headers() - params = self._get_params(self.MODELS_KEY) - response = await self._async_client.get(target, headers=headers, params=params) # type: ignore [union-attr] + target = f"{self.target}/{self.api_routes['models']}" + response = await self._async_client.get(target) response.raise_for_status() return [item["id"] for item in response.json()["data"]] @@ -270,7 +196,7 @@ async def default_model(self) -> str | None: """ Get the default model for this backend. - :return: Model name or None if no model is available. + :return: Model name or None if no model is available """ if self.model or not self._in_process: return self.model @@ -281,363 +207,149 @@ async def default_model(self) -> str | None: async def resolve( self, request: GenerationRequest, - request_info: ScheduledRequestInfo, + request_info: RequestInfo, history: list[tuple[GenerationRequest, GenerationResponse]] | None = None, - ) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]: + ) -> AsyncIterator[tuple[GenerationResponse, RequestInfo]]: """ - Process a generation request and yield progressive responses. + Process generation request and yield progressive responses. Handles request formatting, timing tracking, API communication, and response parsing with streaming support. - :param request: Generation request with content and parameters. - :param request_info: Request tracking info updated with timing metadata. - :param history: Conversation history. Currently not supported. - :raises NotImplementedError: If history is provided. - :yields: Tuples of (response, updated_request_info) as generation progresses. + :param request: Generation request with content and parameters + :param request_info: Request tracking info updated with timing metadata + :param history: Conversation history (currently not supported) + :raises NotImplementedError: If history is provided + :raises RuntimeError: If backend is not initialized + :raises ValueError: If request type is unsupported + :yields: Tuples of (response, updated_request_info) as generation progresses """ - self._check_in_process() - if history is not None: - raise NotImplementedError( - "Multi-turn requests with conversation history are not yet supported" - ) + if self._async_client is None: + raise RuntimeError("Backend not started up for process.") - response = GenerationResponse( - request_id=request.request_id, - request_args={ - "request_type": request.request_type, - "output_token_count": request.constraints.get("output_tokens"), - **request.params, - }, - value="", - request_prompt_tokens=request.stats.get("prompt_tokens"), - request_output_tokens=request.constraints.get("output_tokens"), - ) - request_info.request_timings = GenerationRequestTimings() - request_info.request_timings.request_start = time.time() + if history is not None: + raise NotImplementedError("Multi-turn requests not yet supported") - completion_method = ( - self.text_completions - if request.request_type == "text_completions" - else self.chat_completions + response_handler = self._resolve_response_handler( + request_type=request.request_type ) - completion_kwargs = ( + if (request_path := self.api_routes.get(request.request_type)) is None: + raise ValueError(f"Unsupported request type '{request.request_type}'") + request_url = f"{self.target}/{request_path}" + request_files = ( { - "prompt": request.content, - "request_id": request.request_id, - "output_token_count": request.constraints.get("output_tokens"), - "stream_response": request.params.get("stream", self.stream_response), - **request.params, - } - if request.request_type == "text_completions" - else { - "content": request.content, - "request_id": request.request_id, - "output_token_count": request.constraints.get("output_tokens"), - "stream_response": request.params.get("stream", self.stream_response), - **request.params, + key: tuple(value) if isinstance(value, list) else value + for key, value in request.arguments.files.items() } + if request.arguments.files + else None ) - - async for delta, usage_stats in completion_method(**completion_kwargs): - if request_info.request_timings.request_start is None: - request_info.request_timings.request_start = time.time() - - if delta is not None: - if request_info.request_timings.first_iteration is None: - request_info.request_timings.first_iteration = time.time() - response.value += delta # type: ignore [operator] - response.delta = delta - request_info.request_timings.last_iteration = time.time() - response.iterations += 1 - - if usage_stats is not None: - request_info.request_timings.request_end = time.time() - response.response_output_tokens = usage_stats.output_tokens - response.response_prompt_tokens = usage_stats.prompt_tokens - - yield response, request_info - - if request_info.request_timings.request_end is None: - request_info.request_timings.request_end = time.time() - response.delta = None - yield response, request_info - - async def text_completions( - self, - prompt: str | list[str], - request_id: str | None, # noqa: ARG002 - output_token_count: int | None = None, - stream_response: bool = True, - **kwargs, - ) -> AsyncIterator[tuple[str | None, UsageStats | None]]: - """ - Generate text completions using the /v1/completions endpoint. - - :param prompt: Text prompt(s) for completion. Single string or list. - :param request_id: Request identifier for tracking. - :param output_token_count: Maximum tokens to generate. Overrides default - if specified. - :param stream_response: Whether to stream response progressively. - :param kwargs: Additional request parameters (temperature, top_p, etc.). - :yields: Tuples of (generated_text, usage_stats). First yield is (None, None). - :raises RuntimeError: If backend is not initialized. - :raises HTTPError: If API request fails. - """ - self._check_in_process() - target = f"{self.target}{self.TEXT_COMPLETIONS_PATH}" - headers = self._get_headers() - params = self._get_params(self.TEXT_COMPLETIONS_KEY) - body = self._get_body( - endpoint_type=self.TEXT_COMPLETIONS_KEY, - request_kwargs=kwargs, - max_output_tokens=output_token_count, - prompt=prompt, - ) - yield None, None # Initial yield for async iterator to signal start - - if not stream_response: - response = await self._async_client.post( # type: ignore [union-attr] - target, - headers=headers, - params=params, - json=body, + request_json = request.arguments.body if not request_files else None + request_data = request.arguments.body if request_files else None + + if not request.arguments.stream: + request_info.timings.request_start = time.time() + response = await self._async_client.request( + request.arguments.method or "POST", + request_url, + params=request.arguments.params, + headers=request.arguments.headers, + json=request_json, + data=request_data, + files=request_files, ) + request_info.timings.request_end = time.time() response.raise_for_status() data = response.json() - yield ( - self._get_completions_text_content(data), - self._get_completions_usage_stats(data), - ) + yield response_handler.compile_non_streaming(request, data), request_info return - body.update({"stream": True, "stream_options": {"include_usage": True}}) - async with self._async_client.stream( # type: ignore [union-attr] - "POST", - target, - headers=headers, - params=params, - json=body, - ) as stream: - stream.raise_for_status() - async for line in stream.aiter_lines(): - if not line or not line.strip().startswith("data:"): - continue - if line.strip() == "data: [DONE]": - break - data = json.loads(line.strip()[len("data: ") :]) - yield ( - self._get_completions_text_content(data), - self._get_completions_usage_stats(data), - ) - - async def chat_completions( - self, - content: ContentT, - request_id: str | None = None, # noqa: ARG002 - output_token_count: int | None = None, - raw_content: bool = False, - stream_response: bool = True, - **kwargs, - ) -> AsyncIterator[tuple[str | None, UsageStats | None]]: - """ - Generate chat completions using the /v1/chat/completions endpoint. - - Supports multimodal inputs including text and images with message formatting. - - :param content: Chat content - string, list of mixed content, or raw content - when raw_content=True. - :param request_id: Request identifier (currently unused). - :param output_token_count: Maximum tokens to generate. Overrides default - if specified. - :param raw_content: If True, passes content directly without formatting. - :param stream_response: Whether to stream response progressively. - :param kwargs: Additional request parameters (temperature, top_p, tools, etc.). - :yields: Tuples of (generated_text, usage_stats). First yield is (None, None). - :raises RuntimeError: If backend is not initialized. - :raises HTTPError: If API request fails. - """ - self._check_in_process() - target = f"{self.target}{self.CHAT_COMPLETIONS_PATH}" - headers = self._get_headers() - params = self._get_params(self.CHAT_COMPLETIONS_KEY) - body = self._get_body( - endpoint_type=self.CHAT_COMPLETIONS_KEY, - request_kwargs=kwargs, - max_output_tokens=output_token_count, - messages=self._get_chat_messages(content) if not raw_content else content, - **kwargs, - ) - yield None, None # Initial yield for async iterator to signal start - - if not stream_response: - response = await self._async_client.post( # type: ignore [union-attr] - target, headers=headers, params=params, json=body - ) - response.raise_for_status() - data = response.json() - yield ( - self._get_completions_text_content(data), - self._get_completions_usage_stats(data), - ) - return - - body.update({"stream": True, "stream_options": {"include_usage": True}}) - async with self._async_client.stream( # type: ignore [union-attr] - "POST", target, headers=headers, params=params, json=body - ) as stream: - stream.raise_for_status() - async for line in stream.aiter_lines(): - if not line or not line.strip().startswith("data:"): - continue - if line.strip() == "data: [DONE]": - break - data = json.loads(line.strip()[len("data: ") :]) - yield ( - self._get_completions_text_content(data), - self._get_completions_usage_stats(data), - ) - - def _build_headers( - self, - api_key: str | None, - organization: str | None, - project: str | None, - user_headers: dict | None, - ) -> dict[str, str]: - headers = {} - - if api_key: - headers["Authorization"] = ( - f"Bearer {api_key}" if not api_key.startswith("Bearer") else api_key - ) - if organization: - headers["OpenAI-Organization"] = organization - if project: - headers["OpenAI-Project"] = project - if user_headers: - headers.update(user_headers) - - return {key: val for key, val in headers.items() if val is not None} - - def _check_in_process(self): - if not self._in_process or self._async_client is None: - raise RuntimeError( - "Backend not started up for process, cannot process requests." - ) - - def _get_headers(self) -> dict[str, str]: - return { - "Content-Type": "application/json", - **self.headers, - } + try: + request_info.timings.request_start = time.time() + + async with self._async_client.stream( + request.arguments.method or "POST", + request_url, + params=request.arguments.params, + headers=request.arguments.headers, + json=request_json, + data=request_data, + files=request_files, + ) as stream: + stream.raise_for_status() + end_reached = False + + async for chunk in stream.aiter_lines(): + iter_time = time.time() + + if ( + (iterations := response_handler.add_streaming_line(chunk)) + is None + or iterations < 0 + or end_reached + ): + end_reached = end_reached or iterations is None + continue + + if ( + request_info.timings.first_iteration is None + or request_info.timings.iterations is None + ): + request_info.timings.first_iteration = iter_time + request_info.timings.iterations = 0 + + request_info.timings.last_iteration = iter_time + request_info.timings.iterations += iterations + + request_info.timings.request_end = time.time() + yield response_handler.compile_streaming(request), request_info + except asyncio.CancelledError as err: + # Yield current result to store iterative results before propagating + yield response_handler.compile_streaming(request), request_info + raise err + + def _resolve_validate_kwargs( + self, validate_backend: bool | str | dict[str, Any] + ) -> dict[str, Any] | None: + if not (validate_kwargs := validate_backend): + return None - def _get_params(self, endpoint_type: str) -> dict[str, str]: - if endpoint_type in self.extra_query: - return copy.deepcopy(self.extra_query[endpoint_type]) - return copy.deepcopy(self.extra_query) + if validate_kwargs is True: + validate_kwargs = "health" - def _get_chat_messages( - self, - content: ContentT, - ) -> list[dict[str, Any]]: - if isinstance(content, str): - return [{"role": "user", "content": content}] - - if not isinstance(content, list): - raise ValueError(f"Unsupported content type: {type(content)}") - - resolved_content = [] - for item in content: - if isinstance(item, dict): - resolved_content.append(item) - elif isinstance(item, str): - resolved_content.append({"type": "text", "text": item}) - elif isinstance(item, Image.Image | Path): - resolved_content.append(self._get_chat_message_media_item(item)) - else: - raise ValueError(f"Unsupported content item type: {type(item)}") - - return [{"role": "user", "content": resolved_content}] - - def _get_chat_message_media_item( - self, item: Path | Image.Image - ) -> dict[str, Any]: - if isinstance(item, Image.Image): - encoded = base64.b64encode(item.tobytes()).decode("utf-8") - return { - "type": "image", - "image": {"url": f"data:image/jpeg;base64,{encoded}"}, - } + if isinstance(validate_kwargs, str) and validate_kwargs in self.api_routes: + validate_kwargs = f"{self.target}/{self.api_routes[validate_kwargs]}" - # Handle file paths - suffix = item.suffix.lower() - if suffix in [".jpg", ".jpeg"]: - image = Image.open(item) - encoded = base64.b64encode(image.tobytes()).decode("utf-8") - return { - "type": "image", - "image": {"url": f"data:image/jpeg;base64,{encoded}"}, - } - elif suffix == ".wav": - encoded = base64.b64encode(item.read_bytes()).decode("utf-8") - return { - "type": "input_audio", - "input_audio": {"data": encoded, "format": "wav"}, + if isinstance(validate_kwargs, str): + validate_kwargs = { + "method": "GET", + "url": validate_kwargs, } - else: - raise ValueError(f"Unsupported file type: {suffix}") - def _get_body( - self, - endpoint_type: str, - request_kwargs: dict[str, Any] | None, - max_output_tokens: int | None = None, - **kwargs, - ) -> dict[str, Any]: - # Start with endpoint-specific extra body parameters - extra_body: dict = self.extra_body.get(endpoint_type, self.extra_body) - - body = copy.deepcopy(extra_body) - body.update(request_kwargs or {}) - body.update(kwargs) - body["model"] = self.model - - # Handle token limits - max_tokens = max_output_tokens or self.max_output_tokens - if max_tokens is not None: - body.update( - { - "max_tokens": max_tokens, - "max_completion_tokens": max_tokens, - } + if not isinstance(validate_kwargs, dict) or "url" not in validate_kwargs: + raise ValueError( + "validate_backend must be a boolean, string, or dictionary and contain " + f"a target URL. Got: {validate_kwargs}" ) - # Set stop conditions only for request-level limits - if max_output_tokens: - body.update({"stop": None, "ignore_eos": True}) - if self.remove_from_body: - for key in self.remove_from_body: - body.pop(key, None) + if "method" not in validate_kwargs: + validate_kwargs["method"] = "GET" - return {key: val for key, val in body.items() if val is not None} + return validate_kwargs - def _get_completions_text_content(self, data: dict) -> str | None: - if not data.get("choices"): - return None + def _resolve_response_handler(self, request_type: str) -> GenerationResponseHandler: + if ( + self.response_handlers is not None + and (handler := self.response_handlers.get(request_type)) is not None + ): + return handler - choice: dict = data["choices"][0] - return ( - choice.get("text") - or choice.get("delta", {}).get("content") - or choice.get("message", {}).get("content") + handler_class = GenerationResponseHandlerFactory.get_registered_object( + request_type ) + if handler_class is None: + raise ValueError( + f"No response handler registered for request type '{request_type}'" + ) - def _get_completions_usage_stats(self, data: dict) -> UsageStats | None: - if not data.get("usage"): - return None - - return UsageStats( - prompt_tokens=data["usage"].get("prompt_tokens"), - output_tokens=data["usage"].get("completion_tokens"), - ) + return handler_class() diff --git a/src/guidellm/backends/response_handlers.py b/src/guidellm/backends/response_handlers.py new file mode 100644 index 00000000..44c949e6 --- /dev/null +++ b/src/guidellm/backends/response_handlers.py @@ -0,0 +1,456 @@ +""" +Response handlers for processing API responses from different generation backends. + +This module provides a pluggable system for handling responses from various language +model backends, supporting both streaming and non-streaming responses. Each handler +implements the GenerationResponseHandler protocol to parse API responses, extract +usage metrics, and convert them into standardized GenerationResponse objects for the +benchmark system. +""" + +from __future__ import annotations + +from typing import Any, Protocol + +from guidellm.schemas import GenerationRequest, GenerationResponse, UsageMetrics +from guidellm.utils import RegistryMixin, json + +__all__ = [ + "AudioResponseHandler", + "ChatCompletionsResponseHandler", + "GenerationResponseHandler", + "GenerationResponseHandlerFactory", + "TextCompletionsResponseHandler", +] + + +class GenerationResponseHandler(Protocol): + """ + Protocol defining the interface for handling generation API responses. + + Response handlers implement this protocol to process both streaming and + non-streaming responses from different backend APIs, converting them into + standardized GenerationResponse objects with consistent metrics extraction. + """ + + def compile_non_streaming( + self, request: GenerationRequest, response: Any + ) -> GenerationResponse: + """ + Process a complete non-streaming API response. + + :param request: The original generation request + :param response: Raw API response data from the backend + :return: Standardized GenerationResponse with extracted metrics + """ + ... + + def add_streaming_line(self, line: str) -> int | None: + """ + Process a single line from a streaming response. + + :param line: Raw line from the streaming response + :return: 1 if content was updated, 0 if line was ignored, None if done + """ + ... + + def compile_streaming(self, request: GenerationRequest) -> GenerationResponse: + """ + Compile accumulated streaming data into a final response. + + :param request: The original generation request + :return: Standardized GenerationResponse with extracted metrics + """ + ... + + +class GenerationResponseHandlerFactory(RegistryMixin[type[GenerationResponseHandler]]): + """ + Factory for registering and creating response handlers by backend type. + + Provides a registry-based system for associating handler classes with specific + backend API types, enabling automatic selection of the appropriate handler + for processing responses from different generation services. + """ + + +@GenerationResponseHandlerFactory.register("text_completions") +class TextCompletionsResponseHandler(GenerationResponseHandler): + """ + Response handler for OpenAI-style text completion endpoints. + + Processes responses from text completion APIs that return generated text + in the 'choices' array with 'text' fields. Handles both streaming and + non-streaming responses, extracting usage metrics for input and output tokens. + + Example: + :: + handler = TextCompletionsResponseHandler() + response = handler.compile_non_streaming(request, api_response) + """ + + def __init__(self): + """ + Initialize the text completions response handler. + + Sets up internal state for accumulating streaming response data including + text chunks and usage metrics. + """ + self.streaming_texts: list[str] = [] + self.streaming_usage: dict[str, int | dict[str, int]] | None = None + + def compile_non_streaming( + self, request: GenerationRequest, response: dict + ) -> GenerationResponse: + """ + Process a complete text completion response. + + :param request: The original generation request + :param response: Complete API response containing choices and usage data + :return: Standardized GenerationResponse with extracted text and metrics + """ + choices, usage = self.extract_choices_and_usage(response) + input_metrics, output_metrics = self.extract_metrics(usage) + + return GenerationResponse( + request_id=request.request_id, + request_args=str( + request.arguments.model_dump() if request.arguments else None + ), + text=choices[0].get("text", "") if choices else "", + input_metrics=input_metrics, + output_metrics=output_metrics, + ) + + def add_streaming_line(self, line: str) -> int | None: + """ + Process a single line from a text completion streaming response. + + Parses Server-Sent Events (SSE) formatted lines, extracting text content + and usage metrics. Accumulates text chunks for final response compilation. + + :param line: Raw SSE line from the streaming response + :return: 1 if text content was extracted, 0 if line ignored, None if done + """ + if not (data := self.extract_line_data(line)): + return None if data is None else 0 + + updated = False + choices, usage = self.extract_choices_and_usage(data) + + if text := choices[0].get("text"): + self.streaming_texts.append(text) + updated = True + + if usage: + self.streaming_usage = usage + + return 1 if updated else 0 + + def compile_streaming(self, request: GenerationRequest) -> GenerationResponse: + """ + Compile accumulated streaming text chunks into a final response. + + :param request: The original generation request + :return: Standardized GenerationResponse with concatenated text and metrics + """ + input_metrics, output_metrics = self.extract_metrics(self.streaming_usage) + + return GenerationResponse( + request_id=request.request_id, + request_args=str( + request.arguments.model_dump() if request.arguments else None + ), + text="".join(self.streaming_texts), + input_metrics=input_metrics, + output_metrics=output_metrics, + ) + + def extract_line_data(self, line: str) -> dict[str, Any] | None: + """ + Extract JSON data from a streaming response line. + + :param line: Raw line from the streaming response + :return: Parsed JSON data as a dictionary, or None if line is invalid + """ + if line == "data: [DONE]": + return None + + if not line or not (line := line.strip()) or not line.startswith("data:"): + return {} + + line = line[len("data:") :].strip() + + return json.loads(line) + + def extract_choices_and_usage( + self, response: dict + ) -> tuple[list[dict], dict[str, int | dict[str, int]]]: + """ + Extract choices and usage data from the API response. + + :param response: Complete API response containing choices and usage data + :return: Tuple of (choices list, usage dictionary) + """ + return response.get("choices", []), response.get("usage", {}) + + def extract_metrics( + self, usage: dict[str, int | dict[str, int]] | None + ) -> tuple[UsageMetrics, UsageMetrics]: + """ + Extract input and output usage metrics from API response usage data. + + :param usage: Usage data dictionary from API response + :return: Tuple of (input_metrics, output_metrics) as UsageMetrics objects + """ + if not usage: + return UsageMetrics(), UsageMetrics() + + input_details: dict[str, int] = usage.get("prompt_tokens_details", {}) or {} + output_details: dict[str, int] = ( + usage.get("completion_tokens_details", {}) or {} + ) + + return UsageMetrics( + text_tokens=( + input_details.get("prompt_tokens") or usage.get("prompt_tokens") + ), + image_tokens=input_details.get("image_tokens"), + video_tokens=input_details.get("video_tokens"), + audio_tokens=input_details.get("audio_tokens"), + audio_seconds=input_details.get("seconds"), + ), UsageMetrics( + text_tokens=( + output_details.get("completion_tokens") + or usage.get("completion_tokens") + ), + image_tokens=output_details.get("image_tokens"), + video_tokens=output_details.get("video_tokens"), + audio_tokens=output_details.get("audio_tokens"), + audio_seconds=output_details.get("seconds"), + ) + + +@GenerationResponseHandlerFactory.register("chat_completions") +class ChatCompletionsResponseHandler(TextCompletionsResponseHandler): + """ + Response handler for OpenAI-style chat completion endpoints. + + Extends TextCompletionsResponseHandler to handle chat completion responses + where generated text is nested within message objects in the choices array. + Processes both streaming and non-streaming chat completion responses. + """ + + def compile_non_streaming( + self, request: GenerationRequest, response: dict + ) -> GenerationResponse: + """ + Process a complete chat completion response. + + Extracts content from the message object within choices, handling the + nested structure specific to chat completion endpoints. + + :param request: The original generation request + :param response: Complete API response containing choices and usage data + :return: Standardized GenerationResponse with extracted content and metrics + """ + choices, usage = self.extract_choices_and_usage(response) + input_metrics, output_metrics = self.extract_metrics(usage) + + return GenerationResponse( + request_id=request.request_id, + request_args=str( + request.arguments.model_dump() if request.arguments else None + ), + text=(choices[0].get("message", {}).get("content", "") if choices else ""), + input_metrics=input_metrics, + output_metrics=output_metrics, + ) + + def add_streaming_line(self, line: str) -> int | None: + """ + Process a single line from a chat completion streaming response. + + Handles the chat completion specific delta structure where content + is nested within delta objects in the streaming response chunks. + + :param line: Raw SSE line from the streaming response + :return: 1 if content was extracted, 0 if line ignored, None if done + """ + if not (data := self.extract_line_data(line)): + return None if data is None else 0 + + updated = False + choices, usage = self.extract_choices_and_usage(data) + + if choices and (content := choices[0].get("delta", {}).get("content")): + self.streaming_texts.append(content) + updated = True + + if usage: + self.streaming_usage = usage + + return 1 if updated else 0 + + def compile_streaming(self, request: GenerationRequest) -> GenerationResponse: + """ + Compile accumulated streaming chat completion content into a final response. + + :param request: The original generation request + :return: Standardized GenerationResponse with concatenated content and metrics + """ + input_metrics, output_metrics = self.extract_metrics(self.streaming_usage) + + return GenerationResponse( + request_id=request.request_id, + request_args=str( + request.arguments.model_dump() if request.arguments else None + ), + text="".join(self.streaming_texts), + input_metrics=input_metrics, + output_metrics=output_metrics, + ) + + +@GenerationResponseHandlerFactory.register( + ["audio_transcriptions", "audio_translations"] +) +class AudioResponseHandler: + """ + Response handler for audio transcription and translation endpoints. + + Processes responses from audio processing APIs that convert speech to text, + handling both transcription and translation services. Manages audio-specific + usage metrics including audio tokens and processing duration. + + Example: + :: + handler = AudioResponseHandler() + response = handler.compile_non_streaming(request, api_response) + """ + + def __init__(self): + """ + Initialize the audio response handler. + + Sets up internal state for accumulating streaming response data including + audio buffers, text chunks, and usage metrics. + """ + self.streaming_buffer: bytearray = bytearray() + self.streaming_texts: list[str] = [] + self.streaming_usage: dict[str, int | dict[str, int]] | None = None + + def compile_non_streaming( + self, request: GenerationRequest, response: dict + ) -> GenerationResponse: + """ + Process a complete audio transcription or translation response. + + Extracts transcribed or translated text and audio-specific usage metrics + including processing duration and token counts for audio content. + + :param request: The original generation request + :param response: Complete API response containing text and usage data + :return: Standardized GenerationResponse with extracted text and metrics + """ + usage: dict[str, int | dict[str, int]] = response.get("usage", {}) + input_details: dict[str, int] = usage.get("input_token_details", {}) or {} + output_details: dict[str, int] = usage.get("output_token_details", {}) or {} + text: str = response.get("text", "") + + return GenerationResponse( + request_id=request.request_id, + request_args=str( + request.arguments.model_dump() if request.arguments else None + ), + text=text, + input_metrics=UsageMetrics( + text_tokens=input_details.get("text_tokens", usage.get("input_tokens")), + audio_tokens=input_details.get( + "audio_tokens", usage.get("input_tokens") + ), + audio_seconds=input_details.get("seconds", usage.get("seconds")), + ), + output_metrics=UsageMetrics( + text_tokens=output_details.get( + "text_tokens", usage.get("output_tokens") + ), + ), + ) + + def add_streaming_line(self, line: str) -> int | None: + """ + Process a single line from an audio streaming response. + + Handles JSON-formatted streaming responses from audio processing endpoints, + extracting text content and usage metrics as they become available. + + :param line: Raw JSON line from the streaming response + :return: 1 if text content was extracted, 0 if line ignored, None if done + """ + if line == "data: [DONE]": + return None + + if not line or not (line := line.strip()) or not line.startswith("{"): + return 0 + + data: dict[str, Any] = json.loads(line) + text: str + usage: dict[str, int | dict[str, int]] + updated = False + + if text := data.get("text"): + self.streaming_texts.append(text) + updated = True + + if usage := data.get("usage"): + self.streaming_usage = usage + + return 1 if updated else 0 + + def compile_streaming(self, request: GenerationRequest) -> GenerationResponse: + """ + Compile accumulated streaming audio text into a final response. + + :param request: The original generation request + :return: Standardized GenerationResponse with concatenated text and metrics + """ + input_metrics, output_metrics = self.extract_metrics(self.streaming_usage) + + return GenerationResponse( + request_id=request.request_id, + request_args=str( + request.arguments.model_dump() if request.arguments else None + ), + text="".join(self.streaming_texts), + input_metrics=input_metrics, + output_metrics=output_metrics, + ) + + def extract_metrics( + self, usage: dict[str, int | dict[str, int]] | None + ) -> tuple[UsageMetrics, UsageMetrics]: + """ + Extract input and output usage metrics from audio API response usage data. + + Handles audio-specific metrics including processing duration and audio tokens + in addition to standard text token counts. + + :param usage: Usage data dictionary from audio API response + :return: Tuple of (input_metrics, output_metrics) as UsageMetrics objects + """ + if not usage: + return UsageMetrics(), UsageMetrics() + + input_details: dict[str, int] = usage.get("input_token_details", {}) or {} + output_details: dict[str, int] = usage.get("output_token_details", {}) or {} + + return UsageMetrics( + text_tokens=(input_details.get("text_tokens") or usage.get("input_tokens")), + audio_tokens=( + input_details.get("audio_tokens") or usage.get("audio_tokens") + ), + audio_seconds=(input_details.get("seconds") or usage.get("seconds")), + ), UsageMetrics( + text_tokens=output_details.get("text_tokens") or usage.get("output_tokens"), + ) diff --git a/src/guidellm/benchmark/__init__.py b/src/guidellm/benchmark/__init__.py index 9fdb231d..4c7cc4a5 100644 --- a/src/guidellm/benchmark/__init__.py +++ b/src/guidellm/benchmark/__init__.py @@ -1,25 +1,5 @@ -from .aggregator import ( - Aggregator, - AggregatorState, - CompilableAggregator, - GenerativeRequestsAggregator, - GenerativeStatsProgressAggregator, - InjectExtrasAggregator, - SchedulerStatsAggregator, - SerializableAggregator, -) from .benchmarker import Benchmarker from .entrypoints import benchmark_generative_text, reimport_benchmarks_report -from .objects import ( - Benchmark, - BenchmarkMetrics, - BenchmarkSchedulerStats, - BenchmarkT, - GenerativeBenchmark, - GenerativeBenchmarksReport, - GenerativeMetrics, - GenerativeRequestStats, -) from .output import ( GenerativeBenchmarkerConsole, GenerativeBenchmarkerCSV, @@ -35,40 +15,34 @@ SynchronousProfile, ThroughputProfile, ) -from .progress import ( - BenchmarkerProgress, - BenchmarkerProgressGroup, - GenerativeConsoleBenchmarkerProgress, -) -from .scenario import ( - GenerativeTextScenario, - Scenario, - enable_scenarios, - get_builtin_scenarios, -) -from .types import ( - AggregatorInputT, - DataInputT, - OutputFormatT, - ProcessorInputT, - ProgressInputT, +from .progress import BenchmarkerProgress, GenerativeConsoleBenchmarkerProgress +from .schemas import ( + Benchmark, + BenchmarkArgs, + BenchmarkerDict, + BenchmarkSchedulerStats, + EstimatedBenchmarkState, + GenerativeAudioMetricsSummary, + GenerativeBenchmark, + GenerativeBenchmarksReport, + GenerativeImageMetricsSummary, + GenerativeMetrics, + GenerativeMetricsSummary, + GenerativeVideoMetricsSummary, + SchedulerDict, ) __all__ = [ - "Aggregator", - "AggregatorInputT", - "AggregatorState", "AsyncProfile", "Benchmark", - "BenchmarkMetrics", + "BenchmarkArgs", "BenchmarkSchedulerStats", - "BenchmarkT", "Benchmarker", + "BenchmarkerDict", "BenchmarkerProgress", - "BenchmarkerProgressGroup", - "CompilableAggregator", "ConcurrentProfile", - "DataInputT", + "EstimatedBenchmarkState", + "GenerativeAudioMetricsSummary", "GenerativeBenchmark", "GenerativeBenchmarkerCSV", "GenerativeBenchmarkerConsole", @@ -76,20 +50,13 @@ "GenerativeBenchmarkerOutput", "GenerativeBenchmarksReport", "GenerativeConsoleBenchmarkerProgress", + "GenerativeImageMetricsSummary", "GenerativeMetrics", - "GenerativeRequestStats", - "GenerativeRequestsAggregator", - "GenerativeStatsProgressAggregator", - "GenerativeTextScenario", - "InjectExtrasAggregator", - "OutputFormatT", - "ProcessorInputT", + "GenerativeMetricsSummary", + "GenerativeVideoMetricsSummary", "Profile", "ProfileType", - "ProgressInputT", - "Scenario", - "SchedulerStatsAggregator", - "SerializableAggregator", + "SchedulerDict", "SweepProfile", "SynchronousProfile", "ThroughputProfile", diff --git a/src/guidellm/benchmark/aggregator.py b/src/guidellm/benchmark/aggregator.py deleted file mode 100644 index b33a7b14..00000000 --- a/src/guidellm/benchmark/aggregator.py +++ /dev/null @@ -1,1260 +0,0 @@ -""" -Benchmark result aggregation and compilation interfaces. - -Provides protocols and implementations for collecting, processing, and compiling -benchmark data from scheduler executions into final metrics and statistics. - -Classes: - Aggregator: Protocol for processing benchmark data updates. - CompilableAggregator: Protocol for aggregators that can compile final results. - SchedulerStatsAggregator: Aggregates scheduler timing and performance metrics. - GenerativeRequestsStatsProgressAggregator: Tracks generation metrics during run. - GenerativeRequestsAggregator: Compiles complete generative benchmark results. - -Functions: - add_aggregate_metric: Helper for accumulating timing and count metrics. - -Type Variables: - RequestT: Generic request object type. - ResponseT: Generic response object type. - RequestTimingsT: Generic request timing object type. -""" - -from __future__ import annotations - -import math -import random -from abc import ABC, abstractmethod -from typing import ( - Any, - ClassVar, - Generic, - Literal, - Protocol, - runtime_checkable, -) - -from pydantic import Field, PrivateAttr - -from guidellm.backends import ( - GenerationRequest, - GenerationResponse, -) -from guidellm.benchmark.objects import ( - BenchmarkSchedulerStats, - GenerativeMetrics, - GenerativeRequestStats, -) -from guidellm.scheduler import ( - RequestT, - ResponseT, - ScheduledRequestInfo, - SchedulerState, -) -from guidellm.settings import settings -from guidellm.utils import ( - InfoMixin, - PydanticClassRegistryMixin, - StatusBreakdown, - StatusDistributionSummary, - all_defined, - safe_divide, - safe_getattr, -) - -__all__ = [ - "Aggregator", - "AggregatorState", - "CompilableAggregator", - "GenerativeRequestsAggregator", - "GenerativeStatsProgressAggregator", - "InjectExtrasAggregator", - "SchedulerStatsAggregator", - "SerializableAggregator", -] - - -class AggregatorState(dict[str, Any]): - def add_metric( - self, - key: str, - value: int | float | None, - start_val: int | float | None = 0.0, - count: int | None = 1, - duration: float | None = None, - duration_div: Literal["total", "avg"] = "total", - prefix: str | None = None, - ): - """ - Add timing or count metrics to aggregation state. - """ - if prefix: - self.add_metric( - key=f"{prefix}_{key}", - value=value, - start_val=start_val, - count=count, - duration=duration, - duration_div=duration_div, - ) - return - - if not all_defined(value, start_val, count): - return - - delta_val = value - start_val - self[f"{key}_total"] = self.get(f"{key}_total", 0) + delta_val - self[f"{key}_count"] = self.get(f"{key}_count", 0) + count - self[f"{key}_avg"] = safe_divide( - self.get(f"{key}_total"), self.get(f"{key}_count") - ) - - if all_defined(duration): - self[f"{key}_duration"] = duration - self[f"{key}_rate"] = safe_divide( - self.get(f"{key}_{duration_div}"), duration - ) - - def set_metric( - self, - key: str, - value: int | float | None, - type_: Literal["total", "count", "avg", "duration", "rate"], - prefix: str | None = None, - ): - if prefix: - self.set_metric( - key=f"{prefix}_{key}", - value=value, - type_=type_, - prefix=None, - ) - return - - self[f"{key}_{type_}"] = value - - def get_metric( - self, - key: str, - type_: Literal["total", "count", "avg", "duration", "rate"], - default: int | float | None = None, - prefix: str | None = None, - ) -> int | float | None: - if prefix: - return self.get_metric( - key=f"{prefix}_{key}", - type_=type_, - default=default, - ) - - return self.get(f"{key}_{type_}", default) - - -@runtime_checkable -class Aggregator(Protocol[ResponseT, RequestT]): - """ - Protocol for processing benchmark data updates during execution. - - Defines the interface for aggregators that collect and process request/response - data from scheduler executions. Implementations update aggregation state with - each completed request for eventual compilation into final metrics. - """ - - def __call__( - self, - state: AggregatorState, - response: ResponseT | None, - request: RequestT, - request_info: ScheduledRequestInfo, - scheduler_state: SchedulerState, - ) -> dict[str, Any] | None: - """ - Process a completed request and update aggregation state. - - :param state: Current aggregation state to update in-place. - :param response: Response generated for the request, if successful. - :param request: The processed request object. - :param request_info: Scheduling metadata and timing information. - :param scheduler_state: Current scheduler execution state. - :return: Optional intermediate updates for progress reporting. - """ - - -@runtime_checkable -class CompilableAggregator(Protocol[ResponseT, RequestT]): - """ - Protocol for aggregators that compile final results from aggregated state. - - Extends the Aggregator protocol with the ability to transform accumulated - state into final benchmark results and metrics after execution completes. - """ - - def __call__( - self, - state: AggregatorState, - response: ResponseT | None, - request: RequestT, - request_info: ScheduledRequestInfo, - scheduler_state: SchedulerState, - ) -> dict[str, Any] | None: - """ - Process a completed request and update aggregation state. - - :param state: Current aggregation state to update in-place. - :param response: Response generated for the request, if successful. - :param request: The processed request object. - :param request_info: Scheduling metadata and timing information. - :param scheduler_state: Current scheduler execution state. - :return: Optional intermediate updates for progress reporting. - """ - - def compile( - self, state: AggregatorState, scheduler_state: SchedulerState - ) -> dict[str, Any]: - """ - Compile aggregated state into final benchmark results. - - :param agg_state: The accumulated aggregation state. - :param scheduler_state: Final scheduler execution state. - :return: Compiled benchmark results and metrics. - """ - - -class SerializableAggregator( - PydanticClassRegistryMixin[type["SerializableAggregator"]], - ABC, - Generic[ResponseT, RequestT], -): - schema_discriminator: ClassVar[str] = "type_" - - @classmethod - def __pydantic_schema_base_type__(cls) -> type[SerializableAggregator]: - if cls.__name__ == "SerializableAggregator": - return cls - - return SerializableAggregator - - @classmethod - @abstractmethod - def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: - """ - Validate and process arguments for constraint creation. - - Must be implemented by subclasses to handle their specific parameter patterns. - - :param args: Positional arguments passed to the constraint - :param kwargs: Keyword arguments passed to the constraint - :return: Validated dictionary of parameters for constraint creation - :raises NotImplementedError: Must be implemented by subclasses - """ - ... - - @classmethod - def resolve( - cls, - aggregators: dict[ - str, - Any | dict[str, Any] | Aggregator | CompilableAggregator, - ], - ) -> dict[str, Aggregator | CompilableAggregator]: - """ - Resolve mixed aggregator specifications to callable aggregators. - - :param aggregators: Dictionary mapping aggregator keys to specifications - :return: Dictionary mapping aggregator keys to callable functions - :raises ValueError: If any key is not registered in the factory - """ - resolved = {} - - for key, val in aggregators.items(): - if isinstance(val, Aggregator | CompilableAggregator): - resolved[key] = val - else: - aggregator_class = cls.get_registered_object(key) - kwargs = aggregator_class.validated_kwargs(**val) - resolved[key] = aggregator_class(**kwargs) - - return resolved - - type_: Literal["aggregator"] = Field(default="aggregator", description="") - - @abstractmethod - def __call__( - self, - state: AggregatorState, - response: ResponseT | None, - request: RequestT, - request_info: ScheduledRequestInfo, - scheduler_state: SchedulerState, - ) -> dict[str, Any] | None: - """ - Process a completed request and update aggregation state. - - :param agg_state: Current aggregation state to update in-place. - :param response: Response generated for the request, if successful. - :param request: The processed request object. - :param request_info: Scheduling metadata and timing information. - :param scheduler_state: Current scheduler execution state. - :return: Optional intermediate updates for progress reporting. - """ - - @abstractmethod - def compile( - self, state: AggregatorState, scheduler_state: SchedulerState - ) -> dict[str, Any]: - """ - Compile aggregated state into final benchmark results. - - :param agg_state: The accumulated aggregation state. - :param scheduler_state: Final scheduler execution state. - :return: Compiled benchmark results and metrics. - """ - - -@SerializableAggregator.register("inject_extras") -class InjectExtrasAggregator(SerializableAggregator[ResponseT, RequestT], InfoMixin): - """ - Aggregator for injecting extra metadata into the output. - """ - - @classmethod - def validated_kwargs(cls, extras: dict[str, Any], **_kwargs) -> dict[str, Any]: - return {"extras": extras} - - type_: Literal["inject_extras"] = Field(default="inject_extras") - extras: dict[str, Any] | None = Field(default_factory=None) - - def __call__( - self, - state: AggregatorState, - response: ResponseT | None, - request: RequestT, - request_info: ScheduledRequestInfo, - scheduler_state: SchedulerState, - ) -> dict[str, Any] | None: - """ - Inject extra metadata into the aggregation state. - - :param agg_state: Current aggregation state to update. - :param response: Response generated for the request, if successful. - :param request: The processed request object. - :param request_info: Scheduling metadata and timing information. - :param scheduler_state: Current scheduler execution state. - :return: Updated aggregation state with injected extras. - """ - _ = (state, response, request, request_info, scheduler_state) # unused - return None - - def compile( - self, state: AggregatorState, scheduler_state: SchedulerState - ) -> dict[str, Any]: - _ = (state, scheduler_state) # unused - return {"extras": self.extras} if self.extras else {} - - -@SerializableAggregator.register("scheduler_stats") -class SchedulerStatsAggregator(SerializableAggregator[ResponseT, RequestT], InfoMixin): - """ - Aggregates scheduler timing and performance metrics. - - Collects timing data for various scheduler phases including queuing, - resolution, and processing delays to generate performance statistics. - """ - - @classmethod - def validated_kwargs(cls, *_args, **_kwargs) -> dict[str, Any]: - return {} - - type_: Literal["scheduler_stats"] = Field(default="scheduler_stats") - - def __call__( - self, - state: AggregatorState, - response: ResponseT | None, - request: RequestT, - request_info: ScheduledRequestInfo, - scheduler_state: SchedulerState, - ) -> dict[str, Any] | None: - """ - Aggregate scheduler timing metrics for a completed request. - - :param agg_state: Current aggregation state to update. - :param response: Response generated for the request, if successful. - :param request: The processed request object. - :param request_info: Scheduling metadata and timing information. - :param scheduler_state: Current scheduler execution state. - :return: Updated aggregation state for intermediate reporting. - """ - _ = (response, request, scheduler_state) # unused - if request_info.status not in ("completed", "errored", "cancelled"): - # Only compile scheduler stats for processed requests - return None - - state["updated_scheduler_stats"] = True - state.add_metric( - key="queued_time", - value=request_info.scheduler_timings.dequeued, - start_val=request_info.scheduler_timings.queued, - ) - state.add_metric( - key="worker_resolve_start_delay", - value=request_info.scheduler_timings.resolve_start, - start_val=request_info.scheduler_timings.scheduled_at, - ) - state.add_metric( - key="worker_resolve_time", - value=request_info.scheduler_timings.resolve_end, - start_val=request_info.scheduler_timings.resolve_start, - ) - state.add_metric( - key="worker_resolve_end_delay", - value=request_info.scheduler_timings.resolve_end, - start_val=safe_getattr(request_info.request_timings, "request_end"), - ) - state.add_metric( - key="finalized_delay", - value=request_info.scheduler_timings.finalized, - start_val=request_info.scheduler_timings.resolve_end, - ) - state.add_metric( - key="worker_targeted_start_delay", - value=request_info.scheduler_timings.resolve_start, - start_val=request_info.scheduler_timings.targeted_start, - ) - state.add_metric( - key="request_start_delay", - value=request_info.scheduler_timings.resolve_start, - start_val=safe_getattr(request_info.request_timings, "request_start"), - ) - state.add_metric( - key="request_time", - value=safe_getattr(request_info.request_timings, "request_end"), - start_val=safe_getattr(request_info.request_timings, "request_start"), - ) - state.add_metric( - key="request_targeted_start_delay", - value=safe_getattr(request_info.request_timings, "request_start"), - start_val=request_info.scheduler_timings.targeted_start, - ) - - return state - - def compile( - self, state: AggregatorState, scheduler_state: SchedulerState - ) -> dict[Literal["run_stats"], BenchmarkSchedulerStats]: - """ - Compile scheduler timing metrics into benchmark statistics. - - :param agg_state: Accumulated timing data and counts. - :param scheduler_state: Final scheduler execution state. - :return: Dictionary containing compiled scheduler statistics. - """ - return { - "run_stats": BenchmarkSchedulerStats( - start_time=scheduler_state.start_time, - end_time=scheduler_state.end_time, - requests_made=StatusBreakdown[int, int, int, int]( - successful=scheduler_state.successful_requests, - incomplete=scheduler_state.cancelled_requests, - errored=scheduler_state.errored_requests, - total=( - scheduler_state.successful_requests - + scheduler_state.cancelled_requests - + scheduler_state.errored_requests - ), - ), - queued_time_avg=state.get_metric( - key="queued_time", type_="avg", default=0.0 - ), - worker_resolve_start_delay_avg=state.get_metric( - key="worker_resolve_start_delay", type_="avg", default=0.0 - ), - worker_resolve_time_avg=state.get_metric( - key="worker_resolve_time", type_="avg", default=0.0 - ), - worker_resolve_end_delay_avg=state.get_metric( - key="worker_resolve_end_delay", type_="avg", default=0.0 - ), - finalized_delay_avg=state.get_metric( - key="finalized_delay", type_="avg", default=0.0 - ), - worker_targeted_start_delay_avg=state.get_metric( - key="worker_targeted_start_delay", type_="avg", default=0.0 - ), - request_start_delay_avg=state.get_metric( - key="request_start_delay", type_="avg", default=0.0 - ), - request_time_avg=state.get_metric( - key="request_time", type_="avg", default=0.0 - ), - request_targeted_start_delay_avg=state.get_metric( - key="request_targeted_start_delay", type_="avg", default=0.0 - ), - ), - } - - -@SerializableAggregator.register("generative_stats_progress") -class GenerativeStatsProgressAggregator( - SerializableAggregator[GenerationResponse, GenerationRequest] -): - """ - Tracks generative model metrics during benchmark execution. - - Aggregates token-level metrics including time to first token, inter-token - latency, and token counts for real-time progress monitoring. - """ - - @classmethod - def validated_kwargs(cls, *_args, **_kwargs) -> dict[str, Any]: - return {} - - type_: Literal["generative_stats_progress"] = Field( - default="generative_stats_progress" - ) - - def __call__( - self, - state: AggregatorState, - response: GenerationResponse | None, - request: GenerationRequest, - request_info: ScheduledRequestInfo, - scheduler_state: SchedulerState, - ) -> dict[str, Any] | None: - """ - Aggregate generative model metrics for a completed request. - - :param agg_state: Current aggregation state to update. - :param response: Generation response with token and timing data. - :param request: The processed generation request. - :param request_info: Scheduling metadata and timing information. - :param scheduler_state: Current scheduler execution state. - :return: Updated aggregation state for progress reporting. - """ - _ = (request,) # unused - if request_info.status not in {"completed", "errored", "cancelled"}: - # Only compile progress stats for processed requests - return None - - state["updated_generative_stats"] = True - start_time = scheduler_state.start_time - end_time = ( - safe_getattr(request_info.request_timings, "request_end") - or request_info.scheduler_timings.resolve_end - ) - duration = end_time - start_time if end_time else None - - for prefix in (request_info.status, None): - requests_count = ( - scheduler_state.processed_requests - if prefix is None - else scheduler_state.successful_requests - if request_info.status == "completed" - else scheduler_state.cancelled_requests - if request_info.status == "cancelled" - else scheduler_state.errored_requests - ) - - # Requests per Second - if duration is not None: - state.set_metric( - key="requests", - value=safe_divide(requests_count, duration), - type_="rate", - prefix=prefix, - ) - - # Request Concurrency - state.set_metric( - key="requests", - value=scheduler_state.processing_requests, - type_="avg", - prefix=prefix, - ) - - # Request Latency - state.add_metric( - key="request_latency", - value=safe_getattr(request_info.request_timings, "request_end"), - start_val=safe_getattr(request_info.request_timings, "request_start"), - prefix=prefix, - ) - - # Time to First Token - state.add_metric( - key="time_to_first_token", - value=safe_getattr(request_info.request_timings, "first_iteration"), - start_val=safe_getattr(request_info.request_timings, "request_start"), - prefix=prefix, - ) - - output_tokens = safe_getattr(response, "output_tokens") - prompt_tokens = safe_getattr(response, "prompt_tokens") - - # Inter Token Latency - state.add_metric( - key="inter_token_latency", - value=safe_getattr(request_info.request_timings, "last_iteration"), - start_val=safe_getattr(request_info.request_timings, "first_iteration"), - count=( - output_tokens - 1 if output_tokens and output_tokens > 1 else None - ), - prefix=prefix, - ) - - # Time per Output Token - state.add_metric( - key="time_per_output_token", - value=safe_getattr(request_info.request_timings, "request_start"), - start_val=safe_getattr(request_info.request_timings, "last_iteration"), - count=output_tokens, - prefix=prefix, - ) - - # Prompt Tokens - state.add_metric( - key="prompt_tokens", - value=prompt_tokens, - duration=duration, - prefix=prefix, - ) - - # Output Tokens - state.add_metric( - key="output_tokens", - value=output_tokens, - duration=duration, - prefix=prefix, - ) - - # Total Tokens - state.add_metric( - key="total_tokens", - value=( - prompt_tokens + output_tokens - if all_defined(prompt_tokens, output_tokens) - else prompt_tokens - if all_defined(prompt_tokens) - else output_tokens - if all_defined(output_tokens) - else None - ), - duration=duration, - prefix=prefix, - ) - - return state - - def compile( - self, state: AggregatorState, scheduler_state: SchedulerState - ) -> dict[str, Any]: - """ - Compile progress metrics into final results. - - GenerativeStatsProgressAggregator is primarily for progress tracking, - so compilation returns the aggregated state as-is. - - :param agg_state: The accumulated aggregation state. - :param scheduler_state: Final scheduler execution state. - :return: The aggregated state as final results. - """ - _ = (state, scheduler_state) # unused - return {} - - -@SerializableAggregator.register("generative_requests") -class GenerativeRequestsAggregator( - SerializableAggregator[GenerationResponse, GenerationRequest], -): - """ - Compiles complete generative benchmark results with warmup/cooldown filtering. - - Aggregates request data during execution and compiles comprehensive metrics - including timing distributions, token statistics, and throughput measurements. - Supports filtering warmup and cooldown periods from final results. - """ - - @classmethod - def validated_kwargs( - cls, - request_samples: int | None = 20, - warmup: int | float | None = None, - cooldown: int | float | None = None, - **_kwargs, - ) -> dict[str, Any]: - return { - "request_samples": request_samples, - "warmup": warmup, - "cooldown": cooldown, - } - - type_: Literal["generative_requests"] = Field(default="generative_requests") - - request_samples: int | None = Field(default=20, description="") - warmup: int | float | None = Field( - default=None, - description="Number of warmup requests to ignore at benchmark start", - ) - cooldown: int | float | None = Field( - default=None, - description="Number of cooldown requests to ignore at benchmark end", - ) - _in_cooldown: bool = PrivateAttr(False) - _in_warmup: bool = PrivateAttr(False) - - def __call__( - self, - state: AggregatorState, - response: GenerationResponse | None, - request: GenerationRequest, - request_info: ScheduledRequestInfo, - scheduler_state: SchedulerState, - ) -> dict[str, Any] | None: - """ - Collect completed requests for final compilation. - - Filters requests based on warmup/cooldown settings and categorizes by - completion status for comprehensive benchmark analysis. - - :param agg_state: Current aggregation state to update. - :param response: Generation response data. - :param request: The processed generation request. - :param request_info: Scheduling metadata and timing information. - :param scheduler_state: Current scheduler execution state. - :return: None, as this aggregator only collects for final compilation. - """ - # Skip invalid requests - if request_info.status not in {"completed", "canceled", "errored"} or ( - request_info.status == "canceled" - and safe_getattr(request_info.scheduler_timings, "resolve_start") is None - # Canceled requests that never started should not be kept - ): - return None - - status = { - "updated_generative_requests": True, - "requests_in_warmup": False, - "requests_in_cooldown": False, - } - - if self._is_in_warmup(request_info, scheduler_state): - status["requests_in_warmup"] = True - return status - - if self._is_in_cooldown(request_info, scheduler_state): - status["requests_in_cooldown"] = True - return status - - if "completed" not in state: - state["completed"] = [] - state["errored"] = [] - state["incomplete"] = [] - - # Categorize request by status - if request_info.status == "completed": - state["completed"].append((response, request, request_info)) - elif request_info.status == "canceled": - state["incomplete"].append((response, request, request_info)) - else: - state["errored"].append((response, request, request_info)) - - return status - - def compile( - self, - state: AggregatorState, - scheduler_state: SchedulerState, # noqa: ARG002 - ) -> dict[str, Any]: - """ - Compile aggregated requests into comprehensive benchmark results. - - Transforms collected request data into detailed metrics including timing - distributions, token statistics, throughput measurements, and status breakdowns. - - :param agg_state: Accumulated request data categorized by completion status. - :param scheduler_state: Final scheduler execution state. - :return: Complete benchmark results with metrics and request statistics. - """ - successful: list[GenerativeRequestStats] = [ - self._create_generative_request_stats(response, request, request_info) - for (response, request, request_info) in state.get("completed", []) - ] - incomplete: list[GenerativeRequestStats] = [ - self._create_generative_request_stats(response, request, request_info) - for (response, request, request_info) in state.get("incomplete", []) - ] - errored: list[GenerativeRequestStats] = [ - self._create_generative_request_stats(response, request, request_info) - for (response, request, request_info) in state.get("errored", []) - ] - - # Use all requests for metrics calculations (not sampled) - total: list[GenerativeRequestStats] = successful + incomplete + errored - total_types: list[Literal["successful", "incomplete", "error"]] = [ - *["successful"] * len(successful), - *["incomplete"] * len(incomplete), - *["error"] * len(errored), - ] - start_time = min( - [math.inf] - + [ - req.scheduler_info.request_timings.request_start - for req in total - if req.scheduler_info.request_timings.request_start is not None - ] - ) - end_time = max( - [-1 * math.inf] - + [ - req.scheduler_info.request_timings.request_end - for req in total - if req.scheduler_info.request_timings.request_end is not None - ] - ) - - return { - "start_time": start_time, - "end_time": end_time, - "request_totals": StatusBreakdown[int, int, int, int]( - successful=len(successful), - incomplete=len(incomplete), - errored=len(errored), - total=len(total), - ), - "requests": StatusBreakdown[ - list[GenerativeRequestStats], - list[GenerativeRequestStats], - list[GenerativeRequestStats], - list[GenerativeRequestStats], - ]( - successful=self._sample_request_stats(successful, self.request_samples), - incomplete=self._sample_request_stats(incomplete, self.request_samples), - errored=self._sample_request_stats(errored, self.request_samples), - ), - "metrics": GenerativeMetrics( - requests_per_second=self._calculate_requests_per_second( - statuses=total_types, requests=total - ), - request_concurrency=self._calculate_request_concurrency( - statuses=total_types, requests=total - ), - request_latency=self._calculate_request_latency( - statuses=total_types, requests=total - ), - prompt_token_count=self._calculate_prompt_token_count( - statuses=total_types, requests=total - ), - output_token_count=self._calculate_output_token_count( - statuses=total_types, requests=total - ), - total_token_count=self._calculate_total_token_count( - statuses=total_types, requests=total - ), - time_to_first_token_ms=self._calculate_time_to_first_token_ms( - statuses=total_types, requests=total - ), - time_per_output_token_ms=self._calculate_time_per_output_token_ms( - statuses=total_types, requests=total - ), - inter_token_latency_ms=self._calculate_inter_token_latency_ms( - statuses=total_types, requests=total - ), - output_tokens_per_second=self._calculate_output_tokens_per_second( - statuses=total_types, requests=total - ), - tokens_per_second=self._calculate_tokens_per_second( - statuses=total_types, requests=total - ), - ), - } - - def _is_in_warmup( - self, - request_info: ScheduledRequestInfo, - scheduler_state: SchedulerState, - ) -> bool: - """Check if the current request is within the warmup period.""" - if self.warmup is None: - return False - - if 0 < self.warmup < 1: # Percentage-based warmup - return ( - scheduler_state.remaining_fraction is not None - and scheduler_state.remaining_fraction > (1 - self.warmup) - ) - - if self.warmup >= 1: # Count/time-based warmup - if scheduler_state.processed_requests < self.warmup: - return True - - current_time = request_info.scheduler_timings.targeted_start - return ( - current_time is not None - and (current_time - scheduler_state.start_time) < self.warmup - ) - - return False - - def _is_in_cooldown( - self, - request_info: ScheduledRequestInfo, - scheduler_state: SchedulerState, - ) -> bool: - """Check if the current request is within the cooldown period.""" - if self.cooldown is None: - return False - - if 0 < self.cooldown < 1: # Percentage-based cooldown - return ( - scheduler_state.remaining_fraction is not None - and scheduler_state.remaining_fraction < self.cooldown - ) - - if self.cooldown >= 1: # Count/time-based cooldown - if scheduler_state.remaining_requests <= self.cooldown: - return True - - current_time = ( - request_info.scheduler_timings.resolve_end - or request_info.scheduler_timings.targeted_start - ) - return ( - current_time is not None - and scheduler_state.remaining_duration is not None - and scheduler_state.remaining_duration < self.cooldown - ) - - return False - - @classmethod - def _create_generative_request_stats( - cls, - response: GenerationResponse, - request: GenerationRequest, - request_info: ScheduledRequestInfo, - ) -> GenerativeRequestStats: - prompt_tokens = response.preferred_prompt_tokens( - settings.preferred_prompt_tokens_source - ) - output_tokens = response.preferred_output_tokens( - settings.preferred_output_tokens_source - ) - - return GenerativeRequestStats( - request_id=request.request_id, - request_type=request.request_type, - prompt=str(request.content), - request_args=response.request_args, - output=response.value, - iterations=response.iterations, - prompt_tokens=prompt_tokens, - output_tokens=output_tokens, - total_tokens=( - prompt_tokens + output_tokens - if prompt_tokens is not None and output_tokens is not None - else None - ), - scheduler_info=request_info, - ) - - @classmethod - def _sample_request_stats( - cls, stats: list[GenerativeRequestStats], sample_size: int | None - ) -> list[GenerativeRequestStats]: - if sample_size is None or sample_size <= 0 or not stats: - return stats - - return random.sample(stats, min(sample_size, len(stats))) - - @classmethod - def _calculate_requests_per_second( - cls, - statuses: list[Literal["successful", "incomplete", "error"]], - requests: list[GenerativeRequestStats], - ) -> StatusDistributionSummary: - filtered_statuses = [] - filtered_times = [] - - for status, request in zip(statuses, requests, strict=False): - if not all_defined( - safe_getattr(request.scheduler_info.request_timings, "request_start"), - safe_getattr(request.scheduler_info.request_timings, "request_end"), - ): - continue - - filtered_statuses.append(status) - filtered_times.append( - ( - request.scheduler_info.request_timings.request_start, - request.scheduler_info.request_timings.request_end, - ) - ) - - return StatusDistributionSummary.from_request_times( - request_types=filtered_statuses, - requests=filtered_times, - distribution_type="rate", - ) - - @classmethod - def _calculate_request_concurrency( - cls, - statuses: list[Literal["successful", "incomplete", "error"]], - requests: list[GenerativeRequestStats], - ) -> StatusDistributionSummary: - filtered_statuses = [] - filtered_times = [] - - for status, request in zip(statuses, requests, strict=False): - if not all_defined( - safe_getattr(request.scheduler_info.request_timings, "request_start"), - safe_getattr(request.scheduler_info.request_timings, "request_end"), - ): - continue - - filtered_statuses.append(status) - filtered_times.append( - ( - request.scheduler_info.request_timings.request_start, - request.scheduler_info.request_timings.request_end, - ) - ) - - return StatusDistributionSummary.from_request_times( - request_types=filtered_statuses, - requests=filtered_times, - distribution_type="concurrency", - ) - - @classmethod - def _calculate_request_latency( - cls, - statuses: list[Literal["successful", "incomplete", "error"]], - requests: list[GenerativeRequestStats], - ) -> StatusDistributionSummary: - filtered_statuses = [] - filtered_values = [] - - for status, request in zip(statuses, requests, strict=False): - if not all_defined(request.request_latency): - continue - - filtered_statuses.append(status) - filtered_values.append(request.request_latency) - - return StatusDistributionSummary.from_values( - value_types=filtered_statuses, - values=filtered_values, - ) - - @classmethod - def _calculate_prompt_token_count( - cls, - statuses: list[Literal["successful", "incomplete", "error"]], - requests: list[GenerativeRequestStats], - ) -> StatusDistributionSummary: - filtered_statuses = [] - filtered_values = [] - - for status, request in zip(statuses, requests, strict=False): - if not all_defined(request.prompt_tokens): - continue - - filtered_statuses.append(status) - filtered_values.append(request.prompt_tokens) - - return StatusDistributionSummary.from_values( - value_types=filtered_statuses, - values=filtered_values, - ) - - @classmethod - def _calculate_output_token_count( - cls, - statuses: list[Literal["successful", "incomplete", "error"]], - requests: list[GenerativeRequestStats], - ) -> StatusDistributionSummary: - filtered_statuses = [] - filtered_values = [] - - for status, request in zip(statuses, requests, strict=False): - if not all_defined(request.output_tokens): - continue - - filtered_statuses.append(status) - filtered_values.append(request.output_tokens) - - return StatusDistributionSummary.from_values( - value_types=filtered_statuses, - values=filtered_values, - ) - - @classmethod - def _calculate_total_token_count( - cls, - statuses: list[Literal["successful", "incomplete", "error"]], - requests: list[GenerativeRequestStats], - ) -> StatusDistributionSummary: - filtered_statuses = [] - filtered_values = [] - - for status, request in zip(statuses, requests, strict=False): - if not all_defined(request.total_tokens): - continue - - filtered_statuses.append(status) - filtered_values.append(request.total_tokens) - - return StatusDistributionSummary.from_values( - value_types=filtered_statuses, - values=filtered_values, - ) - - @classmethod - def _calculate_time_to_first_token_ms( - cls, - statuses: list[Literal["successful", "incomplete", "error"]], - requests: list[GenerativeRequestStats], - ) -> StatusDistributionSummary: - filtered_statuses = [] - filtered_values = [] - - for status, request in zip(statuses, requests, strict=False): - if not all_defined(request.time_to_first_token_ms): - continue - - filtered_statuses.append(status) - filtered_values.append(request.time_to_first_token_ms) - - return StatusDistributionSummary.from_values( - value_types=filtered_statuses, - values=filtered_values, - ) - - @classmethod - def _calculate_time_per_output_token_ms( - cls, - statuses: list[Literal["successful", "incomplete", "error"]], - requests: list[GenerativeRequestStats], - ) -> StatusDistributionSummary: - filtered_statuses = [] - filtered_values = [] - filtered_weights = [] - - for status, request in zip(statuses, requests, strict=False): - if not all_defined(request.time_to_first_token_ms): - continue - - # Add time to first token separately to better reflect in distribution - filtered_statuses.append(status) - filtered_values.append(request.time_to_first_token_ms) - filtered_weights.append(1) - - if not all_defined(request.inter_token_latency_ms): - continue - - # Add tokens after the first token to get the full distribution - filtered_statuses.append(status) - filtered_values.append(request.inter_token_latency_ms) - filtered_weights.append(request.output_tokens - 1) - - return StatusDistributionSummary.from_values( - value_types=filtered_statuses, - values=filtered_values, - weights=filtered_weights, - ) - - @classmethod - def _calculate_inter_token_latency_ms( - cls, - statuses: list[Literal["successful", "incomplete", "error"]], - requests: list[GenerativeRequestStats], - ) -> StatusDistributionSummary: - filtered_statuses = [] - filtered_values = [] - filtered_weights = [] - - for status, request in zip(statuses, requests, strict=False): - if not all_defined(request.inter_token_latency_ms): - continue - - filtered_statuses.append(status) - filtered_values.append(request.inter_token_latency_ms) - filtered_weights.append(request.output_tokens - 1) - - return StatusDistributionSummary.from_values( - value_types=filtered_statuses, - values=filtered_values, - weights=filtered_weights, - ) - - @classmethod - def _calculate_output_tokens_per_second( - cls, - statuses: list[Literal["successful", "incomplete", "error"]], - requests: list[GenerativeRequestStats], - ) -> StatusDistributionSummary: - filtered_statuses = [] - filtered_request_times = [] - filtered_first_iter_times = [] - filtered_iter_counts = [] - - for status, request in zip(statuses, requests, strict=False): - if not all_defined(request.output_tokens_per_second): - continue - - filtered_statuses.append(status) - filtered_request_times.append( - ( - request.scheduler_info.request_timings.request_start, - request.scheduler_info.request_timings.request_end, - ) - ) - filtered_first_iter_times.append( - request.scheduler_info.request_timings.first_iteration - ) - filtered_iter_counts.append(request.output_tokens) - - return StatusDistributionSummary.from_iterable_request_times( - request_types=filtered_statuses, - requests=filtered_request_times, - first_iter_times=filtered_first_iter_times, - iter_counts=filtered_iter_counts, - ) - - @classmethod - def _calculate_tokens_per_second( - cls, - statuses: list[Literal["successful", "incomplete", "error"]], - requests: list[GenerativeRequestStats], - ) -> StatusDistributionSummary: - filtered_statuses = [] - filtered_request_times = [] - filtered_first_iter_times = [] - filtered_iter_counts = [] - filtered_first_iter_counts = [] - - for status, request in zip(statuses, requests, strict=False): - if not all_defined(request.tokens_per_second): - continue - - filtered_statuses.append(status) - filtered_request_times.append( - ( - request.scheduler_info.request_timings.request_start, - request.scheduler_info.request_timings.request_end, - ) - ) - filtered_first_iter_times.append( - request.scheduler_info.request_timings.first_iteration - ) - filtered_iter_counts.append(request.output_tokens - 1) - filtered_first_iter_counts.append(request.prompt_tokens + 1) - - return StatusDistributionSummary.from_iterable_request_times( - request_types=filtered_statuses, - requests=filtered_request_times, - first_iter_times=filtered_first_iter_times, - iter_counts=filtered_iter_counts, - first_iter_counts=filtered_first_iter_counts, - ) diff --git a/src/guidellm/benchmark/benchmarker.py b/src/guidellm/benchmark/benchmarker.py index 99410e4c..ed9d789b 100644 --- a/src/guidellm/benchmark/benchmarker.py +++ b/src/guidellm/benchmark/benchmarker.py @@ -20,31 +20,24 @@ import uuid from abc import ABC from collections.abc import AsyncIterator, Iterable -from typing import ( - Any, - Generic, -) +from typing import Generic -from guidellm.benchmark.aggregator import ( - Aggregator, - AggregatorState, - CompilableAggregator, -) -from guidellm.benchmark.objects import BenchmarkerDict, BenchmarkT, SchedulerDict from guidellm.benchmark.profile import Profile +from guidellm.benchmark.progress import BenchmarkerProgress +from guidellm.benchmark.schemas import ( + BenchmarkArgs, + BenchmarkT, + EstimatedBenchmarkState, +) +from guidellm.logger import logger from guidellm.scheduler import ( BackendInterface, - Constraint, Environment, - NonDistributedEnvironment, RequestT, ResponseT, Scheduler, - SchedulerState, - SchedulingStrategy, ) -from guidellm.utils import InfoMixin, ThreadSafeSingletonMixin -from guidellm.utils.pydantic_utils import StandardBaseDict +from guidellm.utils import ThreadSafeSingletonMixin __all__ = ["Benchmarker"] @@ -67,23 +60,17 @@ class Benchmarker( async def run( self, + benchmark_class: type[BenchmarkT], requests: Iterable[RequestT | Iterable[RequestT | tuple[RequestT, float]]], backend: BackendInterface[RequestT, ResponseT], profile: Profile, - benchmark_class: type[BenchmarkT], - benchmark_aggregators: dict[ - str, - Aggregator[ResponseT, RequestT] | CompilableAggregator[ResponseT, RequestT], - ], - environment: Environment | None = None, - ) -> AsyncIterator[ - tuple[ - AggregatorState | None, - BenchmarkT | None, - SchedulingStrategy, - SchedulerState | None, - ] - ]: + environment: Environment, + progress: BenchmarkerProgress[BenchmarkT] | None = None, + sample_requests: int | None = 20, + warmup: float | None = None, + cooldown: float | None = None, + prefer_response_metrics: bool = True, + ) -> AsyncIterator[BenchmarkT]: """ Execute benchmark runs across multiple scheduling strategies. @@ -101,18 +88,27 @@ async def run( :raises Exception: If benchmark execution or compilation fails. """ with self.thread_lock: - if environment is None: - environment = NonDistributedEnvironment() + if progress: + await progress.on_initialize(profile) run_id = str(uuid.uuid4()) strategies_generator = profile.strategies_generator() strategy, constraints = next(strategies_generator) while strategy is not None: - yield None, None, strategy, None - aggregators_state = { - key: AggregatorState() for key in benchmark_aggregators - } + if progress: + await progress.on_benchmark_start(strategy) + + args = BenchmarkArgs( + run_id=run_id, + run_index=len(profile.completed_strategies), + sample_requests=sample_requests, + warmup=warmup, + cooldown=cooldown, + prefer_response_metrics=prefer_response_metrics, + ) + estimated_state = EstimatedBenchmarkState() + scheduler_state = None async for ( response, @@ -126,34 +122,39 @@ async def run( env=environment, **constraints or {}, ): - aggregators_update = AggregatorState() - for key, aggregator in benchmark_aggregators.items(): - update = aggregator( - aggregators_state[key], + try: + benchmark_class.update_estimate( + args, + estimated_state, response, request, request_info, scheduler_state, ) - if update: - aggregators_update.update(update) - yield aggregators_update, None, strategy, scheduler_state + if progress: + await progress.on_benchmark_update( + estimated_state, scheduler_state + ) + except Exception as err: + logger.error( + f"Error updating benchmark estimate/progress: {err}" + ) - benchmark_kwargs = self._compile_benchmark_kwargs( - run_id=run_id, - run_index=len(profile.completed_strategies), + benchmark = benchmark_class.compile( + args=args, + estimated_state=estimated_state, + scheduler_state=scheduler_state, profile=profile, requests=requests, backend=backend, environment=environment, - aggregators=benchmark_aggregators, - aggregators_state=aggregators_state, strategy=strategy, constraints=constraints, - scheduler_state=scheduler_state, ) - benchmark = benchmark_class(**benchmark_kwargs) - yield None, benchmark, strategy, None + if progress: + await progress.on_benchmark_complete(benchmark) + + yield benchmark try: strategy, constraints = strategies_generator.send(benchmark) @@ -161,106 +162,5 @@ async def run( strategy = None constraints = None - @classmethod - def _compile_benchmark_kwargs( - cls, - run_id: str, - run_index: int, - profile: Profile, - requests: Iterable[RequestT | Iterable[RequestT | tuple[RequestT, float]]], - backend: BackendInterface[RequestT, ResponseT], - environment: Environment, - aggregators: dict[ - str, - Aggregator[ResponseT, RequestT] | CompilableAggregator[ResponseT, RequestT], - ], - aggregators_state: dict[str, dict[str, Any]], - strategy: SchedulingStrategy, - constraints: dict[str, Any | dict[str, Any] | Constraint], - scheduler_state: SchedulerState | None, - ) -> dict[str, Any]: - """ - Compile benchmark construction parameters from execution results. - - Aggregates metadata from scheduler execution and compiles it into - structured parameters for benchmark object construction. - - :param run_id: Unique identifier for the benchmark run. - :param run_index: Index of this strategy in the benchmark profile. - :param profile: Benchmark profile containing strategy configuration. - :param requests: Request datasets used for the benchmark. - :param backend: Backend interface used for request processing. - :param environment: Execution environment for coordination. - :param aggregators: Metric aggregation functions by name. - :param aggregators_state: Current state of metric aggregators. - :param strategy: Scheduling strategy that was executed. - :param constraints: Runtime constraints applied during execution. - :param scheduler_state: Final state of scheduler execution. - :return: Dictionary of parameters for benchmark object construction. - :raises ValueError: If aggregator output conflicts with existing keys. - """ - benchmark_kwargs = { - "run_id": run_id, - "run_index": run_index, - "scheduler": SchedulerDict( - strategy=strategy, - constraints={ - key: InfoMixin.extract_from_obj(val) - for key, val in constraints.items() - }, - state=scheduler_state, - ), - "benchmarker": BenchmarkerDict( - profile=profile, - requests=InfoMixin.extract_from_obj(requests), - backend=backend.info, - environment=environment.info, - aggregators={ - key: InfoMixin.extract_from_obj(aggregator) - for key, aggregator in aggregators.items() - }, - ), - "env_args": StandardBaseDict(), - "extras": StandardBaseDict(), - } - - def _combine( - existing: dict[str, Any] | StandardBaseDict, - addition: dict[str, Any] | StandardBaseDict, - ) -> dict[str, Any] | StandardBaseDict: - if not isinstance(existing, dict | StandardBaseDict): - raise ValueError( - f"Existing value {existing} (type: {type(existing).__name__}) " - f"is not a valid type for merging." - ) - if not isinstance(addition, dict | StandardBaseDict): - raise ValueError( - f"Addition value {addition} (type: {type(addition).__name__}) " - f"is not a valid type for merging." - ) - - add_kwargs = ( - addition if isinstance(addition, dict) else addition.model_dump() - ) - - if isinstance(existing, dict): - return {**add_kwargs, **existing} - - return existing.__class__(**{**add_kwargs, **existing.model_dump()}) - - for key, aggregator in aggregators.items(): - if not isinstance(aggregator, CompilableAggregator): - continue - - compiled = aggregator.compile(aggregators_state[key], scheduler_state) - - for field_name, field_val in compiled.items(): - if field_name in benchmark_kwargs: - # If the key already exists, merge the values - benchmark_kwargs[field_name] = _combine( - benchmark_kwargs[field_name], field_val - ) - else: - benchmark_kwargs[field_name] = field_val - - return benchmark_kwargs + if progress: + await progress.on_finalize() diff --git a/src/guidellm/benchmark/entrypoints.py b/src/guidellm/benchmark/entrypoints.py index b926394f..18768216 100644 --- a/src/guidellm/benchmark/entrypoints.py +++ b/src/guidellm/benchmark/entrypoints.py @@ -1,41 +1,32 @@ from __future__ import annotations +from collections.abc import Callable from pathlib import Path from typing import Any, Literal -from guidellm.backends import ( - Backend, - BackendType, - GenerationRequest, - GenerationResponse, -) -from guidellm.benchmark.aggregator import ( - GenerativeRequestsAggregator, - GenerativeStatsProgressAggregator, - SchedulerStatsAggregator, - SerializableAggregator, -) +from torch.utils.data import Sampler + +from guidellm.backends import Backend, BackendType from guidellm.benchmark.benchmarker import Benchmarker -from guidellm.benchmark.objects import GenerativeBenchmark, GenerativeBenchmarksReport -from guidellm.benchmark.output import ( - GenerativeBenchmarkerOutput, -) +from guidellm.benchmark.output import GenerativeBenchmarkerOutput from guidellm.benchmark.profile import Profile, ProfileType -from guidellm.benchmark.progress import BenchmarkerProgressGroup -from guidellm.benchmark.scenario import enable_scenarios -from guidellm.benchmark.types import ( - AggregatorInputT, - DataInputT, - OutputFormatT, - ProcessorInputT, - ProgressInputT, +from guidellm.benchmark.progress import BenchmarkerProgress +from guidellm.benchmark.schemas import GenerativeBenchmark, GenerativeBenchmarksReport +from guidellm.benchmark.types import OutputFormatT, ProcessorInputT +from guidellm.data import ( + DataLoader, + DatasetPreprocessor, + GenerativeRequestCollator, + PreprocessorRegistry, + ProcessorFactory, ) -from guidellm.request import GenerativeRequestLoader +from guidellm.data.preprocessors import GenerativeColumnMapper from guidellm.scheduler import ( ConstraintInitializer, NonDistributedEnvironment, StrategyType, ) +from guidellm.schemas import GenerationRequest, GenerationResponse from guidellm.utils import Console, InfoMixin __all__ = [ @@ -44,88 +35,249 @@ ] +# Helper Variables + _CURRENT_WORKING_DIR = Path.cwd() -# Helper functions +# Helper Functions -async def initialize_backend( +async def resolve_backend( backend: BackendType | Backend, target: str, model: str | None, - backend_kwargs: dict[str, Any] | None, -) -> Backend: + console: Console | None = None, + **backend_kwargs: dict[str, Any], +) -> tuple[Backend, str | None]: + console_step = ( + console.print_update_step(title=f"Initializing backend {backend}") + if console + else None + ) backend = ( Backend.create(backend, target=target, model=model, **(backend_kwargs or {})) if not isinstance(backend, Backend) else backend ) + + if console_step: + console_step.update(f"{backend.__class__.__name__} backend initialized") + await backend.process_startup() await backend.validate() - return backend + + if model is None: + if console_step: + console_step.update( + title="Resolving default model from backend.default_model", + status_level="info", + ) + model = await backend.default_model() + + await backend.process_shutdown() + + if console_step: + console_step.finish( + title=( + f"{backend.__class__.__name__} backend validated with model {model}" + ), + details=backend.info, + status_level="success", + ) + + return backend, model + + +async def resolve_processor( + processor: ProcessorInputT | None, + model: str | None, + console: Console | None = None, +) -> ProcessorInputT | None: + console_step = ( + console.print_update_step(title=f"Resolving processor {processor}") + if console + else None + ) + + if processor is not None: + if console_step: + console_step.finish( + title="Processor resolved", + details=f"Using processor '{processor}'", + status_level="success", + ) + else: + processor = model + if console_step: + console_step.finish( + title="Processor resolved", + details=f"Using model '{processor}' as processor", + status_level="success", + ) + + return processor + + +async def resolve_request_loader( + data: list[Any], + model: str | None, + data_args: list[dict[str, Any]] | None, + data_samples: int, + processor: ProcessorInputT | None, + processor_args: dict[str, Any] | None, + data_column_mapper: ( + DatasetPreprocessor | dict[str, str] | Literal["generative_column_mapper"] + ), + data_request_formatter: (DatasetPreprocessor | dict[str, str] | str), + data_collator: Callable | Literal["generative"] | None, + data_sampler: Sampler[int] | Literal["shuffle"] | None, + data_num_workers: int | None, + random_seed: int, + console: Console | None = None, + **dataloader_kwargs: dict[str, Any] | None, +) -> DataLoader[GenerationRequest]: + console_step = ( + console.print_update_step(title=f"Initializing request loader from {data}") + if console + else None + ) + + if not isinstance(data_column_mapper, DatasetPreprocessor): + column_mappings = ( + data_column_mapper if isinstance(data_column_mapper, dict) else None + ) + data_column_mapper = GenerativeColumnMapper( + column_mappings=column_mappings, + ) + if not isinstance(data_request_formatter, DatasetPreprocessor): + request_type = ( + data_request_formatter + if isinstance(data_request_formatter, str) + else data_request_formatter.pop("request_type", "chat_completions") + ) + data_request_formatter = PreprocessorRegistry.get_registered_object( + request_type + )( + model=model, + **( + data_request_formatter + if isinstance(data_request_formatter, dict) + else {} + ), + ) + + request_loader = DataLoader( + data=data, + data_args=data_args, + data_samples=data_samples, + processor_factory=ProcessorFactory( + processor=processor, processor_args=processor_args + ), + preprocessors=[data_column_mapper, data_request_formatter], + collator=( + data_collator if callable(data_collator) else GenerativeRequestCollator() + ), + sampler=data_sampler, + num_workers=data_num_workers, + random_seed=random_seed, + **(dataloader_kwargs or {}), + ) + + if console_step: + console_step.finish( + title=( + f"Request loader initialized with " + f"{data_samples if data_samples > 0 else 'inf'} " + f"unique requests from {data}" + ), + details=InfoMixin.extract_from_obj(request_loader), + status_level="success", + ) + + return request_loader async def resolve_profile( - constraint_inputs: dict[str, int | float], - profile: Profile | str | None, - rate: list[float] | None, + profile: StrategyType | ProfileType | Profile, + rate: float | list[float] | None, random_seed: int, constraints: dict[str, ConstraintInitializer | Any], -): - for key, val in constraint_inputs.items(): + max_seconds: int | float | None, + max_requests: int | None, + max_errors: int | None, + max_error_rate: float | None, + max_global_error_rate: float | None, + console: Console | None = None, +) -> Profile: + console_step = ( + console.print_update_step(title=f"Resolving profile {profile}") + if console + else None + ) + + for key, val in { + "max_seconds": max_seconds, + "max_requests": max_requests, + "max_errors": max_errors, + "max_error_rate": max_error_rate, + "max_global_error_rate": max_global_error_rate, + }.items(): if val is not None: constraints[key] = val if not isinstance(profile, Profile): - if isinstance(profile, str): - profile = Profile.create( - rate_type=profile, - rate=rate, - random_seed=random_seed, - constraints={**constraints}, - ) - else: - raise ValueError(f"Expected string for profile; got {type(profile)}") - + profile = Profile.create( + rate_type=profile, + rate=rate, + random_seed=random_seed, + constraints={**constraints}, + ) elif constraints: raise ValueError( "Constraints must be empty when providing a Profile instance. " f"Provided constraints: {constraints} ; provided profile: {profile}" ) + + if console_step: + console_step.finish( + title=f"{profile.__class__.__name__} profile resolved", + details=InfoMixin.extract_from_obj(profile), + status_level="success", + ) + return profile async def resolve_output_formats( output_formats: OutputFormatT, output_path: str | Path | None, + console: Console | None = None, ) -> dict[str, GenerativeBenchmarkerOutput]: - return GenerativeBenchmarkerOutput.resolve( - output_formats=(output_formats or {}), output_path=output_path + console_step = ( + console.print_update_step(title="Resolving output formats") if console else None ) + resolved = GenerativeBenchmarkerOutput.resolve( + output_formats=output_formats, output_path=output_path + ) -async def finalize_outputs( - report: GenerativeBenchmarksReport, - resolved_output_formats: dict[str, GenerativeBenchmarkerOutput], -): - output_format_results = {} - for key, output in resolved_output_formats.items(): - output_result = await output.finalize(report) - output_format_results[key] = output_result - return output_format_results - + if console_step: + console_step.finish( + title="Output formats resolved", + details={key: str(val) for key, val in resolved.items()}, + status_level="success", + ) -# Complete entrypoints + return resolved -# @validate_call(config={"arbitrary_types_allowed": True}) -@enable_scenarios -async def benchmark_generative_text( # noqa: C901 +async def benchmark_generative_text( # noqa: C901, PLR0915, PLR0912 + # Required target: str, - data: DataInputT, - profile: StrategyType | ProfileType | Profile, - rate: list[float] | None = None, - random_seed: int = 42, + data: list[Any], + # Benchmark configuration + profile: StrategyType | ProfileType | Profile = "sweep", + rate: float | list[float] | None = None, # Backend configuration backend: BackendType | Backend = "openai_http", backend_kwargs: dict[str, Any] | None = None, @@ -133,19 +285,35 @@ async def benchmark_generative_text( # noqa: C901 # Data configuration processor: ProcessorInputT | None = None, processor_args: dict[str, Any] | None = None, - data_args: dict[str, Any] | None = None, - data_sampler: Literal["random"] | None = None, + data_args: list[dict[str, Any]] | None = None, + data_samples: int = -1, + data_column_mapper: ( + DatasetPreprocessor | dict[str, str] | Literal["generative_column_mapper"] + ) = "generative_column_mapper", + data_request_formatter: ( + DatasetPreprocessor | dict[str, str] | str + ) = "chat_completions", + data_collator: Callable | Literal["generative"] | None = "generative", + data_sampler: Sampler[int] | Literal["shuffle"] | None = None, + data_num_workers: int | None = None, + dataloader_kwargs: dict[str, Any] | None = None, + random_seed: int = 42, # Output configuration output_path: str | Path | None = _CURRENT_WORKING_DIR, - output_formats: OutputFormatT = ("console", "json", "html", "csv"), + output_formats: ( + tuple[str, ...] + | list[str] + | dict[str, str | dict[str, Any] | GenerativeBenchmarkerOutput] + | None + ) = ("console", "json", "html", "csv"), # Updates configuration - progress: ProgressInputT | None = None, + progress: BenchmarkerProgress | None = None, print_updates: bool = False, - # Aggregators configuration - add_aggregators: AggregatorInputT | None = None, + # Benchmarker configuration + benchmark_cls: type[GenerativeBenchmark] = GenerativeBenchmark, + sample_requests: int | None = 10, warmup: float | None = None, cooldown: float | None = None, - request_samples: int | None = 20, # Constraints configuration max_seconds: int | float | None = None, max_requests: int | None = None, @@ -155,156 +323,80 @@ async def benchmark_generative_text( # noqa: C901 **constraints: dict[str, ConstraintInitializer | Any], ) -> tuple[GenerativeBenchmarksReport, dict[str, Any]]: console = Console(quiet=not print_updates) - - with console.print_update_step( - title=f"Initializing backend {backend}" - ) as console_step: - backend = await initialize_backend(backend, target, model, backend_kwargs) - console_step.finish( - title=f"{backend.__class__.__name__} backend initialized", - details=backend.info, - status_level="success", - ) - - with console.print_update_step(title="Resolving processor") as console_step: - if processor is not None: - console_step.finish( - title="Processor resolved", - details=f"Using processor '{processor}'", - status_level="success", - ) - elif model is not None: - console_step.finish( - title="Processor resolved", - details=f"Using model '{model}' as processor", - status_level="success", - ) - processor = model - else: - console_step.update( - title="Resolving processor from backend.default_model", - status_level="info", - ) - processor = await backend.default_model() - console_step.finish( - title="Processor resolved", - details=( - f"Using model '{processor}' from backend " - f"{backend.__class__.__name__} as processor" - ), - status_level="success", - ) - await backend.process_shutdown() - - with console.print_update_step( - title=f"Initializing request loader from {data}" - ) as console_step: - request_loader = GenerativeRequestLoader( - data=data, - data_args=data_args, - processor=processor, - processor_args=processor_args, - shuffle=data_sampler == "random", - random_seed=random_seed, - ) - unique_requests = request_loader.num_unique_items(raise_err=False) - console_step.finish( - title=( - f"Request loader initialized with {unique_requests} unique requests " - f"from {data}" - ), - details=InfoMixin.extract_from_obj(request_loader), - status_level="success", - ) - - with console.print_update_step( - title=f"Resolving profile {profile}" - ) as console_step: - profile = await resolve_profile( - { - "max_seconds": max_seconds, - "max_requests": max_requests, - "max_errors": max_errors, - "max_error_rate": max_error_rate, - "max_global_error_rate": max_global_error_rate, - }, - profile, - rate, - random_seed, - constraints, - ) - console_step.finish( - title=f"{profile.__class__.__name__} profile resolved", - details=InfoMixin.extract_from_obj(profile), - status_level="success", - ) - - with console.print_update_step( - title="Creating benchmark aggregators" - ) as console_step: - aggregators = { - "scheduler_stats": SchedulerStatsAggregator(), - "requests_progress": GenerativeStatsProgressAggregator(), - "requests": GenerativeRequestsAggregator( - request_samples=request_samples, - warmup=warmup, - cooldown=cooldown, - ), - **SerializableAggregator.resolve(add_aggregators or {}), - } - console_step.finish( - title="Benchmark aggregators created", - details={key: str(val) for key, val in aggregators.items()}, - status_level="success", - ) - - with console.print_update_step(title="Resolving output formats") as console_step: - resolved_output_formats = await resolve_output_formats( - output_formats, output_path - ) - console_step.finish( - title="Output formats resolved", - details={key: str(val) for key, val in resolved_output_formats.items()}, - status_level="success", - ) - - progress_group = BenchmarkerProgressGroup( - instances=progress or [], enabled=bool(progress) + backend, model = await resolve_backend( + backend=backend, + target=target, + model=model, + console=console, + **(backend_kwargs or {}), ) + processor = await resolve_processor( + processor=processor, model=model, console=console + ) + request_loader = await resolve_request_loader( + data=data, + model=model, + data_args=data_args, + data_samples=data_samples, + processor=processor, + processor_args=processor_args, + data_column_mapper=data_column_mapper, + data_request_formatter=data_request_formatter, + data_collator=data_collator, + data_sampler=data_sampler, + data_num_workers=data_num_workers, + random_seed=random_seed, + console=console, + **(dataloader_kwargs or {}), + ) + profile = await resolve_profile( + profile=profile, + rate=rate, + random_seed=random_seed, + constraints=constraints, + max_seconds=max_seconds, + max_requests=max_requests, + max_errors=max_errors, + max_error_rate=max_error_rate, + max_global_error_rate=max_global_error_rate, + console=console, + ) + output_formats = await resolve_output_formats( + output_formats=output_formats, output_path=output_path, console=console + ) + report = GenerativeBenchmarksReport() console.print_update( title="Setup complete, starting benchmarks...", status="success" ) console.print("\n\n") - async for ( - _aggregator_update, - benchmark, - _strategy, - _scheduler_state, - ) in progress_group( - profile, - Benchmarker[ - GenerativeBenchmark, - GenerationRequest, - GenerationResponse, - ]().run( - requests=request_loader, - backend=backend, - profile=profile, - environment=NonDistributedEnvironment(), - benchmark_aggregators=aggregators, - benchmark_class=GenerativeBenchmark, - ), + benchmarker: Benchmarker[ + GenerativeBenchmark, GenerationRequest, GenerationResponse + ] = Benchmarker() + async for benchmark in benchmarker.run( + benchmark_class=benchmark_cls, + requests=request_loader, + backend=backend, + profile=profile, + environment=NonDistributedEnvironment(), + progress=progress, + sample_requests=sample_requests, + warmup=warmup, + cooldown=cooldown, + prefer_response_metrics=True, ): if benchmark: report.benchmarks.append(benchmark) - output_format_results = await finalize_outputs(report, resolved_output_formats) + output_format_results = {} + for key, output in output_formats.items(): + output_result = await output.finalize(report) + output_format_results[key] = output_result console.print("\n\n") console.print_update( - title=f"Benchmarking complete; generated {len(report.benchmarks)} benchmark(s)", + title=f"Benchmarking complete, generated {len(report.benchmarks)} benchmark(s)", status="success", ) for key, value in output_format_results.items(): @@ -320,12 +412,13 @@ async def reimport_benchmarks_report( ) -> tuple[GenerativeBenchmarksReport, dict[str, Any]]: """ The command-line entry point for re-importing and displaying an - existing benchmarks report. Can also specify an output format. + existing benchmarks report. Can also specify Assumes the file provided exists. """ console = Console() + with console.print_update_step( - title=f"Loading benchmarks from {file}" + title=f"Loading benchmarks from {file}..." ) as console_step: report = GenerativeBenchmarksReport.load_file(file) console_step.finish( @@ -333,17 +426,13 @@ async def reimport_benchmarks_report( f" loaded {len(report.benchmarks)} benchmark(s)" ) - with console.print_update_step(title="Resolving output formats") as console_step: - resolved_output_formats = await resolve_output_formats( - output_formats, output_path - ) - console_step.finish( - title="Output formats resolved", - details={key: str(val) for key, val in resolved_output_formats.items()}, - status_level="success", - ) - - output_format_results = await finalize_outputs(report, resolved_output_formats) + output_formats = await resolve_output_formats( + output_formats, output_path, console=console + ) + output_format_results = {} + for key, output in output_formats.items(): + output_result = await output.finalize(report) + output_format_results[key] = output_result for key, value in output_format_results.items(): console.print_update(title=f" {key:<8}: {value}", status="debug") diff --git a/src/guidellm/benchmark/objects.py b/src/guidellm/benchmark/objects.py deleted file mode 100644 index 8afabba9..00000000 --- a/src/guidellm/benchmark/objects.py +++ /dev/null @@ -1,473 +0,0 @@ -""" -Benchmark data models and metrics for performance measurement and analysis. - -Provides comprehensive data structures for capturing, storing, and analyzing -benchmark results from scheduler executions. Includes timing measurements, -token statistics, and performance metrics for generative AI workloads. - -Classes: - BenchmarkSchedulerStats: Scheduler timing and performance statistics. - BenchmarkMetrics: Core benchmark metrics and distributions. - BenchmarkRequestStats: Individual request processing statistics. - Benchmark: Base benchmark result container with generic metrics. - GenerativeRequestStats: Request statistics for generative AI workloads. - GenerativeMetrics: Comprehensive metrics for generative benchmarks. - GenerativeBenchmark: Complete generative benchmark results and analysis. - GenerativeBenchmarksReport: Container for multiple benchmark results. - -Type Variables: - BenchmarkMetricsT: Generic benchmark metrics type. - BenchmarkRequestStatsT: Generic request statistics type. - BenchmarkT: Generic benchmark container type. -""" - -from __future__ import annotations - -import json -import uuid -from pathlib import Path -from typing import Any, ClassVar, Generic, Literal, TypeVar - -import yaml -from pydantic import Field, computed_field - -from guidellm.benchmark.profile import ( - Profile, -) -from guidellm.scheduler import ( - ScheduledRequestInfo, - SchedulerState, - SchedulingStrategy, -) -from guidellm.utils import ( - StandardBaseDict, - StandardBaseModel, - StatusBreakdown, - StatusDistributionSummary, -) - -__all__ = [ - "Benchmark", - "BenchmarkMetrics", - "BenchmarkSchedulerStats", - "BenchmarkT", - "GenerativeBenchmark", - "GenerativeBenchmarksReport", - "GenerativeMetrics", - "GenerativeRequestStats", -] - - -class BenchmarkSchedulerStats(StandardBaseDict): - """Scheduler timing and performance statistics.""" - - start_time: float = Field( - description="Unix timestamp when the benchmark run started" - ) - end_time: float = Field(description="Unix timestamp when the benchmark run ended") - requests_made: StatusBreakdown[int, int, int, int] = Field( - description="Request counts by status: successful, incomplete, errored, total" - ) - queued_time_avg: float = Field( - description="Avg time requests spent in the queue (seconds)" - ) - worker_resolve_start_delay_avg: float = Field( - description="Avg delay before worker begins resolving req after dequeue (sec)" - ) - worker_resolve_time_avg: float = Field( - description="Avg time for worker to resolve requests (seconds)" - ) - worker_resolve_end_delay_avg: float = Field( - description="Avg delay after request end till worker resolves (seconds)" - ) - finalized_delay_avg: float = Field( - description="Avg delay after resolve til finalized with in scheduler (sec)" - ) - worker_targeted_start_delay_avg: float = Field( - description="Avg delay from targeted start to actual worker start (seconds)" - ) - request_start_delay_avg: float = Field( - description="Avg delay after resolve til request start (seconds)" - ) - request_time_avg: float = Field(description="Avg request processing time (seconds)") - request_targeted_start_delay_avg: float = Field( - description="Avg delay from targeted start to actual request start" - ) - - -class SchedulerDict(StandardBaseDict): - """Scheduler configuration and execution state dictionary.""" - - strategy: SchedulingStrategy - constraints: dict[str, dict[str, Any]] - state: SchedulerState - - -class BenchmarkerDict(StandardBaseDict): - """Benchmarker configuration and component settings dictionary.""" - - profile: Profile - requests: dict[str, Any] - backend: dict[str, Any] - environment: dict[str, Any] - aggregators: dict[str, dict[str, Any]] - - -class BenchmarkMetrics(StandardBaseDict): - """Core benchmark metrics and statistical distributions.""" - - requests_per_second: StatusDistributionSummary = Field( - description="Distribution of requests per second across benchmark execution" - ) - request_concurrency: StatusDistributionSummary = Field( - description="Distribution of concurrent request counts during execution" - ) - request_latency: StatusDistributionSummary = Field( - description="Distribution of request latencies for completed requests" - ) - - -BenchmarkMetricsT = TypeVar("BenchmarkMetricsT", bound=BenchmarkMetrics) - - -class BenchmarkRequestStats(StandardBaseDict): - """Individual request processing statistics and scheduling metadata.""" - - scheduler_info: ScheduledRequestInfo = Field( - description="Scheduler metadata and timing information for the request" - ) - - -BenchmarkRequestStatsT = TypeVar("BenchmarkRequestStatsT", bound=BenchmarkRequestStats) - - -class Benchmark(StandardBaseDict, Generic[BenchmarkMetricsT, BenchmarkRequestStatsT]): - """Base benchmark result container with execution metadata.""" - - type_: Literal["benchmark"] = "benchmark" - id_: str = Field( - default_factory=lambda: str(uuid.uuid4()), - description="Unique identifier for this benchmark execution", - ) - run_id: str = Field( - description="Identifier for the benchmarker run containing this benchmark" - ) - run_index: int = Field( - description="Sequential index of this benchmark within the benchmarker run" - ) - scheduler: SchedulerDict = Field( - description="Scheduler configuration and execution state" - ) - benchmarker: BenchmarkerDict = Field( - description="Benchmarker configuration and component settings" - ) - env_args: StandardBaseDict = Field( - description="Environment arguments and runtime configuration" - ) - extras: StandardBaseDict = Field( - description="Additional metadata and custom benchmark parameters" - ) - run_stats: BenchmarkSchedulerStats = Field( - description="Scheduler timing and performance statistics" - ) - start_time: float = Field( - default=-1.0, description="Unix timestamp when the first request was initiated" - ) - end_time: float = Field( - default=-1.0, description="Unix timestamp when the last request completed" - ) - - @computed_field # type: ignore[misc] - @property - def duration(self) -> float: - """ - Benchmark execution duration in seconds. - - :return: Time elapsed from first request start to last request completion. - """ - return self.end_time - self.start_time - - metrics: BenchmarkMetricsT = Field( - description="Performance metrics and statistical distributions" - ) - request_totals: StatusBreakdown[int, int, int, int] = Field( - description="Request counts by status: successful, incomplete, errored, total" - ) - requests: StatusBreakdown[ - list[BenchmarkRequestStatsT], - list[BenchmarkRequestStatsT], - list[BenchmarkRequestStatsT], - None, - ] = Field( - description="Request details grouped by status: successful, incomplete, errored" - ) - - -BenchmarkT = TypeVar("BenchmarkT", bound=Benchmark) - - -class GenerativeRequestStats(BenchmarkRequestStats): - """Request statistics for generative AI text generation workloads.""" - - type_: Literal["generative_request_stats"] = "generative_request_stats" - request_id: str = Field(description="Unique identifier for the request") - request_type: Literal["text_completions", "chat_completions"] = Field( - description="Type of generative request: text or chat completion" - ) - prompt: str = Field(description="Input text prompt for generation") - request_args: dict[str, Any] = Field( - description="Generation parameters and configuration options" - ) - output: str | None = Field( - description="Generated text output, if request completed successfully" - ) - iterations: int = Field( - description="Number of processing iterations for the request" - ) - prompt_tokens: int | None = Field( - description="Number of tokens in the input prompt" - ) - output_tokens: int | None = Field( - description="Number of tokens in the generated output" - ) - - @computed_field # type: ignore[misc] - @property - def total_tokens(self) -> int | None: - """ - Total token count including prompt and output tokens. - - :return: Sum of prompt and output tokens, or None if either is unavailable. - """ - if self.prompt_tokens is None and self.output_tokens is None: - return None - - return (self.prompt_tokens or 0) + (self.output_tokens or 0) - - @computed_field # type: ignore[misc] - @property - def request_latency(self) -> float | None: - """ - End-to-end request processing latency in seconds. - - :return: Duration from request start to completion, or None if unavailable. - """ - if ( - not self.scheduler_info.request_timings.request_end - or not self.scheduler_info.request_timings.request_start - ): - return None - - return ( - self.scheduler_info.request_timings.request_end - - self.scheduler_info.request_timings.request_start - ) - - @computed_field # type: ignore[misc] - @property - def time_to_first_token_ms(self) -> float | None: - """ - Time to first token generation in milliseconds. - - :return: Latency from request start to first token, or None if unavailable. - """ - if ( - not self.scheduler_info.request_timings.first_iteration - or not self.scheduler_info.request_timings.request_start - ): - return None - - return 1000 * ( - self.scheduler_info.request_timings.first_iteration - - self.scheduler_info.request_timings.request_start - ) - - @computed_field # type: ignore[misc] - @property - def time_per_output_token_ms(self) -> float | None: - """ - Average time per output token in milliseconds. - - Includes time for first token and all subsequent tokens. - - :return: Average milliseconds per output token, or None if unavailable. - """ - if ( - not self.scheduler_info.request_timings.request_start - or not self.scheduler_info.request_timings.last_iteration - or not self.output_tokens - ): - return None - - return ( - 1000 - * ( - self.scheduler_info.request_timings.last_iteration - - self.scheduler_info.request_timings.request_start - ) - / self.output_tokens - ) - - @computed_field # type: ignore[misc] - @property - def inter_token_latency_ms(self) -> float | None: - """ - Average inter-token latency in milliseconds. - - Measures time between token generations, excluding first token. - - :return: Average milliseconds between tokens, or None if unavailable. - """ - if ( - not self.scheduler_info.request_timings.first_iteration - or not self.scheduler_info.request_timings.last_iteration - or not self.output_tokens - or self.output_tokens <= 1 - ): - return None - - return ( - 1000 - * ( - self.scheduler_info.request_timings.last_iteration - - self.scheduler_info.request_timings.first_iteration - ) - / (self.output_tokens - 1) - ) - - @computed_field # type: ignore[misc] - @property - def tokens_per_second(self) -> float | None: - """ - Overall token throughput including prompt and output tokens. - - :return: Total tokens per second, or None if unavailable. - """ - if not (latency := self.request_latency) or not (tokens := self.total_tokens): - return None - - return tokens / latency - - @computed_field # type: ignore[misc] - @property - def output_tokens_per_second(self) -> float | None: - """ - Output token generation throughput. - - :return: Output tokens per second, or None if unavailable. - """ - if not (latency := self.request_latency) or not self.output_tokens: - return None - - return self.output_tokens / latency - - -class GenerativeMetrics(BenchmarkMetrics): - """Comprehensive metrics for generative AI benchmarks.""" - - prompt_token_count: StatusDistributionSummary = Field( - description="Distribution of prompt token counts by request status" - ) - output_token_count: StatusDistributionSummary = Field( - description="Distribution of output token counts by request status" - ) - total_token_count: StatusDistributionSummary = Field( - description="Distribution of total token counts by request status" - ) - time_to_first_token_ms: StatusDistributionSummary = Field( - description="Distribution of first token latencies in milliseconds" - ) - time_per_output_token_ms: StatusDistributionSummary = Field( - description="Distribution of average time per output token in milliseconds" - ) - inter_token_latency_ms: StatusDistributionSummary = Field( - description="Distribution of inter-token latencies in milliseconds" - ) - output_tokens_per_second: StatusDistributionSummary = Field( - description="Distribution of output token generation rates" - ) - tokens_per_second: StatusDistributionSummary = Field( - description="Distribution of total token throughput including prompt and output" - ) - - -class GenerativeBenchmark(Benchmark[GenerativeMetrics, GenerativeRequestStats]): - """Complete generative AI benchmark results with specialized metrics.""" - - type_: Literal["generative_benchmark"] = "generative_benchmark" # type: ignore[assignment] - - -class GenerativeBenchmarksReport(StandardBaseModel): - """Container for multiple benchmark results with load/save functionality.""" - - DEFAULT_FILE: ClassVar[str] = "benchmarks.json" - - @staticmethod - def load_file( - path: str | Path, type_: Literal["json", "yaml"] | None = None - ) -> GenerativeBenchmarksReport: - """ - Load a report from a file. - - :param path: The path to load the report from. - :param type_: File type override, auto-detected from extension if None. - :return: The loaded report. - :raises ValueError: If file type is unsupported. - """ - path = Path(path) if not isinstance(path, Path) else path - - if path.is_dir(): - path = path / GenerativeBenchmarksReport.DEFAULT_FILE - - path.parent.mkdir(parents=True, exist_ok=True) - path_suffix = path.suffix.lower()[1:] - - with path.open("r") as file: - if (type_ or path_suffix) == "json": - model_dict = json.loads(file.read()) - elif (type_ or path_suffix) in ["yaml", "yml"]: - model_dict = yaml.safe_load(file) - else: - raise ValueError(f"Unsupported file type: {type_} for {path}.") - - return GenerativeBenchmarksReport.model_validate(model_dict) - - benchmarks: list[GenerativeBenchmark] = Field( - description="The list of completed benchmarks contained within the report.", - default_factory=list, - ) - - def save_file( - self, path: str | Path | None, type_: Literal["json", "yaml"] | None = None - ) -> Path: - """ - Save the report to a file. - - :param path: The path to save the report to. - :param type_: File type override, auto-detected from extension if None. - :return: The path to the saved report. - :raises ValueError: If file type is unsupported. - """ - if path is None: - path = Path.cwd() - elif not isinstance(path, Path): - path = Path(path) - - if path.is_dir(): - path = path / GenerativeBenchmarksReport.DEFAULT_FILE - - path.parent.mkdir(parents=True, exist_ok=True) - path_suffix = path.suffix.lower()[1:] - model_dict = self.model_dump() - - if (type_ or path_suffix) == "json": - save_str = json.dumps(model_dict) - elif (type_ or path_suffix) in ["yaml", "yml"]: - save_str = yaml.dump(model_dict) - else: - raise ValueError(f"Unsupported file type: {type_} for {path}.") - - with path.open("w") as file: - file.write(save_str) - - return path diff --git a/src/guidellm/benchmark/output.py b/src/guidellm/benchmark/output.py index cacadc94..1e92c7a9 100644 --- a/src/guidellm/benchmark/output.py +++ b/src/guidellm/benchmark/output.py @@ -15,17 +15,17 @@ from rich.padding import Padding from rich.text import Text -from guidellm.benchmark.objects import ( - GenerativeBenchmark, - GenerativeBenchmarksReport, - GenerativeMetrics, -) from guidellm.benchmark.profile import ( AsyncProfile, ConcurrentProfile, SweepProfile, ThroughputProfile, ) +from guidellm.benchmark.schemas import ( + GenerativeBenchmark, + GenerativeBenchmarksReport, + GenerativeMetrics, +) from guidellm.presentation import UIDataBuilder from guidellm.presentation.injector import create_report from guidellm.settings import settings diff --git a/src/guidellm/benchmark/profile.py b/src/guidellm/benchmark/profile.py index 87a9a2be..8564afde 100644 --- a/src/guidellm/benchmark/profile.py +++ b/src/guidellm/benchmark/profile.py @@ -55,7 +55,7 @@ from guidellm.utils import PydanticClassRegistryMixin if TYPE_CHECKING: - from guidellm.benchmark.objects import Benchmark + from guidellm.benchmark.schemas import Benchmark __all__ = [ "AsyncProfile", @@ -665,9 +665,9 @@ def next_strategy( return SynchronousStrategy() if prev_strategy.type_ == "synchronous": - self.synchronous_rate = ( - prev_benchmark.metrics.requests_per_second.successful.mean - ) + self.synchronous_rate = prev_benchmark.get_request_metrics_sample()[ + "request_throughput" + ] return ThroughputStrategy( max_concurrency=self.max_concurrency, @@ -675,9 +675,9 @@ def next_strategy( ) if prev_strategy.type_ == "throughput": - self.throughput_rate = ( - prev_benchmark.metrics.requests_per_second.successful.mean - ) + self.throughput_rate = prev_benchmark.get_request_metrics_sample()[ + "request_throughput" + ] if self.synchronous_rate <= 0 and self.throughput_rate <= 0: raise RuntimeError( "Invalid rates in sweep; aborting. " diff --git a/src/guidellm/benchmark/progress.py b/src/guidellm/benchmark/progress.py index f93b3a83..558def67 100644 --- a/src/guidellm/benchmark/progress.py +++ b/src/guidellm/benchmark/progress.py @@ -16,9 +16,7 @@ from __future__ import annotations -import asyncio from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, AsyncIterator, Iterable from dataclasses import dataclass from datetime import datetime from typing import Any, Generic, Literal @@ -37,21 +35,16 @@ TimeRemainingColumn, ) -from guidellm.benchmark.aggregator import AggregatorState -from guidellm.benchmark.objects import BenchmarkT, GenerativeBenchmark from guidellm.benchmark.profile import Profile -from guidellm.scheduler import ( - SchedulerState, - SchedulingStrategy, - StrategyType, +from guidellm.benchmark.schemas import ( + BenchmarkT, + EstimatedBenchmarkState, + GenerativeBenchmark, ) +from guidellm.scheduler import SchedulerState, SchedulingStrategy, StrategyType from guidellm.utils import Colors, format_value_display -__all__ = [ - "BenchmarkerProgress", - "BenchmarkerProgressGroup", - "GenerativeConsoleBenchmarkerProgress", -] +__all__ = ["BenchmarkerProgress", "GenerativeConsoleBenchmarkerProgress"] class BenchmarkerProgress(Generic[BenchmarkT], ABC): @@ -63,106 +56,15 @@ class BenchmarkerProgress(Generic[BenchmarkT], ABC): enable/disable functionality for conditional progress tracking. """ - def __init__(self, enabled: bool = True): + def __init__(self): """ Initialize progress tracker. :param enabled: Whether to enable progress tracking and display. """ - self._enabled = enabled self.profile: Profile = None self.current_strategy: SchedulingStrategy = None - @property - def enabled(self) -> bool: - """ - :return: Whether progress tracking is currently enabled. - """ - return self._enabled - - @enabled.setter - def enabled(self, value: bool) -> None: - """ - :param value: True to enable progress tracking, False to disable. - :raises RuntimeError: If called after progress run has started. - """ - if self.profile is not None: - raise RuntimeError( - "Cannot change enabled state after __call__ for progress run" - ) - - self._enabled = value - - def __call__( - self, - profile: Profile, - agen: AsyncIterable[ - tuple[ - AggregatorState | None, - BenchmarkT | None, - SchedulingStrategy, - SchedulerState | None, - ] - ], - ) -> AsyncIterator[ - tuple[ - AggregatorState | None, - BenchmarkT | None, - SchedulingStrategy, - SchedulerState | None, - ] - ]: - """ - Track progress through benchmark execution pipeline. - - Wraps the provided async generator to monitor benchmark progress, - calling appropriate lifecycle hooks based on execution state. - - :param profile: Benchmark profile configuration. - :param agen: Async generator yielding benchmark execution updates. - :return: Async iterator forwarding original updates with progress tracking. - """ - - async def aiterator() -> AsyncIterator[ - tuple[ - AggregatorState | None, - BenchmarkT | None, - SchedulingStrategy, - SchedulerState | None, - ] - ]: - self.profile = profile - if self.enabled: - await self.on_initialize(profile) - - async for aggregator_update, benchmark, strategy, scheduler_state in agen: - if self.enabled: - await self.on_raw_update( - profile, - aggregator_update, - benchmark, - strategy, - scheduler_state, - ) - - if self.current_strategy != strategy: - self.current_strategy = strategy - await self.on_benchmark_start(strategy) - elif benchmark is not None: - await self.on_benchmark_complete(benchmark) - self.current_strategy = None - else: - await self.on_benchmark_update( - aggregator_update, scheduler_state - ) - - yield aggregator_update, benchmark, strategy, scheduler_state - - if self.enabled: - await self.on_finalize() - - return aiterator() - @abstractmethod async def on_initialize(self, profile: Profile): """ @@ -181,12 +83,12 @@ async def on_benchmark_start(self, strategy: SchedulingStrategy): @abstractmethod async def on_benchmark_update( - self, aggregator_update: AggregatorState, scheduler_state: SchedulerState + self, estimated_state: EstimatedBenchmarkState, scheduler_state: SchedulerState ): """ Handle benchmark execution progress update. - :param aggregator_update: Current benchmark metrics and statistics. + :param estimated_state: Current benchmark metrics and statistics. :param scheduler_state: Current scheduler execution state. """ @@ -202,153 +104,6 @@ async def on_benchmark_complete(self, benchmark: BenchmarkT): async def on_finalize(self): """Finalize progress tracking and cleanup resources.""" - async def on_raw_update( - self, - profile: Profile, - aggregator_update: AggregatorState | None, - benchmark: BenchmarkT | None, - strategy: SchedulingStrategy, - scheduler_state: SchedulerState | None, - ): - """ - Handle raw benchmark execution update. - - Optional hook for accessing all execution state updates. Default - implementation does nothing. - - :param profile: Benchmark profile configuration. - :param aggregator_update: Current benchmark metrics and statistics. - :param benchmark: Completed benchmark if available. - :param strategy: Current scheduling strategy. - :param scheduler_state: Current scheduler execution state. - """ - - -class BenchmarkerProgressGroup(BenchmarkerProgress[BenchmarkT]): - """ - Composite progress handler that manages multiple progress instances. - - Distributes progress events to all contained progress instances, enabling - parallel progress tracking through multiple channels (e.g., console display - and file logging). - - :param instances: Collection of progress handlers to manage. - :param enabled: Whether the group is active. - """ - - def __init__( - self, - instances: ( - Iterable[BenchmarkerProgress[BenchmarkT]] - | list[BenchmarkerProgress[BenchmarkT]] - ), - enabled: bool = True, - ): - """ - Initialize progress group with handler instances. - - :param instances: Progress handler instances to coordinate. - :param enabled: Whether to enable the progress group. - """ - self.instances: list[BenchmarkerProgress[BenchmarkT]] = list(instances) - super().__init__(enabled=enabled) - - @property - def enabled(self) -> bool: - """Whether the progress group is currently enabled.""" - return self._enabled - - @enabled.setter - def enabled(self, value: bool): - """ - Set enabled state for group and all contained instances. - - :param value: New enabled state. - """ - self._enabled = value - for instance in self.instances: - instance.enabled = value - - async def on_initialize(self, profile: Profile): - """ - Initialize all progress handler instances. - - :param profile: Benchmark profile configuration. - """ - await asyncio.gather( - *[child.on_initialize(profile) for child in self.instances] - ) - - async def on_benchmark_start(self, strategy: SchedulingStrategy): - """ - Notify all handlers of benchmark strategy start. - - :param strategy: Scheduling strategy being executed. - """ - await asyncio.gather( - *[child.on_benchmark_start(strategy) for child in self.instances] - ) - - async def on_benchmark_update( - self, aggregator_update: AggregatorState, scheduler_state: SchedulerState - ): - """ - Distribute benchmark updates to all handlers. - - :param aggregator_update: Current benchmark metrics and statistics. - :param scheduler_state: Current scheduler execution state. - """ - await asyncio.gather( - *[ - child.on_benchmark_update(aggregator_update, scheduler_state) - for child in self.instances - ] - ) - - async def on_benchmark_complete(self, benchmark: BenchmarkT): - """ - Notify all handlers of benchmark completion. - - :param benchmark: Completed benchmark results. - """ - await asyncio.gather( - *[child.on_benchmark_complete(benchmark) for child in self.instances] - ) - - async def on_finalize(self): - """Finalize all progress handler instances.""" - await asyncio.gather(*[child.on_finalize() for child in self.instances]) - - async def on_raw_update( - self, - profile: Profile, - aggregator_update: AggregatorState | None, - benchmark: BenchmarkT | None, - strategy: SchedulingStrategy, - scheduler_state: SchedulerState | None, - ): - """ - Distribute raw updates to all handlers. - - :param profile: Benchmark profile configuration. - :param aggregator_update: Current benchmark metrics and statistics. - :param benchmark: Completed benchmark if available. - :param strategy: Current scheduling strategy. - :param scheduler_state: Current scheduler execution state. - """ - await asyncio.gather( - *[ - child.on_raw_update( - profile, - aggregator_update, - benchmark, - strategy, - scheduler_state, - ) - for child in self.instances - ] - ) - class GenerativeConsoleBenchmarkerProgress( BenchmarkerProgress[GenerativeBenchmark], Live @@ -361,14 +116,14 @@ class GenerativeConsoleBenchmarkerProgress( bars in a structured console interface. """ - def __init__(self, enabled: bool = True, display_scheduler_stats: bool = False): + def __init__(self, display_scheduler_stats: bool = False): """ Initialize console progress display. :param enabled: Whether to enable progress tracking and display. :param display_scheduler_stats: Whether to display scheduler statistics. """ - BenchmarkerProgress.__init__(self, enabled=enabled) + BenchmarkerProgress.__init__(self) Live.__init__( self, refresh_per_second=4, @@ -432,7 +187,9 @@ async def on_benchmark_start(self, strategy: SchedulingStrategy): self._sync_run_progress() async def on_benchmark_update( - self, aggregator_update: AggregatorState | None, scheduler_state: SchedulerState + self, + aggregator_update: EstimatedBenchmarkState | None, + scheduler_state: SchedulerState, ): """ Update display with current benchmark progress. @@ -545,7 +302,9 @@ def start_benchmark(self, strategy: SchedulingStrategy): ) def update_benchmark( - self, aggregator_update: AggregatorState, scheduler_state: SchedulerState + self, + aggregator_update: EstimatedBenchmarkState, + scheduler_state: SchedulerState, ): self.benchmark_task_states[self.current_index].update( aggregator_update, scheduler_state @@ -800,71 +559,75 @@ def start(self, strategy: SchedulingStrategy): self.strategy_type = strategy.type_ def update( - self, aggregator_update: AggregatorState, scheduler_state: SchedulerState + self, + estimated_state: EstimatedBenchmarkState, + scheduler_state: SchedulerState, ): self.progress = ( (1.0 - scheduler_state.remaining_fraction) if scheduler_state.remaining_fraction is not None else 0.0 ) - status: Literal["in_warmup", "in_progress", "in_cooldown"] | None = ( - "in_progress" # Need to handle requests_in_* isn't in aggregator_update - ) - if aggregator_update.get("requests_in_warmup"): - status = "in_warmup" - elif aggregator_update.get("requests_in_cooldown"): - status = "in_cooldown" self._update_processing_states( - benchmark_status=status, + benchmark_status=estimated_state.get_metric( + group=EstimatedBenchmarkState.benchmark_state_group, + key="status", + default=None, + ), start_time=scheduler_state.start_time, successful_requests=scheduler_state.successful_requests, cancelled_requests=scheduler_state.cancelled_requests, errored_requests=scheduler_state.errored_requests, ) self._update_request_stats( - request_concurrency=aggregator_update.get_metric( - key="requests", type_="avg", prefix="completed" + request_concurrency=estimated_state.get_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="concurrency_requests", ), - requests_per_second=aggregator_update.get_metric( - key="requests", - type_="rate", - prefix="completed", + requests_per_second=estimated_state.get_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="completed_requests_per_second", ), - request_latency=aggregator_update.get_metric( - key="request_latency", type_="avg", prefix="completed" + request_latency=estimated_state.get_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="completed_request_latency", ), ) self._update_token_stats( - output_tokens=aggregator_update.get_metric( - key="output_tokens", type_="avg", prefix="completed" + output_tokens=estimated_state.get_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="completed_output_tokens_total", ), - output_tokens_rate=aggregator_update.get_metric( - key="output_tokens", type_="rate" + output_tokens_rate=estimated_state.get_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="completed_output_tokens", ), - prompt_tokens=aggregator_update.get_metric( - key="prompt_tokens", type_="avg", prefix="completed" + prompt_tokens=estimated_state.get_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="completed_input_tokens_total", ), - total_tokens_rate=aggregator_update.get_metric( - key="total_tokens", type_="rate" + total_tokens_rate=estimated_state.get_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="completed_total_tokens", ), - time_to_first_token=( - aggregator_update.get_metric(key="time_to_first_token", type_="avg") + time_to_first_token=estimated_state.get_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="completed_time_to_first_token", ), - inter_token_latency=( - aggregator_update.get_metric(key="inter_token_latency", type_="avg") + inter_token_latency=estimated_state.get_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="completed_inter_token_latency", ), ) - if aggregator_update.get("updated_scheduler_stats"): + if estimated_state.get("updated_scheduler_stats"): self._update_system_stats( - request_targeted_start_delay=( - aggregator_update.get_metric( - key="request_targeted_start_delay", type_="avg", default=0.0 - ) + request_targeted_start_delay=estimated_state.get_metric( + group=EstimatedBenchmarkState.scheduler_state_group, + key="request_targeted_start_delay", ), - queued_time=( - aggregator_update.get_metric( - key="queued_time", type_="avg", default=0.0 - ) + queued_time=estimated_state.get_metric( + group=EstimatedBenchmarkState.scheduler_state_group, + key="queued_time", ), scheduler_overheads_time=0.0, # Need to add up metrics here ) diff --git a/src/guidellm/benchmark/scenario.py b/src/guidellm/benchmark/scenario.py index 73a9a050..59cdef27 100644 --- a/src/guidellm/benchmark/scenario.py +++ b/src/guidellm/benchmark/scenario.py @@ -9,11 +9,11 @@ import yaml from loguru import logger -from pydantic import BeforeValidator, Field, PositiveFloat, PositiveInt, SkipValidation +from pydantic import BeforeValidator, Field, PositiveFloat, PositiveInt from guidellm.backends import Backend, BackendType from guidellm.benchmark.profile import Profile, ProfileType -from guidellm.benchmark.types import AggregatorInputT, DataInputT, ProcessorInputT +from guidellm.benchmark.types import ProcessorInputT from guidellm.scheduler import StrategyType from guidellm.utils import StandardBaseModel @@ -108,11 +108,7 @@ class Config: # types like PreTrainedTokenizerBase arbitrary_types_allowed = True - data: Annotated[ - DataInputT, - # BUG: See https://github.com/pydantic/pydantic/issues/9541 - SkipValidation, - ] + data: Any profile: StrategyType | ProfileType | Profile rate: Annotated[list[PositiveFloat] | None, BeforeValidator(parse_float_list)] = ( None @@ -128,7 +124,6 @@ class Config: data_args: dict[str, Any] | None = None data_sampler: Literal["random"] | None = None # Aggregators configuration - add_aggregators: AggregatorInputT | None = None warmup: Annotated[float | None, Field(gt=0, le=1)] = None cooldown: Annotated[float | None, Field(gt=0, le=1)] = None request_samples: PositiveInt | None = 20 diff --git a/src/guidellm/benchmark/schemas.py b/src/guidellm/benchmark/schemas.py new file mode 100644 index 00000000..62ae5b0e --- /dev/null +++ b/src/guidellm/benchmark/schemas.py @@ -0,0 +1,1392 @@ +""" +Benchmark data models and metrics for performance measurement and analysis. + +Provides comprehensive data structures for capturing, storing, and analyzing +benchmark results from scheduler executions. Includes timing measurements, +token statistics, and performance metrics for generative AI workloads. + +Classes: + BenchmarkSchedulerStats: Scheduler timing and performance statistics. + BenchmarkMetrics: Core benchmark metrics and distributions. + BenchmarkRequestStats: Individual request processing statistics. + Benchmark: Base benchmark result container with generic metrics. + GenerativeRequestStats: Request statistics for generative AI workloads. + GenerativeMetrics: Comprehensive metrics for generative benchmarks. + GenerativeBenchmark: Complete generative benchmark results and analysis. + GenerativeBenchmarksReport: Container for multiple benchmark results. + +Type Variables: + BenchmarkMetricsT: Generic benchmark metrics type. + BenchmarkRequestStatsT: Generic request statistics type. + BenchmarkT: Generic benchmark container type. +""" + +from __future__ import annotations + +import json +import random +import time +import uuid +from abc import ABC, abstractmethod +from collections.abc import Iterable +from pathlib import Path +from typing import Any, ClassVar, Literal, TypeVar, cast + +import yaml +from pydantic import Field, computed_field + +from guidellm.benchmark.profile import Profile +from guidellm.scheduler import ( + BackendInterface, + Environment, + SchedulerState, + SchedulingStrategy, +) +from guidellm.schemas import ( + GenerationRequest, + GenerationResponse, + GenerativeRequestStats, + RequestInfo, +) +from guidellm.schemas.request import UsageMetrics +from guidellm.utils import ( + InfoMixin, + StandardBaseDict, + StandardBaseModel, + StatusBreakdown, + StatusDistributionSummary, +) +from guidellm.utils.pydantic_utils import StandardBaseDict + +__all__ = [ + "Benchmark", + "BenchmarkArgs", + "BenchmarkSchedulerStats", + "BenchmarkT", + "BenchmarkerDict", + "EstimatedBenchmarkState", + "GenerativeAudioMetricsSummary", + "GenerativeBenchmark", + "GenerativeBenchmarksReport", + "GenerativeImageMetricsSummary", + "GenerativeMetrics", + "GenerativeMetricsSummary", + "GenerativeTextMetricsSummary", + "GenerativeVideoMetricsSummary", + "SchedulerDict", +] + + +class EstimatedBenchmarkState(dict[str, Any]): + benchmark_state_group: ClassVar[Literal["benchmark_state"]] = "benchmark_state" + benchmark_metrics_group: ClassVar[Literal["benchmark_metrics"]] = ( + "benchmark_metrics" + ) + scheduler_state_group: ClassVar[Literal["scheduler_state"]] = "scheduler_state" + + def get_metric( + self, + group: str, + key: str, + default: int | float | None = None, + ) -> int | float | None: + return self.get(f"{group}_{key}", default) + + def set_metric( + self, + group: str, + key: str, + value: bool | int | float | None, + start_val: bool | int | float | None = None, + ) -> bool | int | float | None: + if value is None: + return None + + if start_val is not None: + value -= start_val + self[f"{group}_{key}"] = value + + return value + + def add_avg_metric( + self, + group: str, + key: str, + value: bool | int | float | None, + start_val: bool | int | float | None = 0.0, + count: int | None = 1, + ): + if value is None or count is None: + return + + if start_val is not None: + value -= start_val + + total_key = f"{group}_{key}_total" + count_key = f"{group}_{key}_count" + self[total_key] = self.get(total_key, 0) + value + self[count_key] = self.get(count_key, 0) + count + + average = self[total_key] / self[count_key] + self.set_metric( + group=group, + key=key, + value=average, + ) + + def add_avg_rate_metric( + self, + group: str, + key: str, + value: bool | int | float | None, + start_val: bool | int | float | None = 0.0, + start_time: float | None = None, + end_time: float | None = None, + numerator_type: Literal["avg", "total", "count"] = "total", + ): + if value is None: + return + + self.add_avg_metric( + group=group, + key=key, + value=value, + start_val=start_val, + ) + start_time_key = f"{group}_{key}_start_time" + if self.get(start_time_key) is None: + if start_time is None: + start_time = time.time() + self[start_time_key] = start_time + else: + self[start_time_key] = start_time or self[start_time_key] + + end_time = end_time or time.time() + elapsed_time = end_time - self[start_time_key] + + if elapsed_time > 0: + numerator_key = ( + f"{group}_{key}_{numerator_type}" + if numerator_type != "avg" + else f"{group}_{key}" + ) + rate = self[numerator_key] / elapsed_time + self.set_metric( + group=group, + key=f"{key}_per_second", + value=rate, + ) + + def add_time_averaged_metric( + self, + group: str, + key: str, + value: bool | int | float | None, + recorded_time: float | None = None, + ): + if value is None: + return + + if recorded_time is None: + recorded_time = time.time() + + time_avg_numerator_key = f"{group}_{key}_time_avg_numerator" + time_avg_denominator_key = f"{group}_{key}_time_avg_denominator" + last_recorded_time_key = f"{group}_{key}_last_recorded_time" + + if last_recorded_time_key not in self: + self[last_recorded_time_key] = recorded_time + self[time_avg_numerator_key] = value + self[time_avg_denominator_key] = 0.0 + else: + time_delta = recorded_time - self[last_recorded_time_key] + self[time_avg_numerator_key] += value * time_delta + self[time_avg_denominator_key] += time_delta + self[last_recorded_time_key] = recorded_time + + if self[time_avg_denominator_key] > 0: + average = self[time_avg_numerator_key] / self[time_avg_denominator_key] + else: + average = value + + self.set_metric( + group=group, + key=key, + value=average, + ) + + +class BenchmarkArgs(StandardBaseDict): + run_id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Unique identifier for the benchmark run", + ) + run_index: int = Field(default=0, description="Index of the benchmark run") + sample_requests: int | None = Field( + default=20, + description="Number of requests to sample and keep in the final benchmark for metrics", + ) + warmup: int | float | None = Field( + default=None, description="Warmup time before benchmarking starts" + ) + cooldown: int | float | None = Field( + default=None, description="Cooldown time after benchmarking ends" + ) + prefer_response_metrics: bool = Field( + default=True, + description="Whether to prefer response metrics over request metrics", + ) + + def is_in_warmup( + self, request_info: RequestInfo, scheduler_state: SchedulerState + ) -> bool: + if self.warmup is not None and 0 < self.warmup < 1: + # Percentage-based warmup + return ( + scheduler_state.remaining_fraction is not None + and scheduler_state.remaining_fraction > (1 - self.warmup) + ) + + if self.warmup is not None and self.warmup > 1: + # Count/time-based warmup + if scheduler_state.processed_requests < self.warmup: + return True + + current_time = request_info.timings.targeted_start + return ( + current_time is not None + and (current_time - scheduler_state.start_time) < self.warmup + ) + + return False + + def is_in_cooldown( + self, request_info: RequestInfo, scheduler_state: SchedulerState + ) -> bool: + if self.cooldown is not None and 0 < self.cooldown < 1: + # Percentage-based cooldown + return ( + scheduler_state.remaining_fraction is not None + and scheduler_state.remaining_fraction < self.cooldown + ) + + if self.cooldown is not None and self.cooldown > 1: + # Count/time-based cooldown + if ( + scheduler_state.remaining_requests is not None + and scheduler_state.remaining_requests <= self.cooldown + ): + return True + + current_time = ( + request_info.timings.resolve_end or request_info.timings.targeted_start + ) + return ( + current_time is not None + and scheduler_state.remaining_duration is not None + and scheduler_state.remaining_duration < self.cooldown + ) + + return False + + +class Benchmark(ABC): + @abstractmethod + def get_run_metrics_sample( + self, + ) -> dict[Literal["start_time", "end_time", "duration"], float]: ... + + @abstractmethod + def get_request_metrics_sample( + self, + ) -> dict[ + Literal[ + "request_count", + "request_latency", + "request_throughput", + "request_concurrency", + ], + float, + ]: ... + + @classmethod + @abstractmethod + def update_estimate( + cls, + args: BenchmarkArgs, + state: EstimatedBenchmarkState, + response: Any, + request: Any, + request_info: RequestInfo, + scheduler_state: SchedulerState, + ): ... + + @classmethod + @abstractmethod + def compile( + cls, + args: BenchmarkArgs, + estimated_state: EstimatedBenchmarkState, + scheduler_state: SchedulerState, + profile: Profile, + requests: Iterable, + backend: BackendInterface, + environment: Environment, + strategy: SchedulingStrategy, + constraints: dict[str, dict[str, Any]], + ) -> Any: ... + + +BenchmarkT = TypeVar("BenchmarkT", bound=Benchmark) + + +class BenchmarkSchedulerStats(StandardBaseDict): + """Scheduler timing and performance statistics.""" + + group_name: ClassVar[Literal["scheduler_stats"]] = "scheduler_stats" + + start_time: float = Field( + description="Unix timestamp when the benchmark run started" + ) + end_time: float = Field(description="Unix timestamp when the benchmark run ended") + requests_made: StatusBreakdown[int, int, int, int] = Field( + description="Request counts by status: successful, incomplete, errored, total" + ) + queued_time_avg: float = Field( + description="Avg time requests spent in the queue (seconds)" + ) + worker_resolve_start_delay_avg: float = Field( + description="Avg delay before worker begins resolving req after dequeue (sec)" + ) + worker_resolve_time_avg: float = Field( + description="Avg time for worker to resolve requests (seconds)" + ) + worker_resolve_end_delay_avg: float = Field( + description="Avg delay after request end till worker resolves (seconds)" + ) + finalized_delay_avg: float = Field( + description="Avg delay after resolve til finalized with in scheduler (sec)" + ) + worker_targeted_start_delay_avg: float = Field( + description="Avg delay from targeted start to actual worker start (seconds)" + ) + request_start_delay_avg: float = Field( + description="Avg delay after resolve til request start (seconds)" + ) + request_time_avg: float = Field(description="Avg request processing time (seconds)") + request_targeted_start_delay_avg: float = Field( + description="Avg delay from targeted start to actual request start" + ) + + @classmethod + def update_estimate(cls, state: EstimatedBenchmarkState, request_info: RequestInfo): + state.set_metric(group=cls.group_name, key="updated", value=True) + state.add_avg_metric( + group=cls.group_name, + key="queued_time", + value=request_info.timings.dequeued, + start_val=request_info.timings.queued, + ) + state.add_avg_metric( + group=cls.group_name, + key="worker_resolve_start_delay", + value=request_info.timings.resolve_start, + start_val=request_info.timings.scheduled_at, + ) + state.add_avg_metric( + group=cls.group_name, + key="worker_resolve_time", + value=request_info.timings.resolve_end, + start_val=request_info.timings.resolve_start, + ) + state.add_avg_metric( + group=cls.group_name, + key="worker_resolve_end_delay", + value=request_info.timings.request_end, + start_val=request_info.timings.resolve_end, + ) + state.add_avg_metric( + group=cls.group_name, + key="finalized_delay", + value=request_info.timings.finalized, + start_val=request_info.timings.resolve_end, + ) + state.add_avg_metric( + group=cls.group_name, + key="worker_targeted_start_delay", + value=request_info.timings.resolve_start, + start_val=request_info.timings.targeted_start, + ) + state.add_avg_metric( + group=cls.group_name, + key="request_start_delay", + value=request_info.timings.request_start, + start_val=request_info.timings.resolve_start, + ) + state.add_avg_metric( + group=cls.group_name, + key="request_time", + value=request_info.timings.request_end, + start_val=request_info.timings.request_start, + ) + state.add_avg_metric( + group=cls.group_name, + key="request_targeted_start_delay", + value=request_info.timings.request_start, + start_val=request_info.timings.targeted_start, + ) + + @classmethod + def compile( + cls, estimated_state: EstimatedBenchmarkState, scheduler_state: SchedulerState + ) -> BenchmarkSchedulerStats: + return BenchmarkSchedulerStats( + start_time=scheduler_state.start_time, + end_time=scheduler_state.end_time or scheduler_state.start_time, + requests_made=StatusBreakdown[int, int, int, int]( + successful=scheduler_state.successful_requests, + incomplete=scheduler_state.cancelled_requests, + errored=scheduler_state.errored_requests, + total=( + scheduler_state.successful_requests + + scheduler_state.cancelled_requests + + scheduler_state.errored_requests + ), + ), + queued_time_avg=cast( + "float", + estimated_state.get_metric( + group=cls.group_name, key="queued_time", default=-1.0 + ), + ), + worker_resolve_start_delay_avg=cast( + "float", + estimated_state.get_metric( + group=cls.group_name, key="worker_resolve_start_delay", default=-1.0 + ), + ), + worker_resolve_time_avg=cast( + "float", + estimated_state.get_metric( + group=cls.group_name, key="worker_resolve_time", default=-1.0 + ), + ), + worker_resolve_end_delay_avg=cast( + "float", + estimated_state.get_metric( + group=cls.group_name, key="worker_resolve_end_delay", default=-1.0 + ), + ), + finalized_delay_avg=cast( + "float", + estimated_state.get_metric( + group=cls.group_name, key="finalized_delay", default=-1.0 + ), + ), + worker_targeted_start_delay_avg=cast( + "float", + estimated_state.get_metric( + group=cls.group_name, + key="worker_targeted_start_delay", + default=-1.0, + ), + ), + request_start_delay_avg=cast( + "float", + estimated_state.get_metric( + group=cls.group_name, key="request_start_delay", default=-1.0 + ), + ), + request_time_avg=cast( + "float", + estimated_state.get_metric( + group=cls.group_name, key="request_time", default=-1.0 + ), + ), + request_targeted_start_delay_avg=cast( + "float", + estimated_state.get_metric( + group=cls.group_name, + key="request_targeted_start_delay", + default=-1.0, + ), + ), + ) + + +class GenerativeMetricsSummary(StandardBaseDict): + input: StatusDistributionSummary = Field(description="") + input_per_second: StatusDistributionSummary = Field(description="") + input_concurrency: StatusDistributionSummary = Field(description="") + + output: StatusDistributionSummary = Field(description="") + output_per_second: StatusDistributionSummary = Field(description="") + output_concurrency: StatusDistributionSummary = Field(description="") + + total: StatusDistributionSummary = Field(description="") + total_per_second: StatusDistributionSummary = Field(description="") + total_concurrency: StatusDistributionSummary = Field(description="") + + @classmethod + def compile( + cls, + request_types: list[Literal["successful", "incomplete", "error"]], + request_times: list[tuple[float, float]], + input_values: list[int | float], + output_values: list[int | float], + ) -> GenerativeMetricsSummary: + total_values = [ + input_val + output_val + for input_val, output_val in zip(input_values, output_values, strict=False) + ] + + return GenerativeMetricsSummary( + input=StatusDistributionSummary.from_values( + value_types=request_types, + values=input_values, + ), + input_per_second=StatusDistributionSummary.from_request_times( + request_types=request_types, + requests=request_times, + distribution_type="rate", + weights=input_values, + ), + input_concurrency=StatusDistributionSummary.from_request_times( + request_types=request_types, + requests=request_times, + distribution_type="concurrency", + weights=input_values, + ), + output=StatusDistributionSummary.from_values( + value_types=request_types, + values=output_values, + ), + output_per_second=StatusDistributionSummary.from_request_times( + request_types=request_types, + requests=request_times, + distribution_type="rate", + weights=output_values, + ), + output_concurrency=StatusDistributionSummary.from_request_times( + request_types=request_types, + requests=request_times, + distribution_type="concurrency", + weights=output_values, + ), + total=StatusDistributionSummary.from_values( + value_types=request_types, + values=total_values, + ), + total_per_second=StatusDistributionSummary.from_request_times( + request_types=request_types, + requests=request_times, + distribution_type="rate", + weights=total_values, + ), + total_concurrency=StatusDistributionSummary.from_request_times( + request_types=request_types, + requests=request_times, + distribution_type="concurrency", + weights=total_values, + ), + ) + + +class GenerativeTextMetricsSummary(StandardBaseDict): + tokens: GenerativeMetricsSummary = Field(description="") + words: GenerativeMetricsSummary = Field(description="") + characters: GenerativeMetricsSummary = Field(description="") + + @classmethod + def compile( + cls, + request_types: list[Literal["successful", "incomplete", "error"]], + request_times: list[tuple[float, float]], + input_metrics: list[UsageMetrics], + output_metrics: list[UsageMetrics], + ) -> GenerativeTextMetricsSummary: + return GenerativeTextMetricsSummary( + tokens=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.text_tokens or 0 for metrics in input_metrics], + output_values=[metrics.text_tokens or 0 for metrics in output_metrics], + ), + words=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.text_words or 0 for metrics in input_metrics], + output_values=[metrics.text_words or 0 for metrics in output_metrics], + ), + characters=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[ + metrics.text_characters or 0 for metrics in input_metrics + ], + output_values=[ + metrics.text_characters or 0 for metrics in output_metrics + ], + ), + ) + + +class GenerativeImageMetricsSummary(StandardBaseDict): + tokens: GenerativeMetricsSummary = Field(description="") + images: GenerativeMetricsSummary = Field(description="") + pixels: GenerativeMetricsSummary = Field(description="") + bytes: GenerativeMetricsSummary = Field(description="") + + @classmethod + def compile( + cls, + request_types: list[Literal["successful", "incomplete", "error"]], + request_times: list[tuple[float, float]], + input_metrics: list[UsageMetrics], + output_metrics: list[UsageMetrics], + ) -> GenerativeImageMetricsSummary: + return GenerativeImageMetricsSummary( + tokens=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.image_tokens or 0 for metrics in input_metrics], + output_values=[metrics.image_tokens or 0 for metrics in output_metrics], + ), + images=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.image_count or 0 for metrics in input_metrics], + output_values=[metrics.image_count or 0 for metrics in output_metrics], + ), + pixels=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.image_pixels or 0 for metrics in input_metrics], + output_values=[metrics.image_pixels or 0 for metrics in output_metrics], + ), + bytes=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.image_bytes or 0 for metrics in input_metrics], + output_values=[metrics.image_bytes or 0 for metrics in output_metrics], + ), + ) + + +class GenerativeVideoMetricsSummary(StandardBaseDict): + tokens: GenerativeMetricsSummary = Field(description="") + frames: GenerativeMetricsSummary = Field(description="") + seconds: GenerativeMetricsSummary = Field(description="") + bytes: GenerativeMetricsSummary = Field(description="") + + @classmethod + def compile( + cls, + request_types: list[Literal["successful", "incomplete", "error"]], + request_times: list[tuple[float, float]], + input_metrics: list[UsageMetrics], + output_metrics: list[UsageMetrics], + ) -> GenerativeVideoMetricsSummary: + return GenerativeVideoMetricsSummary( + tokens=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.video_tokens or 0 for metrics in input_metrics], + output_values=[metrics.video_tokens or 0 for metrics in output_metrics], + ), + frames=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.video_frames or 0 for metrics in input_metrics], + output_values=[metrics.video_frames or 0 for metrics in output_metrics], + ), + seconds=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.video_seconds or 0 for metrics in input_metrics], + output_values=[ + metrics.video_seconds or 0 for metrics in output_metrics + ], + ), + bytes=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.video_bytes or 0 for metrics in input_metrics], + output_values=[metrics.video_bytes or 0 for metrics in output_metrics], + ), + ) + + +class GenerativeAudioMetricsSummary(StandardBaseDict): + tokens: GenerativeMetricsSummary = Field(description="") + samples: GenerativeMetricsSummary = Field(description="") + seconds: GenerativeMetricsSummary = Field(description="") + bytes: GenerativeMetricsSummary = Field(description="") + + @classmethod + def compile( + cls, + request_types: list[Literal["successful", "incomplete", "error"]], + request_times: list[tuple[float, float]], + input_metrics: list[UsageMetrics], + output_metrics: list[UsageMetrics], + ) -> GenerativeAudioMetricsSummary: + return GenerativeAudioMetricsSummary( + tokens=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.audio_tokens or 0 for metrics in input_metrics], + output_values=[metrics.audio_tokens or 0 for metrics in output_metrics], + ), + samples=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.audio_samples or 0 for metrics in input_metrics], + output_values=[ + metrics.audio_samples or 0 for metrics in output_metrics + ], + ), + seconds=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.audio_seconds or 0 for metrics in input_metrics], + output_values=[ + metrics.audio_seconds or 0 for metrics in output_metrics + ], + ), + bytes=GenerativeMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_values=[metrics.audio_bytes or 0 for metrics in input_metrics], + output_values=[metrics.audio_bytes or 0 for metrics in output_metrics], + ), + ) + + +class GenerativeMetrics(StandardBaseDict): + """Comprehensive metrics for generative AI benchmarks.""" + + # Request stats + requests_per_second: StatusDistributionSummary = Field( + description="Distribution of requests per second across benchmark execution" + ) + request_concurrency: StatusDistributionSummary = Field( + description="Distribution of concurrent request counts during execution" + ) + request_latency: StatusDistributionSummary = Field( + description="Distribution of request latencies for completed requests" + ) + + # General token stats + prompt_token_count: StatusDistributionSummary = Field( + description="Distribution of prompt token counts by request status" + ) + output_token_count: StatusDistributionSummary = Field( + description="Distribution of output token counts by request status" + ) + total_token_count: StatusDistributionSummary = Field( + description="Distribution of total token counts by request status" + ) + time_to_first_token_ms: StatusDistributionSummary = Field( + description="Distribution of first token latencies in milliseconds" + ) + time_per_output_token_ms: StatusDistributionSummary = Field( + description="Distribution of average time per output token in milliseconds" + ) + inter_token_latency_ms: StatusDistributionSummary = Field( + description="Distribution of inter-token latencies in milliseconds" + ) + output_tokens_per_second: StatusDistributionSummary = Field( + description="Distribution of output token generation rates" + ) + tokens_per_second: StatusDistributionSummary = Field( + description="Distribution of total token throughput including prompt and output" + ) + + # Domain specific stats + text: GenerativeTextMetricsSummary = Field(description="") + image: GenerativeImageMetricsSummary = Field(description="") + video: GenerativeVideoMetricsSummary = Field(description="") + audio: GenerativeAudioMetricsSummary = Field(description="") + + @classmethod + def update_estimate( + cls, + state: EstimatedBenchmarkState, + response: GenerationResponse | None, + request: GenerationRequest, + request_info: RequestInfo, + scheduler_state: SchedulerState, + ): + # Always track concurrency + state.add_time_averaged_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="concurrency_requests", + value=scheduler_state.processing_requests, + ) + + if request_info.status not in {"completed", "errored", "cancelled"}: + return + + state.set_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="updated", + value=True, + ) + start_time = scheduler_state.start_time + end_time = request_info.timings.request_end or request_info.timings.resolve_end + duration = end_time - start_time if end_time else None + + for prefix in (request_info.status, "total"): + requests_count = ( + scheduler_state.successful_requests + if prefix == "completed" + else scheduler_state.errored_requests + if prefix == "errored" + else scheduler_state.cancelled_requests + if prefix == "cancelled" + else scheduler_state.processed_requests + ) + + # Request stats + state.set_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key=f"{prefix}_requests", + value=requests_count, + ) + state.set_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key=f"{prefix}_requests_per_second", + value=requests_count / duration if duration else None, + ) + state.add_avg_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key=f"{prefix}_request_latency", + value=( + request_info.timings.request_end or request_info.timings.resolve_end + ), + start_val=( + request_info.timings.request_start + or request_info.timings.resolve_start + ), + ) + + # Input/output token stats + state.add_avg_rate_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="input_tokens", + value=(response.input_metrics.total_tokens if response else None) + or request.input_metrics.total_tokens, + ) + state.add_avg_rate_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="input_text_tokens", + value=(response.input_metrics.text_tokens if response else None) + or request.input_metrics.text_tokens, + ) + state.add_avg_rate_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="input_images", + value=(response.input_metrics.image_count if response else None) + or request.input_metrics.image_count, + ) + state.add_avg_rate_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="input_video_frames", + value=(response.input_metrics.video_frames if response else None) + or request.input_metrics.video_frames, + ) + state.add_avg_rate_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="input_audio_seconds", + value=request.input_metrics.audio_seconds if request else None, + ) + state.add_avg_rate_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="output_tokens", + value=(response.output_metrics.total_tokens if response else None) + or request.output_metrics.total_tokens, + ) + output_tokens = ( + response.output_metrics.total_tokens if response else None + ) or request.output_metrics.total_tokens + state.add_avg_rate_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key="total_tokens", + value=output_tokens, + ) + + # General stats + state.add_avg_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key=f"{prefix}_time_to_first_token", + value=request_info.timings.first_iteration, + start_val=request_info.timings.request_start + or request_info.timings.resolve_start, + ) + state.add_avg_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key=f"{prefix}_inter_token_latency", + value=request_info.timings.last_iteration, + start_val=request_info.timings.first_iteration, + count=output_tokens - 1 + if output_tokens and output_tokens > 1 + else None, + ) + state.add_avg_metric( + group=EstimatedBenchmarkState.benchmark_metrics_group, + key=f"{prefix}_time_per_output_token", + value=( + request_info.timings.request_end or request_info.timings.resolve_end + ), + start_val=( + request_info.timings.first_iteration + or request_info.timings.request_start + or request_info.timings.resolve_start + ), + count=output_tokens, + ) + + @classmethod + def compile( + cls, + completed: list[GenerativeRequestStats], + errored: list[GenerativeRequestStats], + incomplete: list[GenerativeRequestStats], + ) -> GenerativeMetrics: + requests = completed + errored + incomplete + request_types = cast( + "list[Literal['successful', 'error', 'incomplete']]", + ["successful"] * len(completed) + + ["error"] * len(errored) + + ["incomplete"] * len(incomplete), + ) + request_times = [ + ( + req.info.timings.request_start or req.info.timings.resolve_start or 0, + req.info.timings.request_end or req.info.timings.resolve_end or 0, + ) + for req in requests + ] + input_metrics = [req.input_metrics for req in requests] + output_metrics = [req.output_metrics for req in requests] + + return GenerativeMetrics( + # Request stats + requests_per_second=StatusDistributionSummary.from_request_times( + request_types=request_types, + requests=request_times, + distribution_type="rate", + ), + request_concurrency=StatusDistributionSummary.from_request_times( + request_types=request_types, + requests=request_times, + distribution_type="concurrency", + ), + request_latency=StatusDistributionSummary.from_values( + value_types=request_types, + values=[req.request_latency or 0.0 for req in requests], + ), + # General token stats + prompt_token_count=StatusDistributionSummary.from_values( + value_types=request_types, + values=[float(req.prompt_tokens or 0) for req in requests], + ), + output_token_count=StatusDistributionSummary.from_values( + value_types=request_types, + values=[float(req.output_tokens or 0) for req in requests], + ), + total_token_count=StatusDistributionSummary.from_values( + value_types=request_types, + values=[float(req.total_tokens or 0) for req in requests], + ), + time_to_first_token_ms=StatusDistributionSummary.from_values( + value_types=request_types, + values=[req.time_to_first_token_ms or 0.0 for req in requests], + ), + time_per_output_token_ms=StatusDistributionSummary.from_values( + value_types=request_types, + values=[req.time_per_output_token_ms or 0.0 for req in requests], + ), + inter_token_latency_ms=StatusDistributionSummary.from_values( + value_types=request_types, + values=[req.inter_token_latency_ms or 0.0 for req in requests], + ), + output_tokens_per_second=StatusDistributionSummary.from_values( + value_types=request_types, + values=[req.output_tokens_per_second or 0.0 for req in requests], + ), + tokens_per_second=StatusDistributionSummary.from_values( + value_types=request_types, + values=[req.tokens_per_second or 0.0 for req in requests], + ), + # Domain-specific stats + text=GenerativeTextMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_metrics=input_metrics, + output_metrics=output_metrics, + ), + image=GenerativeImageMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_metrics=input_metrics, + output_metrics=output_metrics, + ), + video=GenerativeVideoMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_metrics=input_metrics, + output_metrics=output_metrics, + ), + audio=GenerativeAudioMetricsSummary.compile( + request_types=request_types, + request_times=request_times, + input_metrics=input_metrics, + output_metrics=output_metrics, + ), + ) + + +class SchedulerDict(StandardBaseDict): + """Scheduler configuration and execution state dictionary.""" + + strategy: SchedulingStrategy + constraints: dict[str, dict[str, Any]] + state: SchedulerState + + +class BenchmarkerDict(StandardBaseDict): + """Benchmarker configuration and component settings dictionary.""" + + args: BenchmarkArgs + profile: Profile + requests: dict[str, Any] + backend: dict[str, Any] + environment: dict[str, Any] + + +class GenerativeBenchmark(Benchmark, StandardBaseDict): + """Complete generative AI benchmark results with specialized metrics.""" + + group_name: ClassVar[Literal["generative_benchmark"]] = "generative_benchmark" + + type_: Literal["generative_benchmark"] = "generative_benchmark" # type: ignore[assignment] + id_: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Unique identifier for this benchmark execution", + ) + run_id: str = Field( + description="Identifier for the benchmarker run containing this benchmark" + ) + run_index: int = Field( + description="Sequential index of this benchmark within the benchmarker run" + ) + scheduler: SchedulerDict = Field( + description="Scheduler configuration and execution state" + ) + benchmarker: BenchmarkerDict = Field( + description="Benchmarker configuration and component settings" + ) + run_stats: BenchmarkSchedulerStats = Field( + description="Scheduler timing and performance statistics" + ) + start_time: float = Field( + default=-1.0, description="Unix timestamp when the first request was initiated" + ) + end_time: float = Field( + default=-1.0, description="Unix timestamp when the last request completed" + ) + + def get_run_metrics_sample( + self, + ) -> dict[Literal["start_time", "end_time", "duration"], float]: + return { + "start_time": self.start_time, + "end_time": self.end_time, + "duration": self.duration, + } + + def get_request_metrics_sample( + self, + ) -> dict[ + Literal[ + "request_count", + "request_latency", + "request_throughput", + "request_concurrency", + ], + float, + ]: + return { + "request_count": self.request_totals.successful, + "request_latency": self.metrics.request_latency.successful.mean, + "request_throughput": self.metrics.requests_per_second.successful.mean, + "request_concurrency": self.metrics.request_concurrency.successful.mean, + } + + @computed_field # type: ignore[misc] + @property + def duration(self) -> float: + """ + Benchmark execution duration in seconds. + + :return: Time elapsed from first request start to last request completion. + """ + return self.end_time - self.start_time + + metrics: GenerativeMetrics = Field( + description="Performance metrics and statistical distributions" + ) + request_totals: StatusBreakdown[int, int, int, int] = Field( + description="Request counts by status: successful, incomplete, errored, total" + ) + requests: StatusBreakdown[ + list[GenerativeRequestStats], + list[GenerativeRequestStats], + list[GenerativeRequestStats], + None, + ] = Field( + description="Request details grouped by status: successful, incomplete, errored" + ) + + @classmethod + def update_estimate( + cls, + args: BenchmarkArgs, + state: EstimatedBenchmarkState, + response: GenerationResponse | None, + request: GenerationRequest, + request_info: RequestInfo, + scheduler_state: SchedulerState, + ): + if ( + request_info.status == "cancelled" + and request_info.timings.resolve_start is None + ): + # Cancelled requests that never started should be ignored + return + + # Update child metric groups + BenchmarkSchedulerStats.update_estimate(state, request_info) + GenerativeMetrics.update_estimate( + state, response, request, request_info, scheduler_state + ) + + # Store requests and sampling info, update counts + if "requests_completed" not in state: + state["requests_completed"] = [] + state["samples_completed"] = [] + state["requests_errored"] = [] + state["samples_errored"] = [] + state["requests_incomplete"] = [] + state["samples_incomplete"] = [] + in_warmup = state.set_metric( + group=EstimatedBenchmarkState.benchmark_state_group, + key="in_warmup", + value=args.is_in_warmup(request_info, scheduler_state), + ) + in_cooldown = state.set_metric( + group=EstimatedBenchmarkState.benchmark_state_group, + key="in_cooldown", + value=args.is_in_cooldown(request_info, scheduler_state), + ) + state[f"{EstimatedBenchmarkState.benchmark_state_group}_status"] = ( + "in_cooldown" + if in_cooldown + else "in_warmup" + if in_warmup + else "in_progress" + ) + + if ( + request_info.status not in {"completed", "errored", "cancelled"} + or in_warmup + or in_cooldown + ): + # Must be fully resolved to be added + return + + state.set_metric( + group=EstimatedBenchmarkState.benchmark_state_group, + key="updated", + value=True, + ) + + if response is None: + response = GenerationResponse( + request_id=request.request_id, request_args=str(request.arguments) + ) + + stats = response.compile_stats( + request, request_info, args.prefer_response_metrics + ) + + # Determine status and get corresponding lists + if request_info.status == "completed": + requests_list = state["requests_completed"] + samples_list = state["samples_completed"] + elif request_info.status == "errored": + requests_list = state["requests_errored"] + samples_list = state["samples_errored"] + else: # cancelled (incomplete) + requests_list = state["requests_incomplete"] + samples_list = state["samples_incomplete"] + + # Add to requests list + requests_list.append(stats) + current_index = len(requests_list) - 1 + + # Handle request sampling logic + if args.sample_requests is None: + # No sampling, add index to samples list + samples_list.append(current_index) + elif args.sample_requests > 0 and len(samples_list) < args.sample_requests: + # Space in samples list, add index + samples_list.append(current_index) + elif ( + args.sample_requests > 0 + and (replace_index := random.randrange(len(requests_list))) + < args.sample_requests + ): + # No space, adding based on reservoir sampling + samples_list[replace_index] = current_index + # Sampling set to 0, don't keep any requests + + @classmethod + def compile( + cls, + args: BenchmarkArgs, + estimated_state: EstimatedBenchmarkState, + scheduler_state: SchedulerState, + profile: Profile, + requests: Iterable, + backend: BackendInterface, + environment: Environment, + strategy: SchedulingStrategy, + constraints: dict[str, dict[str, Any]], + ) -> GenerativeBenchmark: + return GenerativeBenchmark( + run_id=args.run_id, + run_index=args.run_index, + scheduler=SchedulerDict( + strategy=strategy, + constraints={ + key: InfoMixin.extract_from_obj(val) + for key, val in constraints.items() + }, + state=scheduler_state, + ), + benchmarker=BenchmarkerDict( + args=args, + profile=profile, + requests=InfoMixin.extract_from_obj(requests), + backend=backend.info, + environment=environment.info, + ), + run_stats=BenchmarkSchedulerStats.compile(estimated_state, scheduler_state), + start_time=scheduler_state.start_time or -1.0, + end_time=scheduler_state.end_time or -1.0, + metrics=GenerativeMetrics.compile( + completed=estimated_state.get("requests_completed", []), + errored=estimated_state.get("requests_errored", []), + incomplete=estimated_state.get("requests_incomplete", []), + ), + request_totals=StatusBreakdown[int, int, int, int]( + successful=len(estimated_state.get("requests_completed", [])), + incomplete=len(estimated_state.get("requests_incomplete", [])), + errored=len(estimated_state.get("requests_errored", [])), + total=( + len(estimated_state.get("requests_completed", [])) + + len(estimated_state.get("requests_incomplete", [])) + + len(estimated_state.get("requests_errored", [])) + ), + ), + requests=StatusBreakdown[ + list[GenerativeRequestStats], + list[GenerativeRequestStats], + list[GenerativeRequestStats], + None, + ]( + successful=estimated_state.get("requests_completed", []), + incomplete=estimated_state.get("requests_incomplete", []), + errored=estimated_state.get("requests_errored", []), + total=None, + ), + ) + + +class GenerativeBenchmarksReport(StandardBaseModel): + """Container for multiple benchmark results with load/save functionality.""" + + DEFAULT_FILE: ClassVar[str] = "benchmarks.json" + + @staticmethod + def load_file( + path: str | Path, type_: Literal["json", "yaml"] | None = None + ) -> GenerativeBenchmarksReport: + """ + Load a report from a file. + + :param path: The path to load the report from. + :param type_: File type override, auto-detected from extension if None. + :return: The loaded report. + :raises ValueError: If file type is unsupported. + """ + path = Path(path) if not isinstance(path, Path) else path + + if path.is_dir(): + path = path / GenerativeBenchmarksReport.DEFAULT_FILE + + path.parent.mkdir(parents=True, exist_ok=True) + path_suffix = path.suffix.lower()[1:] + + with path.open("r") as file: + if (type_ or path_suffix) == "json": + model_dict = json.loads(file.read()) + elif (type_ or path_suffix) in ["yaml", "yml"]: + model_dict = yaml.safe_load(file) + else: + raise ValueError(f"Unsupported file type: {type_} for {path}.") + + return GenerativeBenchmarksReport.model_validate(model_dict) + + benchmarks: list[GenerativeBenchmark] = Field( + description="The list of completed benchmarks contained within the report.", + default_factory=list, + ) + + def save_file( + self, path: str | Path | None, type_: Literal["json", "yaml"] | None = None + ) -> Path: + """ + Save the report to a file. + + :param path: The path to save the report to. + :param type_: File type override, auto-detected from extension if None. + :return: The path to the saved report. + :raises ValueError: If file type is unsupported. + """ + if path is None: + path = Path.cwd() + elif not isinstance(path, Path): + path = Path(path) + + if path.is_dir(): + path = path / GenerativeBenchmarksReport.DEFAULT_FILE + + path.parent.mkdir(parents=True, exist_ok=True) + path_suffix = path.suffix.lower()[1:] + model_dict = self.model_dump() + + if (type_ or path_suffix) == "json": + save_str = json.dumps(model_dict) + elif (type_ or path_suffix) in ["yaml", "yml"]: + save_str = yaml.dump(model_dict) + else: + raise ValueError(f"Unsupported file type: {type_} for {path}.") + + with path.open("w") as file: + file.write(save_str) + + return path diff --git a/src/guidellm/benchmark/types.py b/src/guidellm/benchmark/types.py index 1ef65a68..983e3189 100644 --- a/src/guidellm/benchmark/types.py +++ b/src/guidellm/benchmark/types.py @@ -1,44 +1,15 @@ from __future__ import annotations -from collections.abc import Iterable from pathlib import Path from typing import Any -from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict -from transformers import ( # type: ignore[import] - PreTrainedTokenizerBase, -) +from transformers import PreTrainedTokenizerBase # type: ignore[import] from typing_extensions import TypeAliasType -from guidellm.benchmark.aggregator import ( - Aggregator, - CompilableAggregator, -) -from guidellm.benchmark.output import ( - GenerativeBenchmarkerOutput, -) -from guidellm.benchmark.progress import BenchmarkerProgress - -__all__ = [ - "AggregatorInputT", - "DataInputT", - "OutputFormatT", - "ProcessorInputT", - "ProgressInputT", -] +from guidellm.benchmark.output import GenerativeBenchmarkerOutput +__all__ = ["OutputFormatT", "ProcessorInputT"] -DataInputT = TypeAliasType( - "DataInputT", - Iterable[str] - | Iterable[dict[str, Any]] - | Dataset - | DatasetDict - | IterableDataset - | IterableDatasetDict - | str - | Path, -) OutputFormatT = TypeAliasType( "OutputFormatT", @@ -49,12 +20,3 @@ ) ProcessorInputT = TypeAliasType("ProcessorInputT", str | Path | PreTrainedTokenizerBase) - -ProgressInputT = TypeAliasType( - "ProgressInputT", tuple[str, ...] | list[str] | list[BenchmarkerProgress] -) - -AggregatorInputT = TypeAliasType( - "AggregatorInputT", - dict[str, str | dict[str, Any] | Aggregator | CompilableAggregator], -) diff --git a/src/guidellm/data/__init__.py b/src/guidellm/data/__init__.py index 8a48204e..0bff1b64 100644 --- a/src/guidellm/data/__init__.py +++ b/src/guidellm/data/__init__.py @@ -1,4 +1,28 @@ -""" -Required for python < 3.12 -https://docs.python.org/3/library/importlib.resources.html#importlib.resources.files -""" +from .collators import GenerativeRequestCollator +from .deserializers import ( + DataNotSupportedError, + DatasetDeserializer, + DatasetDeserializerFactory, +) +from .loaders import DataLoader, DatasetsIterator +from .preprocessors import ( + DataDependentPreprocessor, + DatasetPreprocessor, + PreprocessorRegistry, +) +from .processor import ProcessorFactory +from .schemas import GenerativeDatasetColumnType + +__all__ = [ + "DataDependentPreprocessor", + "DataLoader", + "DataNotSupportedError", + "DatasetDeserializer", + "DatasetDeserializerFactory", + "DatasetPreprocessor", + "DatasetsIterator", + "GenerativeDatasetColumnType", + "GenerativeRequestCollator", + "PreprocessorRegistry", + "ProcessorFactory", +] diff --git a/src/guidellm/data/collators.py b/src/guidellm/data/collators.py new file mode 100644 index 00000000..f9e1ade4 --- /dev/null +++ b/src/guidellm/data/collators.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from guidellm.schemas import GenerationRequest + +__all__ = ["GenerativeRequestCollator"] + + +class GenerativeRequestCollator: + def __call__(self, batch: list) -> GenerationRequest: + if len(batch) != 1: + raise NotImplementedError( + f"Batch size greater than 1 is not currently supported. " + f"Got batch size: {len(batch)}" + ) + + return batch[0] diff --git a/src/guidellm/data/deserializers/__init__.py b/src/guidellm/data/deserializers/__init__.py new file mode 100644 index 00000000..1062f2b7 --- /dev/null +++ b/src/guidellm/data/deserializers/__init__.py @@ -0,0 +1,53 @@ +from .deserializer import ( + DataNotSupportedError, + DatasetDeserializer, + DatasetDeserializerFactory, +) +from .file import ( + ArrowFileDatasetDeserializer, + CSVFileDatasetDeserializer, + DBFileDatasetDeserializer, + HDF5FileDatasetDeserializer, + JSONFileDatasetDeserializer, + ParquetFileDatasetDeserializer, + TarFileDatasetDeserializer, + TextFileDatasetDeserializer, +) +from .huggingface import HuggingFaceDatasetDeserializer +from .memory import ( + InMemoryCsvDatasetDeserializer, + InMemoryDictDatasetDeserializer, + InMemoryDictListDatasetDeserializer, + InMemoryItemListDatasetDeserializer, + InMemoryJsonStrDatasetDeserializer, +) +from .synthetic import ( + SyntheticTextDatasetConfig, + SyntheticTextDatasetDeserializer, + SyntheticTextGenerator, + SyntheticTextPrefixBucketConfig, +) + +__all__ = [ + "ArrowFileDatasetDeserializer", + "CSVFileDatasetDeserializer", + "DBFileDatasetDeserializer", + "DataNotSupportedError", + "DatasetDeserializer", + "DatasetDeserializerFactory", + "HDF5FileDatasetDeserializer", + "HuggingFaceDatasetDeserializer", + "InMemoryCsvDatasetDeserializer", + "InMemoryDictDatasetDeserializer", + "InMemoryDictListDatasetDeserializer", + "InMemoryItemListDatasetDeserializer", + "InMemoryJsonStrDatasetDeserializer", + "JSONFileDatasetDeserializer", + "ParquetFileDatasetDeserializer", + "SyntheticTextDatasetConfig", + "SyntheticTextDatasetDeserializer", + "SyntheticTextGenerator", + "SyntheticTextPrefixBucketConfig", + "TarFileDatasetDeserializer", + "TextFileDatasetDeserializer", +] diff --git a/src/guidellm/data/deserializers/deserializer.py b/src/guidellm/data/deserializers/deserializer.py new file mode 100644 index 00000000..cb362710 --- /dev/null +++ b/src/guidellm/data/deserializers/deserializer.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import contextlib +from collections.abc import Callable +from typing import Any, Protocol, Union, runtime_checkable + +from datasets import Dataset, IterableDataset +from transformers import PreTrainedTokenizerBase + +from guidellm.data.utils import resolve_dataset_split +from guidellm.utils import RegistryMixin + +__all__ = [ + "DataNotSupportedError", + "DatasetDeserializer", + "DatasetDeserializerFactory", +] + + +class DataNotSupportedError(Exception): + """Exception raised when data format is not supported by deserializer.""" + + +@runtime_checkable +class DatasetDeserializer(Protocol): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: ... + + +class DatasetDeserializerFactory( + RegistryMixin[Union["type[DatasetDeserializer]", DatasetDeserializer]], +): + @classmethod + def deserialize( + cls, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int = 42, + type_: str | None = None, + resolve_split: bool = True, + select_columns: list[str] | None = None, + remove_columns: list[str] | None = None, + **data_kwargs: dict[str, Any], + ) -> Dataset | IterableDataset: + dataset = None + + if type_ is None: + for deserializer in cls.registered_objects(): + deserializer_fn: DatasetDeserializer = ( + deserializer() if isinstance(deserializer, type) else deserializer + ) + + with contextlib.suppress(DataNotSupportedError): + dataset = deserializer_fn( + data=data, + processor_factory=processor_factory, + random_seed=random_seed, + **data_kwargs, + ) + elif deserializer := cls.get_registered_object(type_) is not None: + deserializer_fn: DatasetDeserializer = ( + deserializer() if isinstance(deserializer, type) else deserializer + ) + + dataset = deserializer_fn( + data=data, + processor_factory=processor_factory, + random_seed=random_seed, + **data_kwargs, + ) + + if dataset is None: + raise DataNotSupportedError( + f"No suitable deserializer found for data {data} " + f"with kwargs {data_kwargs} and type_ {type_}." + ) + + if resolve_split: + dataset = resolve_dataset_split(dataset) + + if select_columns is not None or remove_columns is not None: + column_names = dataset.column_names or list(next(iter(dataset)).keys()) + if select_columns is not None: + remove_columns = [ + col for col in column_names if col not in select_columns + ] + + dataset = dataset.remove_columns(remove_columns) + + return dataset diff --git a/src/guidellm/data/deserializers/file.py b/src/guidellm/data/deserializers/file.py new file mode 100644 index 00000000..d57403db --- /dev/null +++ b/src/guidellm/data/deserializers/file.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import pandas as pd +from datasets import Dataset, load_dataset +from transformers import PreTrainedTokenizerBase + +from guidellm.data.deserializers.deserializer import ( + DataNotSupportedError, + DatasetDeserializer, + DatasetDeserializerFactory, +) + +__all__ = [ + "ArrowFileDatasetDeserializer", + "CSVFileDatasetDeserializer", + "DBFileDatasetDeserializer", + "HDF5FileDatasetDeserializer", + "JSONFileDatasetDeserializer", + "ParquetFileDatasetDeserializer", + "TarFileDatasetDeserializer", + "TextFileDatasetDeserializer", +] + + +@DatasetDeserializerFactory.register("text_file") +class TextFileDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + _ = (processor_factory, random_seed) # Ignore unused args format errors + + if ( + not isinstance(data, (str, Path)) + or not (path := Path(data)).exists() + or not path.is_file() + or path.suffix.lower() not in {".txt", ".text"} + ): + raise DataNotSupportedError( + "Unsupported data for TextFileDatasetDeserializer, " + f"expected str or Path to a local .txt or .text file, got {data}" + ) + + with path.open() as file: + lines = file.readlines() + + return Dataset.from_dict({"text": lines}, **data_kwargs) + + +@DatasetDeserializerFactory.register("csv_file") +class CSVFileDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + _ = (processor_factory, random_seed) + if ( + not isinstance(data, (str, Path)) + or not (path := Path(data)).exists() + or not path.is_file() + or path.suffix.lower() != ".csv" + ): + raise DataNotSupportedError( + "Unsupported data for CSVFileDatasetDeserializer, " + f"expected str or Path to a local .csv file, got {data}" + ) + + return load_dataset("csv", data_files=str(path), **data_kwargs) + + +@DatasetDeserializerFactory.register("json_file") +class JSONFileDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + _ = (processor_factory, random_seed) + if ( + not isinstance(data, (str, Path)) + or not (path := Path(data)).exists() + or not path.is_file() + or path.suffix.lower() not in {".json", ".jsonl"} + ): + raise DataNotSupportedError( + f"Unsupported data for JSONFileDatasetDeserializer, " + f"expected str or Path to a local .json or .jsonl file, got {data}" + ) + + return load_dataset("json", data_files=str(path), **data_kwargs) + + +@DatasetDeserializerFactory.register("parquet_file") +class ParquetFileDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + _ = (processor_factory, random_seed) + if ( + not isinstance(data, (str, Path)) + or not (path := Path(data)).exists() + or not path.is_file() + or path.suffix.lower() != ".parquet" + ): + raise DataNotSupportedError( + f"Unsupported data for ParquetFileDatasetDeserializer, " + f"expected str or Path to a local .parquet file, got {data}" + ) + + return load_dataset("parquet", data_files=str(path), **data_kwargs) + + +@DatasetDeserializerFactory.register("arrow_file") +class ArrowFileDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + _ = (processor_factory, random_seed) + if ( + not isinstance(data, (str, Path)) + or not (path := Path(data)).exists() + or not path.is_file() + or path.suffix.lower() != ".arrow" + ): + raise DataNotSupportedError( + f"Unsupported data for ArrowFileDatasetDeserializer, " + f"expected str or Path to a local .arrow file, got {data}" + ) + + return load_dataset("arrow", data_files=str(path), **data_kwargs) + + +@DatasetDeserializerFactory.register("hdf5_file") +class HDF5FileDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + _ = (processor_factory, random_seed) + if ( + not isinstance(data, (str, Path)) + or not (path := Path(data)).exists() + or not path.is_file() + or path.suffix.lower() not in {".hdf5", ".h5"} + ): + raise DataNotSupportedError( + f"Unsupported data for HDF5FileDatasetDeserializer, " + f"expected str or Path to a local .hdf5 or .h5 file, got {data}" + ) + + return Dataset.from_pandas(pd.read_hdf(str(path)), **data_kwargs) + + +@DatasetDeserializerFactory.register("db_file") +class DBFileDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + _ = (processor_factory, random_seed) + if ( + not isinstance(data, (str, Path)) + or not (path := Path(data)).exists() + or not path.is_file() + or path.suffix.lower() != ".db" + ): + raise DataNotSupportedError( + f"Unsupported data for DBFileDatasetDeserializer, " + f"expected str or Path to a local .db file, got {data}" + ) + + return Dataset.from_sql(con=str(path), **data_kwargs) + + +@DatasetDeserializerFactory.register("tar_file") +class TarFileDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + _ = (processor_factory, random_seed) + if ( + not isinstance(data, (str, Path)) + or not (path := Path(data)).exists() + or not path.is_file() + or path.suffix.lower() != ".tar" + ): + raise DataNotSupportedError( + f"Unsupported data for TarFileDatasetDeserializer, " + f"expected str or Path to a local .tar file, got {data}" + ) + + return load_dataset("webdataset", data_files=str(path), **data_kwargs) diff --git a/src/guidellm/data/deserializers/huggingface.py b/src/guidellm/data/deserializers/huggingface.py new file mode 100644 index 00000000..e356043a --- /dev/null +++ b/src/guidellm/data/deserializers/huggingface.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from collections.abc import Callable +from pathlib import Path +from typing import Any + +from datasets import ( + Dataset, + DatasetDict, + IterableDataset, + IterableDatasetDict, + load_dataset, + load_from_disk, +) +from transformers import PreTrainedTokenizerBase + +from guidellm.data.deserializers.deserializer import ( + DataNotSupportedError, + DatasetDeserializer, + DatasetDeserializerFactory, +) + +__all__ = ["HuggingFaceDatasetDeserializer"] + + +@DatasetDeserializerFactory.register("huggingface") +class HuggingFaceDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + _ = (processor_factory, random_seed) + + if isinstance( + data, (Dataset, IterableDataset, DatasetDict, IterableDatasetDict) + ): + return data + + load_error = None + + if ( + isinstance(data, (str, Path)) + and (path := Path(data)).exists() + and ((path.is_file() and path.suffix == ".py") or path.is_dir()) + ): + # Handle python script or nested python script in a directory + try: + return load_dataset(str(data), **data_kwargs) + except Exception as err: # noqa: BLE001 + load_error = err + + if ( + isinstance(data, (str, Path)) + and (path := Path(data)).exists() + and path.is_dir() + ): + # Handle local dataset directory + try: + return load_from_disk(str(data), **data_kwargs) + except Exception as err: # noqa: BLE001 + load_error = err + + try: + # Handle dataset identifier from the Hugging Face Hub + return load_dataset(str(data), **data_kwargs) + except Exception as err: # noqa: BLE001 + load_error = err + + not_supported = DataNotSupportedError( + "Unsupported data for HuggingFaceDatasetDeserializer, " + "expected Dataset, IterableDataset, DatasetDict, IterableDatasetDict, " + "str or Path to a local dataset directory or a local .py dataset script, " + f"got {data} and HF load error: {load_error}" + ) + + if load_error is not None: + raise not_supported from load_error + else: + raise not_supported diff --git a/src/guidellm/data/deserializers/memory.py b/src/guidellm/data/deserializers/memory.py new file mode 100644 index 00000000..6f8888ec --- /dev/null +++ b/src/guidellm/data/deserializers/memory.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import contextlib +import csv +import json +from collections.abc import Callable +from io import StringIO +from typing import Any, cast + +from datasets import Dataset +from transformers import PreTrainedTokenizerBase + +from guidellm.data.deserializers.deserializer import ( + DataNotSupportedError, + DatasetDeserializer, + DatasetDeserializerFactory, +) + +__all__ = [ + "InMemoryCsvDatasetDeserializer", + "InMemoryDictDatasetDeserializer", + "InMemoryDictListDatasetDeserializer", + "InMemoryItemListDatasetDeserializer", + "InMemoryJsonStrDatasetDeserializer", +] + + +@DatasetDeserializerFactory.register("in_memory_dict") +class InMemoryDictDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + _ = (processor_factory, random_seed) # Ignore unused args format errors + + if ( + not data + or not isinstance(data, dict) + or not all( + isinstance(key, str) and isinstance(val, list) + for key, val in data.items() + ) + ): + raise DataNotSupportedError( + f"Unsupported data for InMemoryDictDatasetDeserializer, " + f"expected dict[str, list], got {data}" + ) + + rows = len(list(data.values())[0]) + if not all(len(val) == rows for val in data.values()): + raise DataNotSupportedError( + "All lists in the data dictionary must have the same length, " + f"expected {rows} for all keys {list(data.keys())}" + ) + + return Dataset.from_dict(data, **data_kwargs) + + +@DatasetDeserializerFactory.register("in_memory_dict_list") +class InMemoryDictListDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + _ = (processor_factory, random_seed) # Ignore unused args format errors + + if ( + not data + or not isinstance(data, list) + or not all(isinstance(item, dict) for item in data) + or not all(isinstance(key, str) for item in data for key in item) + ): + raise DataNotSupportedError( + f"Unsupported data for InMemoryDictListDatasetDeserializer, " + f"expected list of dicts, got {data}" + ) + + data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data) + first_keys = set(data[0].keys()) + for index, item in enumerate(data): + if set(item.keys()) != first_keys: + raise DataNotSupportedError( + f"All dictionaries must have the same keys. " + f"Expected keys: {first_keys}, " + f"got keys at index {index}: {set(item.keys())}" + ) + + # Convert list of dicts to dict of lists + result_dict = {key: [] for key in first_keys} + for item in data: + for key, value in item.items(): + result_dict[key].append(value) + + return Dataset.from_dict(result_dict, **data_kwargs) + + +@DatasetDeserializerFactory.register("in_memory_item_list") +class InMemoryItemListDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + _ = (processor_factory, random_seed) # Ignore unused args format errors + + primitive_types = (str, int, float, bool, type(None)) + if ( + not data + or not isinstance(data, list) + or not all(isinstance(item, primitive_types) for item in data) + ): + raise DataNotSupportedError( + f"Unsupported data for InMemoryItemListDatasetDeserializer, " + f"expected list of primitive items, got {data}" + ) + + column_name = data_kwargs.pop("column_name", "data") + + return Dataset.from_dict({column_name: data}, **data_kwargs) + + +@DatasetDeserializerFactory.register("in_memory_json_str") +class InMemoryJsonStrDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + if ( + isinstance(data, str) + and (json_str := data.strip()) + and ( + (json_str.startswith("{") and json_str.endswith("}")) + or (json_str.startswith("[") and json_str.endswith("]")) + ) + ): + with contextlib.suppress(Exception): + parsed = json.loads(data) + + for deserializer in [ + InMemoryDictDatasetDeserializer, + InMemoryDictListDatasetDeserializer, + InMemoryItemListDatasetDeserializer, + ]: + with contextlib.suppress(DataNotSupportedError): + return deserializer()( + parsed, data_kwargs, processor_factory, random_seed + ) + + raise DataNotSupportedError( + f"Unsupported data for InMemoryJsonStrDatasetDeserializer, " + f"expected JSON string with a list or dict of items, got {data}" + ) + + +@DatasetDeserializerFactory.register("in_memory_csv_str") +class InMemoryCsvDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> dict[str, list]: + if ( + isinstance(data, str) + and (csv_str := data.strip()) + and len(csv_str.split("\n")) > 0 + ): + with contextlib.suppress(Exception): + csv_buffer = StringIO(data) + reader = csv.DictReader(csv_buffer) + rows = list(reader) + + return InMemoryDictListDatasetDeserializer()( + rows, processor_factory, random_seed, **data_kwargs + ) + + raise DataNotSupportedError( + f"Unsupported data for InMemoryCsvDatasetDeserializer, " + f"expected CSV string, got {type(data)}" + ) diff --git a/src/guidellm/data/deserializers/synthetic.py b/src/guidellm/data/deserializers/synthetic.py new file mode 100644 index 00000000..d9e415c6 --- /dev/null +++ b/src/guidellm/data/deserializers/synthetic.py @@ -0,0 +1,346 @@ +from __future__ import annotations + +import math +from collections.abc import Callable, Iterator +from pathlib import Path +from random import Random +from typing import Any + +import yaml +from datasets import Features, IterableDataset, Value +from faker import Faker +from pydantic import ConfigDict, Field, model_validator +from transformers import PreTrainedTokenizerBase + +from guidellm.data.deserializers.deserializer import ( + DataNotSupportedError, + DatasetDeserializer, + DatasetDeserializerFactory, +) +from guidellm.utils import IntegerRangeSampler, StandardBaseModel + +__all__ = [ + "SyntheticTextDatasetConfig", + "SyntheticTextDatasetDeserializer", + "SyntheticTextGenerator", + "SyntheticTextPrefixBucketConfig", +] + + +class SyntheticTextPrefixBucketConfig(StandardBaseModel): + bucket_weight: int = Field( + description="Weight of this bucket in the overall distribution.", + gt=0, + default=100, + ) + prefix_count: int = Field( + description="The number of unique prefixes to generate for this bucket.", + ge=1, + default=1, + ) + prefix_tokens: int = Field( + description="The number of prefix tokens per-prompt for this bucket.", + ge=0, + default=0, + ) + + +class SyntheticTextDatasetConfig(StandardBaseModel): + model_config = ConfigDict( + extra="allow", + ) + + prefix_buckets: list[SyntheticTextPrefixBucketConfig] | None = Field( + description="Buckets for the prefix tokens distribution.", + default=None, + ) + prompt_tokens: int = Field( + description="The average number of text tokens generated for prompts.", + gt=0, + ) + prompt_tokens_stdev: int | None = Field( + description="The standard deviation of the tokens generated for prompts.", + gt=0, + default=None, + ) + prompt_tokens_min: int | None = Field( + description="The minimum number of text tokens generated for prompts.", + gt=0, + default=None, + ) + prompt_tokens_max: int | None = Field( + description="The maximum number of text tokens generated for prompts.", + gt=0, + default=None, + ) + output_tokens: int = Field( + description="The average number of text tokens generated for outputs.", + gt=0, + ) + output_tokens_stdev: int | None = Field( + description="The standard deviation of the tokens generated for outputs.", + gt=0, + default=None, + ) + output_tokens_min: int | None = Field( + description="The minimum number of text tokens generated for outputs.", + gt=0, + default=None, + ) + output_tokens_max: int | None = Field( + description="The maximum number of text tokens generated for outputs.", + gt=0, + default=None, + ) + source: str = Field( + description="The source of the text data to be used for generation.", + default="data:prideandprejudice.txt.gz", + ) + + @model_validator(mode="after") + def check_prefix_options(self) -> SyntheticTextDatasetConfig: + prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined] + prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined] + if prefix_count is not None or prefix_tokens is not None: + if self.prefix_buckets: + raise ValueError( + "prefix_buckets is mutually exclusive" + " with prefix_count and prefix_tokens" + ) + + self.prefix_buckets = [ + SyntheticTextPrefixBucketConfig( + prefix_count=prefix_count or 1, + prefix_tokens=prefix_tokens or 0, + ) + ] + + return self + + +class SyntheticTextGenerator: + def __init__( + self, + config: SyntheticTextDatasetConfig, + processor: PreTrainedTokenizerBase, + random_seed: int = 42, + ): + self.config = config + self.processor = processor + self.random_seed = random_seed + + def __iter__(self) -> Iterator[dict[str, Any]]: + samples_generated = 0 + + faker = Faker() + faker.seed_instance(self.random_seed) + prompt_tokens_sampler = iter( + IntegerRangeSampler( + average=self.config.prompt_tokens, + variance=self.config.prompt_tokens_stdev, + min_value=self.config.prompt_tokens_min, + max_value=self.config.prompt_tokens_max, + random_seed=self.random_seed, + ) + ) + output_tokens_sampler = iter( + IntegerRangeSampler( + average=self.config.output_tokens, + variance=self.config.output_tokens_stdev, + min_value=self.config.output_tokens_min, + max_value=self.config.output_tokens_max, + random_seed=self.random_seed + 1, # ensure diff dist from prompts + ) + ) + + # Create a shared prefix if specified + rand = Random(self.random_seed + 3) + prefix_iter = self._create_prefix_iter(faker, rand) + + while True: + prompt_tokens_count = next(prompt_tokens_sampler) + output_tokens_count = next(output_tokens_sampler) + + yield { + "prefix": next(prefix_iter), + "prompt": self._create_prompt( + prompt_tokens_count, faker, f"{samples_generated} " + ), + "prompt_tokens_count": prompt_tokens_count, + "output_tokens_count": output_tokens_count, + } + samples_generated += 1 + + def _create_prompt( + self, prompt_tokens_count: int, faker: Faker, unique: str = "" + ) -> str: + prompt_token_ids = [] + avg_chars_per_token = 5 + margin_of_safety = 1.5 + attempts = 0 + + while len(prompt_token_ids) < prompt_tokens_count: + attempts += 1 + num_chars = ( + prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts + ) + text = unique + faker.text(max_nb_chars=num_chars) + prompt_token_ids = self.processor.encode(text) + + return self.processor.decode( + prompt_token_ids[:prompt_tokens_count], skip_special_tokens=True + ) + + def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]: + if not self.config.prefix_buckets: + while True: + yield "" + + # Increase weights to ensure an integer number of samples per per-prefix + least_common_prefix_count = math.lcm( + *(bucket.prefix_count for bucket in self.config.prefix_buckets) + ) + unnorm_weights = [ + least_common_prefix_count * bucket.bucket_weight // bucket.prefix_count + for bucket in self.config.prefix_buckets + ] + # Use GCD to reduce the weights to smallest integer ratio + common_divisor = math.gcd(*unnorm_weights) + + # Create prefix list maintaining the correct distribution + prefixes = [] + for bucket, weight in zip( + self.config.prefix_buckets, unnorm_weights, strict=False + ): + bucket_prefixes = [ + self._create_prompt(bucket.prefix_tokens, faker) + for _ in range(bucket.prefix_count) + ] + sample_count = weight // common_divisor + prefixes.extend(bucket_prefixes * sample_count) + + while True: + yield rand.choice(prefixes) + + +@DatasetDeserializerFactory.register("synthetic_text") +class SyntheticTextDatasetDeserializer(DatasetDeserializer): + def __call__( + self, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int, + **data_kwargs: dict[str, Any], + ) -> IterableDataset: + # Config file pathways, deserialize and call self again + if (config := self._load_config_file(data)) is not None: + return self(config, processor_factory, random_seed, **data_kwargs) + + # Config str pathways, deserialize and call self again + if (config := self._load_config_str(data)) is not None: + return self(config, processor_factory, random_seed, **data_kwargs) + + if not isinstance(data, SyntheticTextDatasetConfig): + raise DataNotSupportedError( + "Unsupported data for SyntheticTextDatasetDeserializer, " + "expected SyntheticTextDatasetConfig, str or Path to a config file, " + f"got {data}" + ) + + return IterableDataset.from_generator( + SyntheticTextGenerator, + gen_kwargs={ + "config": data, + "processor": processor_factory(), + "random_seed": random_seed, + }, + features=Features( + { + "prefix": Value("string"), + "prompt": Value("string"), + "prompt_tokens_count": Value("int32"), + "output_tokens_count": Value("int32"), + } + ), + ) + + def _load_config_file(self, data: Any) -> SyntheticTextDatasetConfig | None: + if (not isinstance(data, str) and not isinstance(data, Path)) or ( + not Path(data).is_file() + ): + return None + + data_path = Path(data) if isinstance(data, str) else data + error = None + + if Path(data).is_file() and data_path.suffix.lower() == ".json": + try: + return SyntheticTextDatasetConfig.model_validate_json( + data_path.read_text() + ) + except Exception as err: # noqa: BLE001 + error = err + + if Path(data).is_file() and data_path.suffix.lower() in { + ".yaml", + ".yml", + ".config", + }: + try: + return SyntheticTextDatasetConfig.model_validate( + yaml.safe_load(data_path.read_text()) + ) + except Exception as err: # noqa: BLE001 + error = err + + err_message = ( + f"Unsupported file {data_path} for " + f"SyntheticTextDatasetDeserializer, expected .json, " + f".yaml, .yml, or .config" + ) + + if error is not None: + err_message += f" with error: {error}" + raise DataNotSupportedError(err_message) from error + raise DataNotSupportedError(err_message) + + def _load_config_str(self, data: str) -> SyntheticTextDatasetConfig | None: + if not isinstance(data, str): + return None + + data_str = data.strip() + error = None + + if (data_str.startswith("{") and data_str.endswith("}")) or ( + data_str.startswith("[") and data_str.endswith("]") + ): + try: + return SyntheticTextDatasetConfig.model_validate_json(data_str) + except Exception as err: # noqa: BLE001 + error = err + + if data_str.count("=") > 1: + # key=value pairs separated by commas + try: + config_dict = {} + items = data_str.split(",") + for item in items: + key, value = item.split("=") + config_dict[key.strip()] = ( + int(value.strip()) + if value.strip().isnumeric() + else value.strip() + ) + + return SyntheticTextDatasetConfig.model_validate(config_dict) + except Exception as err: # noqa: BLE001 + error = err + + err_message = ( + "Unsupported string data for SyntheticTextDatasetDeserializer, " + f"expected JSON or key-value pairs, got {data}" + ) + if error is not None: + err_message += f" with error: {error}" + raise DataNotSupportedError(err_message) from error + raise DataNotSupportedError(err_message) diff --git a/src/guidellm/data/loaders.py b/src/guidellm/data/loaders.py new file mode 100644 index 00000000..fcdea15d --- /dev/null +++ b/src/guidellm/data/loaders.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import contextlib +from collections.abc import Callable, Iterator +from typing import Any, Literal + +import torch +from torch.utils.data import Sampler +from torch.utils.data.dataloader import DataLoader as PyTorchDataLoader +from torch.utils.data.dataset import IterableDataset as TorchIterableDataset +from transformers import PreTrainedTokenizerBase + +from guidellm.data.deserializers import DatasetDeserializerFactory +from guidellm.data.preprocessors import DataDependentPreprocessor, DatasetPreprocessor +from guidellm.logger import logger + +__all__ = ["DataLoader", "DatasetsIterator"] + + +class DatasetsIterator(TorchIterableDataset): + def __init__( + self, + data: list[Any], + data_args: list[dict[str, Any]] | None, + data_samples: int, + processor_factory: Callable[[], PreTrainedTokenizerBase], + preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor], + random_seed: int, + ): + if not data or not isinstance(data, list): + raise ValueError(f"Data must be a non-empty list, got {data}.") + + if not data_args: + data_args = [{} for _ in data] + + if len(data) != len(data_args): + raise ValueError( + f"Length of data ({len(data)}) must match length of data_args " + f"({len(data_args)})." + ) + + self.datasets = [] + for datum, data_kwargs in zip(data, data_args, strict=False): + self.datasets.append( + DatasetDeserializerFactory.deserialize( + data=datum, + processor_factory=processor_factory, + random_seed=random_seed, + **data_kwargs, + ) + ) + self.preprocessors = preprocessors + for preprocessor in self.preprocessors: + if isinstance(preprocessor, DataDependentPreprocessor): + preprocessor.setup_data( + datasets=self.datasets, + data_args=data_args, + ) + self.precache: list[Any] | None = ( + list(self.generator(data_samples)) if data_samples else None + ) + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + worker_modulus = worker_info.num_workers if worker_info is not None else 1 + worker_index = worker_info.id if worker_info is not None else 0 + + if self.precache is not None: + for index, item in enumerate(self.precache): + if (index + worker_index) % worker_modulus == 0: + yield item + else: + yield from self.generator(modulus=worker_modulus, offset=worker_index) + + def generator( + self, + max_items: int | None = None, + modulus: int | None = None, + offset: int | None = None, + ) -> Iterator[Any]: + gen_count = 0 + + with contextlib.suppress(StopIteration): + dataset_iters = [iter(dataset) for dataset in self.datasets] + + while max_items is None or gen_count < max_items: + try: + row = { + "items": [next(dataset_iter) for dataset_iter in dataset_iters] + } + gen_count += 1 + + if ( + modulus is not None + and offset is not None + and (gen_count % modulus) != offset + ): + continue + + for preprocessor in self.preprocessors: + row = preprocessor(row) + yield row + except Exception as err: + logger.error(f"Skipping data row due to error: {err}") + gen_count -= 1 + + if max_items is not None and gen_count < max_items: + raise ValueError( + f"Requested {max_items} samples, but only {gen_count} " + "available from the provided datasets." + ) + + +class DataLoader(PyTorchDataLoader): + def __init__( + self, + data: list[Any], + data_args: list[dict[str, Any]] | None, + data_samples: int, + processor_factory: Callable[[], PreTrainedTokenizerBase], + preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor], + collator: Callable, + sampler: Sampler[int] | Literal["shuffle"] | None = None, + num_workers: int | None = 1, + random_seed: int = 42, + **kwargs: Any, + ): + iterator = DatasetsIterator( + data=data, + data_args=data_args, + data_samples=data_samples, + processor_factory=processor_factory, + preprocessors=preprocessors, + random_seed=random_seed, + ) + + super().__init__( + dataset=iterator, + batch_size=1, + shuffle=sampler == "shuffle", + sampler=sampler if sampler != "shuffle" else None, + collate_fn=collator, + num_workers=num_workers or 0, + **kwargs, + ) diff --git a/src/guidellm/data/preprocessors/__init__.py b/src/guidellm/data/preprocessors/__init__.py new file mode 100644 index 00000000..664e196b --- /dev/null +++ b/src/guidellm/data/preprocessors/__init__.py @@ -0,0 +1,25 @@ +from .formatters import ( + GenerativeAudioTranscriptionRequestFormatter, + GenerativeAudioTranslationRequestFormatter, + GenerativeChatCompletionsRequestFormatter, + GenerativeTextCompletionsRequestFormatter, +) +from .mappers import GenerativeColumnMapper +from .preprocessor import ( + DataDependentPreprocessor, + DatasetPreprocessor, + PreprocessorRegistry, +) + +__all__ = [ + "ColumnMapper", + "ColumnMapperRegistry", + "DataDependentPreprocessor", + "DatasetPreprocessor", + "GenerativeAudioTranscriptionRequestFormatter", + "GenerativeAudioTranslationRequestFormatter", + "GenerativeChatCompletionsRequestFormatter", + "GenerativeColumnMapper", + "GenerativeTextCompletionsRequestFormatter", + "PreprocessorRegistry", +] diff --git a/src/guidellm/data/preprocessors/formatters.py b/src/guidellm/data/preprocessors/formatters.py new file mode 100644 index 00000000..ce0e46fc --- /dev/null +++ b/src/guidellm/data/preprocessors/formatters.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +from typing import Any + +from guidellm.data.preprocessors.preprocessor import ( + DatasetPreprocessor, + PreprocessorRegistry, +) +from guidellm.data.schemas import GenerativeDatasetColumnType +from guidellm.data.utils import encode_audio, encode_image, encode_video, text_stats +from guidellm.schemas import GenerationRequest, GenerationRequestArguments, UsageMetrics + +__all__ = [ + "GenerativeAudioTranscriptionRequestFormatter", + "GenerativeAudioTranslationRequestFormatter", + "GenerativeChatCompletionsRequestFormatter", + "GenerativeTextCompletionsRequestFormatter", +] + + +@PreprocessorRegistry.register("text_completions") +class GenerativeTextCompletionsRequestFormatter(DatasetPreprocessor): + def __init__( + self, + model: str, + extras: dict[str, Any] | GenerationRequestArguments | None = None, + stream: bool = True, + max_tokens: int | None = None, + max_completion_tokens: int | None = None, + ): + self.model: str | None = model + self.extras = ( + GenerationRequestArguments(**extras) + if extras and isinstance(extras, dict) + else extras + ) + self.stream: bool = stream + self.max_tokens: int | None = max_tokens or max_completion_tokens + + def __call__( + self, columns: dict[GenerativeDatasetColumnType, list[Any]] + ) -> GenerationRequest: + arguments: GenerationRequestArguments = GenerationRequestArguments(body={}) + input_metrics = UsageMetrics() + output_metrics = UsageMetrics() + + # Add model + if self.model is not None: + arguments.body["model"] = self.model + + # Configure streaming + if self.stream: + arguments.stream = True + arguments.body["stream"] = True + + # Handle output tokens + if output_tokens := sum( + count for count in columns.get("output_tokens_count_column", []) if count + ): + output_metrics.text_tokens = output_tokens + arguments.body["max_tokens"] = output_tokens + arguments.body["stop"] = None + arguments.body["ignore_eos"] = True + elif self.max_tokens is not None: + arguments.body["max_tokens"] = self.max_tokens + + # Handle prompt tokens + if prompt_tokens := sum( + count for count in columns.get("prompt_tokens_count_column", []) if count + ): + input_metrics.text_tokens = prompt_tokens + + # Apply extra arguments + if self.extras: + arguments.model_combine(self.extras) + + # Build prompt + prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre) + text = "".join(txt for txt in columns.get("text_column", []) if txt) + if prefix or text: + arguments.body["prompt"] = prefix + text + stats = text_stats(arguments.body["prompt"]) + input_metrics.text_characters = stats.get("num_chars") + input_metrics.text_words = stats.get("num_words") + + return GenerationRequest( + request_type="text_completions", + arguments=arguments, + input_metrics=input_metrics, + output_metrics=output_metrics, + ) + + +@PreprocessorRegistry.register("chat_completions") +class GenerativeChatCompletionsRequestFormatter(DatasetPreprocessor): + def __init__( + self, + model: str, + extras: dict[str, Any] | GenerationRequestArguments | None = None, + stream: bool = True, + max_tokens: int | None = None, + max_completion_tokens: int | None = None, + encode_kwargs: dict[str, Any] | None = None, + ): + self.model = model + self.extras = ( + GenerationRequestArguments(**extras) + if extras and isinstance(extras, dict) + else extras + ) + self.stream = stream + self.max_completion_tokens = max_tokens or max_completion_tokens + self.encode_image_kwargs = ( + encode_kwargs.get("image", {}) if encode_kwargs else {} + ) + self.encode_video_kwargs = ( + encode_kwargs.get("video", {}) if encode_kwargs else {} + ) + self.encode_audio_kwargs = ( + encode_kwargs.get("audio", {}) if encode_kwargs else {} + ) + + def __call__( + self, columns: dict[GenerativeDatasetColumnType, list[Any]] + ) -> GenerationRequest: + arguments = GenerationRequestArguments(body={}) + input_metrics = UsageMetrics() + output_metrics = UsageMetrics() + + # Add model + if self.model is not None: + arguments.body["model"] = self.model + + # Configure streaming + if self.stream: + arguments.stream = True + arguments.body.update( + {"stream": True, "stream_options": {"include_usage": True}} + ) + + # Handle output tokens + if output_tokens := sum( + count for count in columns.get("output_tokens_count_column", []) if count + ): + output_metrics.text_tokens = output_tokens + arguments.body.update( + { + "max_completion_tokens": output_tokens, + "stop": None, + "ignore_eos": True, + } + ) + elif self.max_completion_tokens is not None: + arguments.body["max_completion_tokens"] = self.max_completion_tokens + + # Handle prompt tokens + if prompt_tokens := sum( + count for count in columns.get("prompt_tokens_count_column", []) if count + ): + input_metrics.text_tokens = prompt_tokens + + # Apply extra arguments + if self.extras: + arguments.model_combine(self.extras) + + # Build messages + arguments.body["messages"] = [] + + for prefix in columns.get("prefix_column", []): + if not prefix: + continue + + stats = text_stats(prefix) + if (num_chars := stats.get("num_chars")) is not None: + input_metrics.text_characters = ( + input_metrics.text_characters or 0 + ) + num_chars + if (num_words := stats.get("num_words")) is not None: + input_metrics.text_words = (input_metrics.text_words or 0) + num_words + + arguments.body["messages"].append({"role": "system", "content": prefix}) + + for text in columns.get("text_column", []): + if not text: + continue + + stats = text_stats(text) + if (num_chars := stats.get("num_chars")) is not None: + input_metrics.text_characters = ( + input_metrics.text_characters or 0 + ) + num_chars + if (num_words := stats.get("num_words")) is not None: + input_metrics.text_words = (input_metrics.text_words or 0) + num_words + + arguments.body["messages"].append( + {"role": "user", "content": [{"type": "text", "text": text}]} + ) + + for image in columns.get("image_column", []): + if not image: + continue + + image_dict = encode_image(image, **self.encode_image_kwargs) + if (image_pixels := image_dict.get("image_pixels")) is not None: + input_metrics.image_pixels = ( + input_metrics.image_pixels or 0 + ) + image_pixels + if (image_bytes := image_dict.get("image_bytes")) is not None: + input_metrics.image_bytes = ( + input_metrics.image_bytes or 0 + ) + image_bytes + + arguments.body["messages"].append( + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": image_dict.get("image")} + ], + } + ) + + for video in columns.get("video_column", []): + if not video: + continue + + video_dict = encode_video(video, **self.encode_video_kwargs) + if (video_frames := video_dict.get("video_frames")) is not None: + input_metrics.video_frames = ( + input_metrics.video_frames or 0 + ) + video_frames + if (video_seconds := video_dict.get("video_seconds")) is not None: + input_metrics.video_seconds = ( + input_metrics.video_seconds or 0.0 + ) + video_seconds + if (video_bytes := video_dict.get("video_bytes")) is not None: + input_metrics.video_bytes = ( + input_metrics.video_bytes or 0 + ) + video_bytes + + arguments.body["messages"].append( + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": video_dict.get("video")} + ], + } + ) + + for audio in columns.get("audio_column", []): + if not audio: + continue + + audio_dict = encode_audio(audio, b64encode=True, **self.encode_audio_kwargs) + if (audio_samples := audio_dict.get("audio_samples")) is not None: + input_metrics.audio_samples = ( + input_metrics.audio_samples or 0 + ) + audio_samples + if (audio_seconds := audio_dict.get("audio_seconds")) is not None: + input_metrics.audio_seconds = ( + input_metrics.audio_seconds or 0.0 + ) + audio_seconds + if (audio_bytes := audio_dict.get("audio_bytes")) is not None: + input_metrics.audio_bytes = ( + input_metrics.audio_bytes or 0 + ) + audio_bytes + + arguments.body["messages"].append( + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": audio_dict.get("audio"), + "format": audio_dict.get("format"), + }, + } + ], + } + ) + + return GenerationRequest( + request_type="chat_completions", + arguments=arguments, + input_metrics=input_metrics, + output_metrics=output_metrics, + ) + + +@PreprocessorRegistry.register("audio_transcriptions") +class GenerativeAudioTranscriptionRequestFormatter(DatasetPreprocessor): + def __init__( + self, + model: str, + extras: dict[str, Any] | GenerationRequestArguments | None = None, + stream: bool = True, + encode_kwargs: dict[str, Any] | None = None, + ): + self.model = model + self.extras = ( + GenerationRequestArguments(**extras) + if extras and isinstance(extras, dict) + else extras + ) + self.stream = stream + self.encode_audio_kwargs = encode_kwargs or {} + + def __call__( # noqa: C901 + self, columns: dict[GenerativeDatasetColumnType, list[Any]] + ) -> GenerationRequest: + arguments = GenerationRequestArguments(body={}, files={}) + input_metrics = UsageMetrics() + output_metrics = UsageMetrics() + + # Add model + if self.model is not None: + arguments.body["model"] = self.model + + # Configure streaming + if self.stream: + arguments.stream = True + arguments.body["stream"] = True + + # Handle output tokens + if output_tokens := sum( + count for count in columns.get("output_tokens_count_column", []) if count + ): + output_metrics.text_tokens = output_tokens + + # Handle prompt tokens (for audio duration tracking) + if prompt_tokens := sum( + count for count in columns.get("prompt_tokens_count_column", []) if count + ): + input_metrics.text_tokens = prompt_tokens + + # Apply extra arguments + if self.extras: + arguments.model_combine(self.extras) + + # Build audio input + audio_columns = columns.get("audio_column", []) + if len(audio_columns) != 1: + raise ValueError( + f"GenerativeAudioTranscriptionRequestFormatter expects exactly " + f"one audio column, but got {len(audio_columns)}." + ) + + audio_dict = encode_audio( + audio_columns[0], b64encode=False, **self.encode_audio_kwargs + ) + input_metrics.audio_samples = audio_dict.get("audio_samples") + input_metrics.audio_seconds = audio_dict.get("audio_seconds") + input_metrics.audio_bytes = audio_dict.get("audio_bytes") + + arguments.files = { + "file": ( + audio_dict.get("file_name", "audio_input"), + audio_dict.get("audio"), + audio_dict.get("mimetype"), + ) + } + + # Build prompt + prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre) + text = "".join(txt for txt in columns.get("text_column", []) if txt) + if prefix or text: + arguments.body["prompt"] = prefix + text + stats = text_stats(arguments.body["prompt"]) + input_metrics.text_characters = stats.get("num_chars") + input_metrics.text_words = stats.get("num_words") + + return GenerationRequest( + request_type="audio_transcriptions", + arguments=arguments, + input_metrics=input_metrics, + output_metrics=output_metrics, + ) + + +@PreprocessorRegistry.register("audio_translations") +class GenerativeAudioTranslationRequestFormatter( + GenerativeAudioTranscriptionRequestFormatter +): + def __call__( + self, columns: dict[GenerativeDatasetColumnType, list[Any]] + ) -> GenerationRequest: + result = super().__call__(columns) + result.request_type = "audio_translations" + return result diff --git a/src/guidellm/data/preprocessors/mappers.py b/src/guidellm/data/preprocessors/mappers.py new file mode 100644 index 00000000..0783103b --- /dev/null +++ b/src/guidellm/data/preprocessors/mappers.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import Any, ClassVar, cast + +from datasets import Dataset, IterableDataset + +from guidellm.data.preprocessors.preprocessor import ( + DataDependentPreprocessor, + PreprocessorRegistry, +) +from guidellm.data.schemas import GenerativeDatasetColumnType + +__all__ = ["GenerativeColumnMapper"] + + +@PreprocessorRegistry.register("generative_column_mapper") +class GenerativeColumnMapper(DataDependentPreprocessor): + defaults: ClassVar[dict[str, list[str]]] = { + "prompt_tokens_count_column": ["prompt_tokens_count", "input_tokens_count"], + "output_tokens_count_column": [ + "output_tokens_count", + "completion_tokens_count", + ], + "prefix_column": [ + "system_prompt", + "system", + "prefix", + ], + "text_column": [ + "prompt", + "instruction", + "question", + "input", + "context", + "content", + "conversation", + "turn", + "text", + ], + "image_column": [ + "image", + "picture", + "photo", + "img", + ], + "video_column": [ + "video", + "clip", + "movie", + "footage", + "mp4", + "mov", + "avi", + ], + "audio_column": [ + "audio", + "sound", + "voice", + "speech", + "wav", + "mp3", + ], + } + + @classmethod + def datasets_default_mappings( + cls, datasets: list[Dataset | IterableDataset] + ) -> dict[GenerativeDatasetColumnType, list[tuple[int, str]]]: + mappings: dict[GenerativeDatasetColumnType, list[tuple[int, str]]] = ( + defaultdict(list) + ) + + for index, dataset in enumerate(datasets): + dataset_columns = dataset.column_names or list(next(iter(dataset)).keys()) + + for column_type in cls.defaults: + if column_type in mappings: + continue + + type_names = [ + variant + for name in cls.defaults.get(column_type, []) + for plural in [name, f"{name}s", f"{name}es"] + for variant in [ + plural, + plural.lower(), + plural.upper(), + plural.capitalize(), + ] + ] + + for name in type_names: + if name in dataset_columns: + key = cast("GenerativeDatasetColumnType", column_type) + mappings[key].append((index, name)) + break + + return mappings + + @classmethod + def datasets_mappings( + cls, + datasets: list[Dataset | IterableDataset], + input_mappings: dict[GenerativeDatasetColumnType, str | list[str]], + ) -> dict[GenerativeDatasetColumnType, list[tuple[int, str]]]: + mappings: dict[GenerativeDatasetColumnType, list[tuple[int, str]]] = ( + defaultdict(list) + ) + datasets_named_indices = { + ( + dataset.info.dataset_name + if dataset.info and dataset.info.dataset_name + else index + ): index + for index, dataset in enumerate(datasets) + } + datasets_columns = { + index: dataset.column_names or list(next(iter(dataset)).keys()) + for index, dataset in enumerate(datasets) + } + + # Parse out user mappings that were passed in and validate them + # Must be in the format of: + # {: []} + # where can be a single string or list of strings + # and each string can be any of: + # - a column name (assumes the first dataset was intended) + # - . where is the dataset index + # - . where is the dataset name + for column_type, names in input_mappings.items(): + mappings[column_type] = [] + for name in names if isinstance(names, list) else [names]: + if "." in name: + dataset, column_name = name.split(".", 1) + dataset_index = ( + int(dataset) + if dataset.isdigit() + else datasets_named_indices.get(dataset) + ) + else: + dataset_index = 0 + column_name = name + + if dataset_index is None or dataset_index >= len(datasets): + raise ValueError( + f"Dataset '{name}' not found in datasets: " + f"{datasets_named_indices}." + ) + if column_name not in datasets_columns[dataset_index]: + raise ValueError( + f"Column '{column_name}' not found in dataset " + f"'{datasets[dataset_index]}' " + f"columns: {datasets_columns[dataset_index]}." + ) + mappings[column_type].append((dataset_index, column_name)) + + return mappings + + def __init__( + self, + column_mappings: dict[GenerativeDatasetColumnType, str | list[str]] + | None = None, + ): + self.input_mappings = column_mappings + self.datasets_column_mappings: ( + dict[GenerativeDatasetColumnType, list[tuple[int, str]]] | None + ) + + def __call__( + self, row: dict[str, Any] + ) -> dict[GenerativeDatasetColumnType, list[Any]]: + if self.datasets_column_mappings is None: + raise ValueError("DefaultGenerativeColumnMapper not setup with data.") + + items = cast("dict[int, dict[str, Any]]", row.pop("items")) + mapped: dict[GenerativeDatasetColumnType, list[Any]] = defaultdict(list) + + for column_type, column_mappings in self.datasets_column_mappings.items(): + for ( + dataset_index, + dataset_column, + ) in column_mappings: + mapped[column_type].append(items[dataset_index][dataset_column]) + + return dict(mapped) + + def setup_data( + self, + datasets: list[Dataset | IterableDataset], + data_args: list[dict[str, Any]], + ): + _ = data_args # Unused for this mapper + self.datasets_column_mappings = ( + self.datasets_default_mappings(datasets) + if self.input_mappings is None + else self.datasets_mappings(datasets, self.input_mappings) + ) diff --git a/src/guidellm/data/preprocessors/preprocessor.py b/src/guidellm/data/preprocessors/preprocessor.py new file mode 100644 index 00000000..eefb53d3 --- /dev/null +++ b/src/guidellm/data/preprocessors/preprocessor.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import Any, Protocol, Union, runtime_checkable + +from datasets import Dataset, IterableDataset + +from guidellm.utils import RegistryMixin + +__all__ = ["DataDependentPreprocessor", "DatasetPreprocessor", "PreprocessorRegistry"] + + +@runtime_checkable +class DatasetPreprocessor(Protocol): + def __call__(self, item: dict[str, Any]) -> dict[str, Any]: ... + + +@runtime_checkable +class DataDependentPreprocessor(DatasetPreprocessor, Protocol): + def setup_data( + self, + datasets: list[Dataset | IterableDataset], + data_args: list[dict[str, Any]], + ): ... + + +class PreprocessorRegistry( + RegistryMixin[Union[DataDependentPreprocessor, type[DataDependentPreprocessor]]] +): + pass diff --git a/src/guidellm/data/prideandprejudice.txt.gz b/src/guidellm/data/prideandprejudice.txt.gz deleted file mode 100644 index 8c7a1072..00000000 Binary files a/src/guidellm/data/prideandprejudice.txt.gz and /dev/null differ diff --git a/src/guidellm/data/processor.py b/src/guidellm/data/processor.py new file mode 100644 index 00000000..645683c4 --- /dev/null +++ b/src/guidellm/data/processor.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Any + +from transformers import ( # type: ignore[import] + AutoTokenizer, + PreTrainedTokenizerBase, +) + +__all__ = ["ProcessorFactory"] + + +class ProcessorFactory: + def __init__( + self, + processor: str | PreTrainedTokenizerBase, + processor_args: dict[str, Any] | None = None, + ) -> None: + self.processor = processor + self.processor_args = processor_args or {} + + def __call__(self) -> PreTrainedTokenizerBase: + if isinstance(self.processor, PreTrainedTokenizerBase): + return self.processor + else: + self.processor = AutoTokenizer.from_pretrained( + self.processor, + **(self.processor_args or {}), + ) + return self.processor diff --git a/src/guidellm/data/schemas.py b/src/guidellm/data/schemas.py new file mode 100644 index 00000000..c4421e07 --- /dev/null +++ b/src/guidellm/data/schemas.py @@ -0,0 +1,13 @@ +from typing import Literal + +__all__ = ["GenerativeDatasetColumnType"] + +GenerativeDatasetColumnType = Literal[ + "prompt_tokens_count_column", + "output_tokens_count_column", + "prefix_column", + "text_column", + "image_column", + "video_column", + "audio_column", +] diff --git a/src/guidellm/data/utils/__init__.py b/src/guidellm/data/utils/__init__.py new file mode 100644 index 00000000..cd257898 --- /dev/null +++ b/src/guidellm/data/utils/__init__.py @@ -0,0 +1,22 @@ +from .dataset import DEFAULT_SPLITS, resolve_dataset_split +from .functions import ( + encode_audio, + encode_image, + encode_video, + get_file_format, + is_url, + resize_image, + text_stats, +) + +__all__ = [ + "DEFAULT_SPLITS", + "encode_audio", + "encode_image", + "encode_video", + "get_file_format", + "is_url", + "resize_image", + "resolve_dataset_split", + "text_stats", +] diff --git a/src/guidellm/data/utils/dataset.py b/src/guidellm/data/utils/dataset.py new file mode 100644 index 00000000..9656c1a7 --- /dev/null +++ b/src/guidellm/data/utils/dataset.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from typing import Literal + +from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict + +__all__ = ["DEFAULT_SPLITS", "resolve_dataset_split"] + + +DEFAULT_SPLITS: dict[Literal["train", "calib", "val", "test"], list[str]] = { + "train": [ + "train", + "training", + "train_set", + "training_set", + "train_dataset", + "training_dataset", + "train_data", + "training_data", + "pretrain", + "pretrain_set", + "pretrain_dataset", + "pretrain_data", + "pretraining", + ], + "calib": [ + "calibration", + "calib", + "cal", + "calibration_set", + "calib_set", + "cal_set", + "calibration_dataset", + "calib_dataset", + "cal_set", + "calibration_data", + "calib_data", + "cal_data", + ], + "val": [ + "validation", + "val", + "valid", + "validation_set", + "val_set", + "validation_dataset", + "val_dataset", + "validation_data", + "val_data", + "dev", + "dev_set", + "dev_dataset", + "dev_data", + ], + "test": [ + "test", + "testing", + "test_set", + "testing_set", + "test_dataset", + "testing_dataset", + "test_data", + "testing_data", + "eval", + "eval_set", + "eval_dataset", + "eval_data", + ], +} + + +def resolve_dataset_split( + dataset: Dataset | IterableDataset | DatasetDict | IterableDatasetDict, + split: str | None = None, +) -> Dataset | IterableDataset: + if split is not None and isinstance(dataset, (DatasetDict, IterableDatasetDict)): + if split in dataset: + return dataset[split] + + raise ValueError(f"Requested split '{split}' not found in dataset: {dataset}.") + elif split is not None: + raise ValueError( + f"Requested split '{split}' but dataset has no splits: {dataset}." + ) + + if isinstance(dataset, (Dataset, IterableDataset)): + return dataset + + for _, default_splits in DEFAULT_SPLITS.items(): + for default_split in default_splits: + if default_split in dataset: + return dataset[default_split] + + return dataset[list(dataset.keys())[0]] diff --git a/src/guidellm/data/utils/functions.py b/src/guidellm/data/utils/functions.py new file mode 100644 index 00000000..e11c5cb8 --- /dev/null +++ b/src/guidellm/data/utils/functions.py @@ -0,0 +1,385 @@ +from __future__ import annotations + +import base64 +import io +from pathlib import Path +from typing import Any, Literal + +import httpx +import librosa +import numpy as np +import soundfile +from PIL import Image as PILImage +from pydub import AudioSegment +from torch import Tensor + +__all__ = [ + "encode_audio", + "encode_image", + "encode_video", + "get_file_format", + "is_url", + "resize_image", + "text_stats", +] + + +def is_url(text: Any) -> bool: + return isinstance(text, str) and text.startswith(("http://", "https://")) + + +def text_stats( + text: str, +) -> dict[Literal["type", "text", "num_chars", "num_words"], str | int]: + """Compute basic text statistics.""" + num_chars = len(text) + num_words = len(text.split()) + + return { + "type": "text", + "text": text, + "num_chars": num_chars, + "num_words": num_words, + } + + +def encode_image( + image: bytes | str | Path | np.ndarray | PILImage.Image, + width: int | None = None, + height: int | None = None, + max_size: int | None = None, + max_width: int | None = None, + max_height: int | None = None, + encode_type: Literal["base64", "url"] | None = "base64", +) -> dict[Literal["type", "image", "image_pixels", "image_bytes"], str | int | None]: + """ + Input image types: + - bytes: raw image bytes, decoded with Pillow + - str: file path on disk, url, or already base64 encoded image string + - pathlib.Path: file path on disk + - np.ndarray: image array, decoded with Pillow + - PIL.Image.Image: Pillow image + - datasets.Image: HuggingFace datasets Image object + + max_size: maximum size of the longest edge of the image + max_width: maximum width of the image + max_height: maximum height of the image + + encode_type: None to return the supported format + (url for url, base64 string for others) + "base64" to return base64 encoded string (or download URL and encode) + "url" to return url (only if input is url, otherwise fails) + + Returns a str of either: + - image url + - "data:image/{type};base64, {data}" string + """ + if isinstance(image, str) and is_url(image): + if encode_type == "base64": + response = httpx.get(image) + response.raise_for_status() + return encode_image( + image=response.content, + max_size=max_size, + max_width=max_width, + max_height=max_height, + encode_type="base64", + ) + + if any([width, height, max_size, max_width, max_height]): + raise ValueError(f"Cannot resize image {image} when encode_type is 'url'") + + return { + "type": "image_url", + "image": image, + "image_pixels": None, + "image_bytes": None, + } + + decoded_image: PILImage.Image + + if isinstance(image, bytes): + decoded_image = PILImage.open(io.BytesIO(image)) + elif isinstance(image, str) and image.startswith("data:image/"): + _, encoded = image.split(",", 1) + image_data = base64.b64decode(encoded) + decoded_image = PILImage.open(io.BytesIO(image_data)) + elif isinstance(image, str | Path): + decoded_image = PILImage.open(image) + elif isinstance(image, np.ndarray): + decoded_image = PILImage.fromarray(image) + elif isinstance(image, PILImage.Image): + decoded_image = image + else: + raise ValueError(f"Unsupported image type: {type(image)} for {image}") + + output_image = resize_image( + decoded_image, + width=width, + height=height, + max_width=max_width, + max_height=max_height, + max_size=max_size, + ) + if output_image.mode != "RGB": + output_image = output_image.convert("RGB") + + buffer = io.BytesIO() + output_image.save(buffer, format="JPEG") + image_bytes = buffer.getvalue() + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + + return { + "type": "image_base64", + "image": f"data:image/jpeg;base64,{image_base64}", + "image_pixels": output_image.width * output_image.height, + "image_bytes": len(image_bytes), + } + + +def resize_image( + image: PILImage.Image, + width: int | None = None, + height: int | None = None, + max_width: int | None = None, + max_height: int | None = None, + max_size: int | None = None, +) -> PILImage.Image: + if not isinstance(image, PILImage.Image): + raise ValueError(f"Unsupported image type: {type(image)}") + + if width is not None and height is not None: + return image.resize((width, height), PILImage.Resampling.BILINEAR) + + orig_w, orig_h = image.size + aspect = orig_w / orig_h + + if width is not None: + target_w = width + target_h = round(width / aspect) + elif height is not None: + target_h = height + target_w = round(height * aspect) + else: + target_w, target_h = orig_w, orig_h + + # Normalize max_size → max_width/max_height + if max_size is not None: + max_width = max_width or max_size + max_height = max_height or max_size + + # Apply max constraints (preserve aspect ratio) + if max_width or max_height: + scale_w = max_width / target_w if max_width else 1.0 + scale_h = max_height / target_h if max_height else 1.0 + scale = min(scale_w, scale_h, 1.0) # never upscale + target_w = round(target_w * scale) + target_h = round(target_h * scale) + + if (target_w, target_h) != (orig_w, orig_h): + image = image.resize((target_w, target_h), PILImage.Resampling.BILINEAR) + + return image + + +def encode_video( + video: bytes | str | Path, + encode_type: Literal["base64", "url"] | None = "base64", +) -> dict[ + Literal["type", "video", "video_frames", "video_seconds", "video_bytes"], + str | int | float | None, +]: + """ + Input video types: + - bytes: raw video bytes + - str: file path on disk, url, or already base64 encoded video string + - pathlib.Path: file path on disk + - datasets.Video: HuggingFace datasets Video object + + encode_type: None to return the supported format + (url for url, base64 string for others) + "base64" to return base64 encoded string (or download URL and encode) + "url" to return url (only if input is url, otherwise fails) + + Returns a str of either: + - video url + - "data:video/{type};base64, {data}" string + """ + if isinstance(video, str) and is_url(video): + if encode_type == "base64": + response = httpx.get(video) + response.raise_for_status() + return encode_video(video=response.content, encode_type="base64") + + return { + "type": "video_url", + "video": video, + "video_frames": None, + "video_seconds": None, + "video_bytes": None, + } + + if isinstance(video, str) and video.startswith("data:video/"): + data_str = video.split(",", 1)[1] + + return { + "type": "video_base64", + "video": video, + "video_frames": None, + "video_seconds": None, + "video_bytes": len(data_str) * 3 // 4, # base64 to bytes + } + + if isinstance(video, str | Path): + path = Path(video) + video_bytes = path.read_bytes() + video_format = get_file_format(path) + elif isinstance(video, bytes): + video_bytes = video + video_format = "unknown" + else: + raise ValueError(f"Unsupported video type: {type(video)} for {video}") + + video_base64 = base64.b64encode(video).decode("utf-8") + + return { + "type": "video_base64", + "video": f"data:video/{video_format};base64,{video_base64}", + "video_frames": None, + "video_seconds": None, + "video_bytes": len(video_bytes), + } + + +def encode_audio( + audio: Any, + b64encode: bool, + sample_rate: int = 16000, + file_name: str = "audio.wav", + encode_sample_rate: int = 16000, + max_duration: float | None = None, + mono: bool = True, + audio_format: str = "mp3", + bitrate: str = "64k", +) -> dict[ + Literal[ + "type", + "audio", + "format", + "mimetype", + "audio_samples", + "audio_seconds", + "audio_bytes", + ], + str | int | float | None, +]: + if isinstance(audio, dict): + sample_rate = audio.get("sample_rate", audio.get("sampling_rate", sample_rate)) + if "data" not in audio and "url" not in audio: + raise ValueError( + f"Audio dict must contain either 'data' or 'url' keys, got {audio}" + ) + return encode_audio( + audio=audio.get("data") or audio.get("url"), + sample_rate=sample_rate, + encode_sample_rate=encode_sample_rate, + max_duration=max_duration, + mono=mono, + audio_format=audio_format, + bitrate=bitrate, + ) + + audio_numpy: np.ndarray + + if hasattr(audio, "get_samples_played_in_range"): + # HF datasets Audio object + audio_samples = audio.get_samples_played_in_range( + start_seconds=0.0, + stop_seconds=( + None + if max_duration is None + else min(max_duration, audio.metadata.duration_seconds_from_header) + ), + ) + audio_numpy = np.array(audio_samples.data) + elif isinstance(audio, Tensor): + audio_numpy = audio.numpy() + elif isinstance(audio, str | Path): + if is_url(audio): + response = httpx.get(audio) + response.raise_for_status() + audio_stream = response.content + file_name = get_file_name(audio) + else: + if not Path(audio).exists(): + raise ValueError(f"Audio file does not exist: {audio}") + file_name = get_file_name(audio) + audio_stream = Path(audio).read_bytes() + + audio_numpy, sample_rate = soundfile.read( + io.BytesIO(audio_stream), dtype="float32" + ) + elif isinstance(audio, bytes): + audio_numpy, sample_rate = soundfile.read(io.BytesIO(audio), dtype="float32") + elif isinstance(audio, np.ndarray): + audio_numpy = audio + else: + raise ValueError(f"Unsupported audio type: {type(audio)}") + + if sample_rate != encode_sample_rate: + audio_numpy = librosa.resample( + audio_numpy.astype(np.float32), + orig_sr=sample_rate, + target_sr=encode_sample_rate, + ) + sample_rate = encode_sample_rate + + audio_numpy = librosa.to_mono(audio_numpy) + + if ( + max_duration is not None + and max_duration > 0 + and (max_samples := int(max_duration * sample_rate)) < len(audio_numpy) + ): + audio_numpy = audio_numpy[max_samples:] + + audio_buffer = io.BytesIO() + + if audio_format.lower() == "mp3": + wav = io.BytesIO() + soundfile.write(wav, audio_numpy, sample_rate, format="WAV", subtype="PCM_16") + wav.seek(0) + + sound = AudioSegment.from_wav(wav) + sound.export(audio_buffer, format="mp3", bitrate=bitrate) + else: + soundfile.write(audio_buffer, audio, sample_rate, format=audio_format.upper()) + + audio_buffer.seek(0) + decoded_audio = audio_buffer.read() + + return { + "type": "audio_base64" if b64encode else "audio_file", + "audio": ( + base64.b64encode(decoded_audio).decode("utf-8") + if b64encode + else decoded_audio + ), + "file_name": file_name, + "format": audio_format, + "mimetype": f"audio/{audio_format}", + "audio_samples": len(audio_numpy), + "audio_seconds": len(audio_numpy) / sample_rate, + "audio_bytes": len(decoded_audio), + } + + +def get_file_name(path: Path | str) -> str: + """Get file name from path.""" + return Path(path).name + + +def get_file_format(path: Path | str) -> str: + """Get file format from path extension.""" + suffix = Path(path).suffix.lower() + return suffix[1:] if suffix.startswith(".") else "unknown" diff --git a/src/guidellm/dataset/__init__.py b/src/guidellm/dataset/__init__.py deleted file mode 100644 index b90b72ff..00000000 --- a/src/guidellm/dataset/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .creator import ColumnInputTypes, DatasetCreator -from .entrypoints import load_dataset -from .file import FileDatasetCreator -from .hf_datasets import HFDatasetsCreator -from .in_memory import InMemoryDatasetCreator -from .synthetic import ( - SyntheticDatasetConfig, - SyntheticDatasetCreator, - SyntheticTextItemsGenerator, -) - -__all__ = [ - "ColumnInputTypes", - "DatasetCreator", - "FileDatasetCreator", - "HFDatasetsCreator", - "InMemoryDatasetCreator", - "SyntheticDatasetConfig", - "SyntheticDatasetCreator", - "SyntheticTextItemsGenerator", - "load_dataset", -] diff --git a/src/guidellm/dataset/creator.py b/src/guidellm/dataset/creator.py deleted file mode 100644 index fe712c23..00000000 --- a/src/guidellm/dataset/creator.py +++ /dev/null @@ -1,213 +0,0 @@ -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Any, Literal - -from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict -from transformers import PreTrainedTokenizerBase # type: ignore[import] - -__all__ = ["ColumnInputTypes", "DatasetCreator"] - -ColumnInputTypes = Literal[ - "prompt_column", - "text_column", - "prompt_tokens_count_column", - "output_tokens_count_column", -] - - -class DatasetCreator(ABC): - DEFAULT_SPLITS_TRAIN = [ - "train", - "training", - "train_set", - "training_set", - "train_dataset", - "training_dataset", - "train_data", - "training_data", - "pretrain", - "pretrain_set", - "pretrain_dataset", - "pretrain_data", - "pretraining", - ] - DEFAULT_SPLITS_CALIB = [ - "calibration", - "calib", - "cal", - "calibration_set", - "calib_set", - "cal_set", - "calibration_dataset", - "calib_dataset", - "cal_set", - "calibration_data", - "calib_data", - "cal_data", - ] - DEFAULT_SPLITS_VAL = [ - "validation", - "val", - "valid", - "validation_set", - "val_set", - "validation_dataset", - "val_dataset", - "validation_data", - "val_data", - "dev", - "dev_set", - "dev_dataset", - "dev_data", - ] - DEFAULT_SPLITS_TEST = [ - "test", - "testing", - "test_set", - "testing_set", - "test_dataset", - "testing_dataset", - "test_data", - "testing_data", - "eval", - "eval_set", - "eval_dataset", - "eval_data", - ] - DEFAULT_SPLITS_DATASET: dict[str, str] = {} - - @classmethod - def create( - cls, - data: Any, - data_args: dict[str, Any] | None, - processor: str | Path | PreTrainedTokenizerBase | None, - processor_args: dict[str, Any] | None, - random_seed: int = 42, - split_pref_order: list[str] | None = None, - ) -> tuple[Dataset | IterableDataset, dict[ColumnInputTypes, str]]: - if not cls.is_supported(data, data_args): - raise ValueError(f"Unsupported data type: {type(data)} given for {data}. ") - - split = cls.extract_args_split(data_args) - column_mappings = cls.extract_args_column_mappings(data_args) - dataset = cls.handle_create( - data, data_args, processor, processor_args, random_seed - ) - - if isinstance(dataset, DatasetDict | IterableDatasetDict): - dataset = cls.extract_dataset_split(dataset, split, split_pref_order) - - if not isinstance(dataset, Dataset | IterableDataset): - raise ValueError( - f"Unsupported data type: {type(dataset)} given for {dataset}." - ) - - return dataset, column_mappings - - @classmethod - def extract_args_split(cls, data_args: dict[str, Any] | None) -> str: - split = "auto" - - if data_args and "split" in data_args: - split = data_args["split"] - del data_args["split"] - - return split - - @classmethod - def extract_args_column_mappings( - cls, - data_args: dict[str, Any] | None, - ) -> dict[ColumnInputTypes, str]: - columns: dict[ColumnInputTypes, str] = {} - - if data_args: - if "prompt_column" in data_args: - columns["prompt_column"] = data_args["prompt_column"] - del data_args["prompt_column"] - - if "prompt_tokens_count_column" in data_args: - columns["prompt_tokens_count_column"] = data_args[ - "prompt_tokens_count_column" - ] - del data_args["prompt_tokens_count_column"] - - if "output_tokens_count_column" in data_args: - columns["output_tokens_count_column"] = data_args[ - "output_tokens_count_column" - ] - del data_args["output_tokens_count_column"] - - return columns - - @classmethod - def extract_dataset_name( - cls, dataset: Dataset | IterableDataset | DatasetDict | IterableDatasetDict - ) -> str | None: - if isinstance(dataset, DatasetDict | IterableDatasetDict): - dataset = dataset[list(dataset.keys())[0]] - - if isinstance(dataset, Dataset | IterableDataset): - if not hasattr(dataset, "info") or not hasattr( - dataset.info, "dataset_name" - ): - return None - - return dataset.info.dataset_name - - raise ValueError(f"Unsupported data type: {type(dataset)} given for {dataset}.") - - @classmethod - def extract_dataset_split( - cls, - dataset: DatasetDict | IterableDatasetDict, - specified_split: Literal["auto"] | str = "auto", - split_pref_order: Literal["auto"] | list[str] | None = "auto", - ) -> Dataset | IterableDataset: - if not isinstance(dataset, DatasetDict | IterableDatasetDict): - raise ValueError( - f"Unsupported data type: {type(dataset)} given for {dataset}." - ) - - if specified_split != "auto": - if specified_split not in dataset: - raise ValueError( - f"Split {specified_split} not found in dataset {dataset}." - ) - - return dataset[specified_split] - - dataset_name = cls.extract_dataset_name(dataset) - - if dataset_name and dataset_name in cls.DEFAULT_SPLITS_DATASET: - return dataset[cls.DEFAULT_SPLITS_DATASET[dataset_name]] - - if split_pref_order == "auto": - split_pref_order = [ - *cls.DEFAULT_SPLITS_TEST, - *cls.DEFAULT_SPLITS_VAL, - *cls.DEFAULT_SPLITS_CALIB, - *cls.DEFAULT_SPLITS_TRAIN, - ] - - for test_split in split_pref_order or []: - if test_split in dataset: - return dataset[test_split] - - return dataset[list(dataset.keys())[0]] - - @classmethod - @abstractmethod - def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: ... - - @classmethod - @abstractmethod - def handle_create( - cls, - data: Any, - data_args: dict[str, Any] | None, - processor: str | Path | PreTrainedTokenizerBase | None, - processor_args: dict[str, Any] | None, - random_seed: int, - ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: ... diff --git a/src/guidellm/dataset/entrypoints.py b/src/guidellm/dataset/entrypoints.py deleted file mode 100644 index 1da2222a..00000000 --- a/src/guidellm/dataset/entrypoints.py +++ /dev/null @@ -1,42 +0,0 @@ -from pathlib import Path -from typing import Any - -from datasets import Dataset, IterableDataset -from transformers import PreTrainedTokenizerBase # type: ignore[import] - -from guidellm.dataset.creator import ColumnInputTypes -from guidellm.dataset.file import FileDatasetCreator -from guidellm.dataset.hf_datasets import HFDatasetsCreator -from guidellm.dataset.in_memory import InMemoryDatasetCreator -from guidellm.dataset.synthetic import SyntheticDatasetCreator - -__all__ = ["load_dataset"] - - -def load_dataset( - data: Any, - data_args: dict[str, Any] | None, - processor: str | Path | PreTrainedTokenizerBase | None, - processor_args: dict[str, Any] | None, - random_seed: int = 42, - split_pref_order: list[str] | None = None, -) -> tuple[Dataset | IterableDataset, dict[ColumnInputTypes, str]]: - creators = [ - InMemoryDatasetCreator, - SyntheticDatasetCreator, - FileDatasetCreator, - HFDatasetsCreator, - ] - - for creator in creators: - if creator.is_supported(data, data_args): - return creator.create( - data, - data_args, - processor, - processor_args, - random_seed, - split_pref_order, - ) - - raise ValueError(f"Unsupported data type: {type(data)} given for {data}. ") diff --git a/src/guidellm/dataset/file.py b/src/guidellm/dataset/file.py deleted file mode 100644 index 718cb46f..00000000 --- a/src/guidellm/dataset/file.py +++ /dev/null @@ -1,92 +0,0 @@ -from pathlib import Path -from typing import Any - -import pandas as pd # type: ignore[import] -from datasets import ( - Dataset, - DatasetDict, - IterableDataset, - IterableDatasetDict, - load_dataset, -) -from transformers import PreTrainedTokenizerBase # type: ignore[import] - -from guidellm.dataset.creator import DatasetCreator - -__all__ = ["FileDatasetCreator"] - - -class FileDatasetCreator(DatasetCreator): - SUPPORTED_TYPES = { - ".txt", - ".text", - ".csv", - ".json", - ".jsonl", - ".parquet", - ".arrow", - ".hdf5", - ".tar", - } - - @classmethod - def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: # noqa: ARG003 - if isinstance(data, str | Path) and (path := Path(data)).exists(): - # local folder or py file, assume supported - return path.suffix.lower() in cls.SUPPORTED_TYPES - - return False - - @classmethod - def handle_create( - cls, - data: Any, - data_args: dict[str, Any] | None, - processor: str | Path | PreTrainedTokenizerBase | None, # noqa: ARG003 - processor_args: dict[str, Any] | None, # noqa: ARG003 - random_seed: int, # noqa: ARG003 - ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: - if not isinstance(data, str | Path): - raise ValueError(f"Unsupported data type: {type(data)} given for {data}. ") - - path = Path(data) - if not path.exists(): - raise FileNotFoundError(f"File not found: {path}") - - if not path.is_file(): - raise ValueError(f"Unsupported data type: {path} given for {path}. ") - - if path.suffix.lower() not in cls.SUPPORTED_TYPES: - raise ValueError(f"Unsupported file type: {path.suffix} given for {path}. ") - - return cls.load_dataset(path, data_args) - - @classmethod - def load_dataset( - cls, path: Path, data_args: dict[str, Any] | None - ) -> Dataset | IterableDataset: - if path.suffix.lower() in {".txt", ".text"}: - with path.open("r") as file: - items = file.readlines() - - dataset = Dataset.from_dict({"text": items}, **(data_args or {})) - elif path.suffix.lower() == ".csv": - dataset = load_dataset("csv", data_files=str(path), **(data_args or {})) - elif path.suffix.lower() in {".json", ".jsonl"}: - dataset = load_dataset("json", data_files=str(path), **(data_args or {})) - elif path.suffix.lower() == ".parquet": - dataset = load_dataset("parquet", data_files=str(path), **(data_args or {})) - elif path.suffix.lower() == ".arrow": - dataset = load_dataset("arrow", data_files=str(path), **(data_args or {})) - elif path.suffix.lower() == ".hdf5": - dataset = Dataset.from_pandas(pd.read_hdf(str(path)), **(data_args or {})) - elif path.suffix.lower() == ".db": - dataset = Dataset.from_sql(con=str(path), **(data_args or {})) - elif path.suffix.lower() == ".tar": - dataset = load_dataset( - "webdataset", data_files=str(path), **(data_args or {}) - ) - else: - raise ValueError(f"Unsupported file type: {path.suffix} given for {path}. ") - - return dataset diff --git a/src/guidellm/dataset/hf_datasets.py b/src/guidellm/dataset/hf_datasets.py deleted file mode 100644 index d1be46c1..00000000 --- a/src/guidellm/dataset/hf_datasets.py +++ /dev/null @@ -1,62 +0,0 @@ -from pathlib import Path -from typing import Any - -from datasets import ( - Dataset, - DatasetDict, - IterableDataset, - IterableDatasetDict, - get_dataset_config_info, - load_dataset, -) -from transformers import PreTrainedTokenizerBase # type: ignore[import] - -from guidellm.dataset.creator import DatasetCreator - -__all__ = ["HFDatasetsCreator"] - - -class HFDatasetsCreator(DatasetCreator): - @classmethod - def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: # noqa: ARG003 - if isinstance( - data, Dataset | DatasetDict | IterableDataset | IterableDatasetDict - ): - # base type is supported - return True - - if isinstance(data, str | Path) and (path := Path(data)).exists(): - # local folder or py file, assume supported - return path.is_dir() or path.suffix == ".py" - - if isinstance(data, str | Path): - try: - # try to load dataset - return get_dataset_config_info(data) is not None - except Exception: # noqa: BLE001, S110 - pass - - return False - - @classmethod - def handle_create( - cls, - data: Any, - data_args: dict[str, Any] | None, - processor: str | Path | PreTrainedTokenizerBase | None, # noqa: ARG003 - processor_args: dict[str, Any] | None, # noqa: ARG003 - random_seed: int, # noqa: ARG003 - ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: - if isinstance(data, str | Path): - data = load_dataset(data, **(data_args or {})) - elif data_args: - raise ValueError( - f"data_args should not be provided when data is a {type(data)}" - ) - - if isinstance( - data, Dataset | DatasetDict | IterableDataset | IterableDatasetDict - ): - return data - - raise ValueError(f"Unsupported data type: {type(data)} given for {data}. ") diff --git a/src/guidellm/dataset/in_memory.py b/src/guidellm/dataset/in_memory.py deleted file mode 100644 index 0461948c..00000000 --- a/src/guidellm/dataset/in_memory.py +++ /dev/null @@ -1,132 +0,0 @@ -from collections.abc import Iterable -from pathlib import Path -from typing import Any - -from datasets import ( - Dataset, - DatasetDict, - IterableDataset, - IterableDatasetDict, -) -from transformers import PreTrainedTokenizerBase # type: ignore[import] - -from guidellm.dataset.creator import DatasetCreator - -__all__ = ["InMemoryDatasetCreator"] - - -class InMemoryDatasetCreator(DatasetCreator): - @classmethod - def is_supported(cls, data: Any, data_args: dict[str, Any] | None) -> bool: # noqa: ARG003 - return isinstance(data, Iterable) and not isinstance(data, str) - - @classmethod - def handle_create( - cls, - data: Any, - data_args: dict[str, Any] | None, - processor: str | Path | PreTrainedTokenizerBase | None, # noqa: ARG003 - processor_args: dict[str, Any] | None, # noqa: ARG003 - random_seed: int, # noqa: ARG003 - ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: - if not isinstance(data, Iterable): - raise TypeError( - f"Unsupported data format. Expected Iterable[Any], got {type(data)}" - ) - - if not data: - raise ValueError("Data is empty") - - if isinstance(data, dict): - # assume data is a dictionary of columns and values: {"c1": ["i1", "i2"]} - data_dict = cls.format_data_dict(data) - elif isinstance(data[0], dict): # type: ignore[index] - # assume data is a list of dictionaries: [{"c1": "i1"}, {"c1": "i2"}] - data_dict = cls.format_data_iterable_dicts(data) - else: - # assume data is a list of items with no columns: ["i1", "i2"] - data_dict = cls.format_data_iterable_values(data) - - return Dataset.from_dict(data_dict, **(data_args or {})) - - @classmethod - def format_data_dict(cls, data: dict[Any, Any]) -> dict[str, Any]: - if not isinstance(data, dict): - raise TypeError( - f"Unsupported data format. Expected Dict[str, Iterable[Any]], " - f"got {type(data)}" - ) - - if not all( - isinstance(key, str) and isinstance(val, Iterable) - for key, val in data.items() - ): - raise TypeError( - "Unsupported data format. Expected Dict[str, Iterable[Any]], " - f"got {type(data)}" - ) - - samples = len(list(data.values())[0]) - if not all(len(val) == samples for val in data.values()): - raise ValueError( - "Unsupported data format. Not all columns have the same number samples " - f"for {data}" - ) - - return data - - @classmethod - def format_data_iterable_dicts( - cls, data: Iterable[dict[Any, Any]] - ) -> dict[str, Any]: - if not isinstance(data, Iterable): - raise TypeError( - f"Unsupported data format. Expected Iterable[Dict[str, Any]], " - f"got {type(data)}" - ) - - if not all(isinstance(item, dict) for item in data): - raise TypeError( - f"Unsupported data format. Expected Iterable[Dict[str, Any]], " - f"got {type(data)}" - ) - - if not all(isinstance(key, str) for key in data[0]): # type: ignore[index] - raise TypeError( - "Unsupported data format. Expected Dict[str, Any], " - f"but one of the items had a non string column for {data}" - ) - - columns = list(data[0].keys()) # type: ignore[index] - if not all( - len(item) == len(columns) and all(key in item for key in columns) - for item in data - ): - raise ValueError( - "Unsupported data format. Not all items have the same columns " - f"for {data}" - ) - - data_dict: dict[str, Any] = {key: [] for key in columns} - for item in data: - for key, value in item.items(): - data_dict[key].append(value) - - return data_dict - - @classmethod - def format_data_iterable_values(cls, data: Iterable[Any]) -> dict[str, Any]: - if not isinstance(data, Iterable): - raise TypeError( - f"Unsupported data format. Expected Iterable[Iterable[Any]], " - f"got {type(data)}" - ) - - first_item = next(iter(data), None) - first_type = type(first_item) - if not all(isinstance(item, first_type) for item in data): - raise TypeError( - f"Unsupported data format. Not all types are the same for {data}" - ) - - return {"data": list(data)} diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py deleted file mode 100644 index 8a1626fe..00000000 --- a/src/guidellm/dataset/synthetic.py +++ /dev/null @@ -1,287 +0,0 @@ -import json -import random -from collections.abc import Iterable, Iterator -from itertools import cycle -from pathlib import Path -from typing import Any, Literal - -import yaml -from datasets import ( - Dataset, - DatasetDict, - IterableDataset, - IterableDatasetDict, -) -from pydantic import BaseModel, Field -from transformers import PreTrainedTokenizerBase # type: ignore[import] - -from guidellm.dataset.creator import ColumnInputTypes, DatasetCreator -from guidellm.utils import EndlessTextCreator, IntegerRangeSampler, check_load_processor - -__all__ = [ - "SyntheticDatasetConfig", - "SyntheticDatasetCreator", - "SyntheticTextItemsGenerator", -] - - -class SyntheticDatasetConfig(BaseModel): - prefix_tokens: int = Field( - description="The number of shared prefix tokens to prepend to each prompt.", - ge=0, - default=0, - ) - prompt_tokens: int = Field( - description="The average number of text tokens generated for prompts.", - gt=0, - ) - prompt_tokens_stdev: int | None = Field( - description="The standard deviation of the tokens generated for prompts.", - gt=0, - default=None, - ) - prompt_tokens_min: int | None = Field( - description="The minimum number of text tokens generated for prompts.", - gt=0, - default=None, - ) - prompt_tokens_max: int | None = Field( - description="The maximum number of text tokens generated for prompts.", - gt=0, - default=None, - ) - output_tokens: int = Field( - description="The average number of text tokens generated for outputs.", - gt=0, - ) - output_tokens_stdev: int | None = Field( - description="The standard deviation of the tokens generated for outputs.", - gt=0, - default=None, - ) - output_tokens_min: int | None = Field( - description="The minimum number of text tokens generated for outputs.", - gt=0, - default=None, - ) - output_tokens_max: int | None = Field( - description="The maximum number of text tokens generated for outputs.", - gt=0, - default=None, - ) - samples: int = Field( - description="The number of samples to generate for the dataset.", - gt=0, - default=1000, - ) - source: str = Field( - description="The source of the text data to be used for generation.", - default="data:prideandprejudice.txt.gz", - ) - - @staticmethod - def parse_str(data: str | Path) -> "SyntheticDatasetConfig": - if ( - isinstance(data, Path) - or data.strip().endswith(".config") - or data.strip().endswith(".yaml") - ): - return SyntheticDatasetConfig.parse_config_file(data) - - if data.strip().startswith("{"): - return SyntheticDatasetConfig.parse_json(data) - - if data.count("=") > 1: - return SyntheticDatasetConfig.parse_key_value_pairs(data) - - raise ValueError( - f"Unsupported data format. Expected JSON or key-value pairs, got {data}" - ) - - @staticmethod - def parse_json(data: str) -> "SyntheticDatasetConfig": - config_dict = json.loads(data.strip()) - - return SyntheticDatasetConfig(**config_dict) - - @staticmethod - def parse_key_value_pairs(data: str) -> "SyntheticDatasetConfig": - config_dict = {} - items = data.strip().split(",") - for item in items: - key, value = item.split("=") - config_dict[key.strip()] = ( - int(value.strip()) if value.strip().isnumeric() else value.strip() - ) - - return SyntheticDatasetConfig(**config_dict) # type: ignore[arg-type] - - @staticmethod - def parse_config_file(data: str | Path) -> "SyntheticDatasetConfig": - with Path(data).open("r") as file: - config_dict = yaml.safe_load(file) - - return SyntheticDatasetConfig(**config_dict) - - -class SyntheticTextItemsGenerator( - Iterable[ - dict[ - Literal["prompt", "prompt_tokens_count", "output_tokens_count"], - str | int, - ] - ] -): - def __init__( - self, - config: SyntheticDatasetConfig, - processor: PreTrainedTokenizerBase, - random_seed: int, - ): - self.config = config - self.processor = processor - self.random_seed = random_seed - self.text_creator = EndlessTextCreator( - data=config.source, - ) - - def __iter__( - self, - ) -> Iterator[ - dict[ - Literal["prompt", "prompt_tokens_count", "output_tokens_count"], - str | int, - ] - ]: - prompt_tokens_sampler = IntegerRangeSampler( - average=self.config.prompt_tokens, - variance=self.config.prompt_tokens_stdev, - min_value=self.config.prompt_tokens_min, - max_value=self.config.prompt_tokens_max, - random_seed=self.random_seed, - ) - output_tokens_sampler = IntegerRangeSampler( - average=self.config.output_tokens, - variance=self.config.output_tokens_stdev, - min_value=self.config.output_tokens_min, - max_value=self.config.output_tokens_max, - random_seed=self.random_seed + 1, # ensure diff dist from prompts - ) - # ensure diff distribution from output tokens - rand = random.Random(self.random_seed + 2) # noqa: S311 - unique_prefix_iter = cycle(self.processor.get_vocab().values()) - - prefix_index = rand.randint(0, len(self.text_creator.words)) - prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index) - - for _, prompt_tokens, output_tokens in zip( - range(self.config.samples), - prompt_tokens_sampler, - output_tokens_sampler, strict=False, - ): - start_index = rand.randint(0, len(self.text_creator.words)) - prompt_text = self.processor.decode( - prefix_tokens - + self._create_prompt( - prompt_tokens, start_index, next(unique_prefix_iter) - ), - skip_special_tokens=True, - ) - yield { - "prompt": prompt_text, - "prompt_tokens_count": self.config.prefix_tokens + prompt_tokens, - "output_tokens_count": output_tokens, - } - - def _create_prompt( - self, prompt_tokens: int, start_index: int, unique_prefix: int | None = None - ) -> list[int]: - if prompt_tokens <= 0: - return [] - - left = start_index - right = start_index + 4 * prompt_tokens - start_tokens = [unique_prefix] if unique_prefix else [] - - while left < right: - mid = (left + right) // 2 - test_prompt = self.text_creator.create_text(start_index, mid - start_index) - test_tokens = start_tokens + self.processor.encode(test_prompt) - - if len(test_tokens) == prompt_tokens: - return test_tokens - elif len(test_tokens) < prompt_tokens: - left = mid + 1 - else: - right = mid - - final_text = self.text_creator.create_text(start_index, left - start_index) - return start_tokens + self.processor.encode(final_text) - - -class SyntheticDatasetCreator(DatasetCreator): - @classmethod - def is_supported( - cls, - data: Any, - data_args: dict[str, Any] | None, # noqa: ARG003 - ) -> bool: - if ( - isinstance(data, Path) - and data.exists() - and data.suffix in {".config", ".yaml"} - ): - return True - - if isinstance(data, str): - data_str: str = data.strip() - if ( - data_str.startswith("{") - or data_str.count("=") > 1 - or data_str.endswith((".config", ".yaml")) - ): - return True - - return False - - @classmethod - def handle_create( - cls, - data: Any, - data_args: dict[str, Any] | None, - processor: str | Path | PreTrainedTokenizerBase | None, - processor_args: dict[str, Any] | None, - random_seed: int, - ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: - processor = check_load_processor( - processor, - processor_args, - error_msg=( - "Processor/tokenizer required for synthetic dataset generation." - ), - ) - - config = SyntheticDatasetConfig.parse_str(data) - generator = SyntheticTextItemsGenerator(config, processor, random_seed) - items = list(generator) - - return Dataset.from_list(items, **(data_args or {})) - - @classmethod - def extract_args_column_mappings( - cls, - data_args: dict[str, Any] | None, - ) -> dict[ColumnInputTypes, str]: - data_args_columns = super().extract_args_column_mappings(data_args) - - if data_args_columns: - raise ValueError( - f"Column mappings are not supported for synthetic datasets. " - f"Got {data_args_columns}" - ) - - return { - "prompt_column": "prompt", - "prompt_tokens_count_column": "prompt_tokens_count", - "output_tokens_count_column": "output_tokens_count", - } diff --git a/src/guidellm/preprocess/dataset.py b/src/guidellm/preprocess/dataset.py index b02efec5..cacce3f5 100644 --- a/src/guidellm/preprocess/dataset.py +++ b/src/guidellm/preprocess/dataset.py @@ -11,7 +11,6 @@ from pydantic import BaseModel, Field from transformers import PreTrainedTokenizerBase -from guidellm.dataset import load_dataset as guidellm_load_dataset from guidellm.utils import IntegerRangeSampler, check_load_processor from guidellm.utils.hf_datasets import SUPPORTED_TYPES, save_dataset_to_file @@ -271,9 +270,7 @@ def process_dataset( f"Starting dataset conversion | Input: {data} | Output directory: {output_path}" ) - dataset, column_mappings = guidellm_load_dataset( - data, data_args, processor, processor_args - ) + dataset, column_mappings = None, None tokenizer = check_load_processor( processor, processor_args, diff --git a/src/guidellm/request/__init__.py b/src/guidellm/request/__init__.py deleted file mode 100644 index 85b447d6..00000000 --- a/src/guidellm/request/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from .loader import ( - GenerativeRequestLoader, - GenerativeRequestLoaderDescription, - RequestLoader, - RequestLoaderDescription, -) -from .request import GenerationRequest -from .types import RequestT, ResponseT - -__all__ = [ - "GenerationRequest", - "GenerativeRequestLoader", - "GenerativeRequestLoaderDescription", - "RequestLoader", - "RequestLoaderDescription", - "RequestT", - "ResponseT", -] diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py deleted file mode 100644 index ac34131e..00000000 --- a/src/guidellm/request/loader.py +++ /dev/null @@ -1,275 +0,0 @@ -from abc import abstractmethod -from collections.abc import Iterable, Iterator -from pathlib import Path -from typing import ( - Any, - Literal, -) - -from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict -from transformers import PreTrainedTokenizerBase # type: ignore[import] - -from guidellm.dataset import ColumnInputTypes, load_dataset -from guidellm.request.request import GenerationRequest -from guidellm.settings import settings -from guidellm.utils import StandardBaseModel - -__all__ = [ - "GenerativeRequestLoader", - "GenerativeRequestLoaderDescription", - "RequestLoader", - "RequestLoaderDescription", -] - - -class RequestLoaderDescription(StandardBaseModel): - type_: Literal["request_loader"] = "request_loader" - - -class RequestLoader(Iterable): - @abstractmethod - def __iter__(self) -> Iterator: ... - - @abstractmethod - def __len__(self) -> int: ... - - @property - @abstractmethod - def description(self) -> RequestLoaderDescription: ... - - -class GenerativeRequestLoaderDescription(RequestLoaderDescription): - type_: Literal["generative_request_loader"] = "generative_request_loader" # type: ignore[assignment] - data: str - data_args: dict[str, Any] | None - processor: str - processor_args: dict[str, Any] | None - - -class GenerativeRequestLoader(RequestLoader): - DEFAULT_PROMPT_COLUMNS = [ - "prompt", - "prompts", - "instruction", - "instructions", - "question", - "questions", - "input", - "inputs", - "context", - "content", - "conversation", - "conversations", - "turn", - "turns", - "text", - ] - - def __init__( - self, - data: str | Path | Iterable[str | dict[str, Any]] | Dataset | DatasetDict | \ - IterableDataset | IterableDatasetDict, - data_args: dict[str, Any] | None, - processor: str | Path | PreTrainedTokenizerBase | None, - processor_args: dict[str, Any] | None, - shuffle: bool = True, - iter_type: Literal["finite", "infinite"] = "finite", - random_seed: int = 42, - ): - self.data = data - self.data_args = data_args - dataset, args_column_mappings = load_dataset( - data, - data_args, - processor, - processor_args, - random_seed, - ) - self.dataset = dataset - self.processor = processor - self.processor_args = processor_args - self.shuffle = shuffle - self.iter_type = iter_type - self.random_seed = random_seed - - self.column_mappings = self._create_column_mappings(args_column_mappings) - self.preserve_iter_state = iter_type == "infinite" # ensure no caching requests - self._preserved_iter = None - - def __iter__(self) -> Iterator[GenerationRequest]: - scope_create_count = 0 - - while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None: - scope_create_count += 1 - - for item in dataset_iter: - yield self._create_request(item) - - self._preserved_iter = None - - def __len__(self) -> int: - if self.iter_type == "finite": - return self.num_unique_items() - - raise ValueError(f"Unable to determine length of dataset: {self.data}") - - @property - def description(self) -> GenerativeRequestLoaderDescription: - return GenerativeRequestLoaderDescription( - data=str(self.data), - data_args=self.data_args, - processor=str(self.processor), - processor_args=self.processor_args, - ) - - def num_unique_items(self, raise_err: bool = True) -> int: - try: - return len(self.dataset) - except Exception: # noqa: BLE001, S110 - pass - - dataset_size = self.dataset.info.dataset_size - if dataset_size is not None: - return dataset_size - - if raise_err: - raise ValueError("Unable to determine number of items in the dataset") - - return -1 - - def _create_column_mappings( - self, - args_column_mappings: dict[ColumnInputTypes, str], - ) -> dict[ColumnInputTypes, str]: - column_mappings: dict[ColumnInputTypes, str] = {} - - if "text_column" in args_column_mappings: - column_mappings["prompt_column"] = args_column_mappings["text_column"] - else: - column_mappings["prompt_column"] = self._extract_text_column() - - if "prompt_tokens_count_column" in args_column_mappings: - column_mappings["prompt_tokens_count_column"] = args_column_mappings[ - "prompt_tokens_count_column" - ] - elif prompt_tokens_count_column := self._extract_prompt_tokens_count_column(): - column_mappings["prompt_tokens_count_column"] = prompt_tokens_count_column - - if "output_tokens_count_column" in args_column_mappings: - column_mappings["output_tokens_count_column"] = args_column_mappings[ - "output_tokens_count_column" - ] - elif output_tokens_count_column := self._extract_output_tokens_count_column(): - column_mappings["output_tokens_count_column"] = output_tokens_count_column - - return column_mappings - - def _extract_text_column(self) -> str: - column_names = self._dataset_columns( - err_msg=( - "Unable to determine text column from dataset and it is required. " - "To specify the text column, set the 'text_column' key in the " - "'data_args' dictionary." - ) - ) - - if not column_names: - raise ValueError( - "Unable to determine text column from dataset and it is required. " - "To specify the text column, set the 'text_column' key in the " - "'data_args' dictionary." - ) - - if len(column_names) == 1: - return column_names[0] - - for def_column in self.DEFAULT_PROMPT_COLUMNS: - if def_column in column_names: - return def_column - - raise ValueError( - f"Unable to determine text column from dataset columns: {column_names}. " - "To specify the text column, set the 'text_column' key in the " - "'data_args' dictionary." - ) - - def _extract_prompt_tokens_count_column(self) -> str | None: - column_names = self._dataset_columns() - - if column_names and "prompt_tokens_count" in column_names: - return "prompt_tokens_count" - - if column_names and "prompt_tokens" in column_names: - return "prompt_tokens" - - return None - - def _extract_output_tokens_count_column(self) -> str | None: - column_names = self._dataset_columns() - - if column_names and "output_tokens_count" in column_names: - return "output_tokens_count" - - if column_names and "output_tokens" in column_names: - return "output_tokens" - - return None - - def _dataset_columns(self, err_msg: str | None = None) -> list[str] | None: - try: - column_names = self.dataset.column_names - - if not column_names and err_msg: - raise ValueError(f"No column names found in dataset: {self.data}") - except Exception as err: - if err_msg: - raise ValueError(err_msg) from err - - column_names = None - - return column_names - - def _get_dataset_iter( - self, scope_create_count: int - ) -> Iterator[dict[str, Any]] | None: - if scope_create_count > 0 and self.iter_type != "infinite": - return None - - if self.preserve_iter_state and self._preserved_iter is not None: - return self._preserved_iter - - dataset = ( - self.dataset - if not self.shuffle - else self.dataset.shuffle(seed=self.random_seed) - ) - - dataset_iter = iter(dataset) - - if self.preserve_iter_state: - self._preserved_iter = dataset_iter - - return dataset_iter - - def _create_request(self, item: dict[str, Any]) -> GenerationRequest: - prompt_tokens = ( - item[self.column_mappings["prompt_tokens_count_column"]] - if "prompt_tokens_count_column" in self.column_mappings - else None - ) - output_tokens = ( - item[self.column_mappings["output_tokens_count_column"]] - if "output_tokens_count_column" in self.column_mappings - else None - ) - - return GenerationRequest( - request_type=settings.preferred_route, - content=item[self.column_mappings["prompt_column"]], - stats=( - {"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {} - ), - constraints=( - {"output_tokens": output_tokens} if output_tokens is not None else {} - ), - ) diff --git a/src/guidellm/request/request.py b/src/guidellm/request/request.py deleted file mode 100644 index 83dc40f1..00000000 --- a/src/guidellm/request/request.py +++ /dev/null @@ -1,79 +0,0 @@ -import uuid -from typing import Any, Literal - -from pydantic import Field - -from guidellm.utils import StandardBaseModel - -__all__ = ["GenerationRequest"] - - -class GenerationRequest(StandardBaseModel): - """ - A class representing a request for generation. - This class is used to encapsulate the details of a generation request, - including the request ID, type, content, parameters, statistics, and constraints. - It is designed to be used with the BackendRequestsWorker class to handle - the generation process. - - :param request_id: The unique identifier for the request. - :param request_type: The type of request (e.g., text, chat). - :param content: The content for the request to send to the backend. - If request_type is 'text', this should be a string or list of strings - which will be resolved by backend.text_completions. - If request_type is 'chat', this should be a string, - a list of (str, Dict[str, Union[str, Dict[str, str]], Path, Image]), - or Any raw content which will be resolved by backend.chat_completions. - If raw content, raw_content=True must be passed in the params. - :param params: Additional parameters for the request passed in as kwargs. - For an http backend, these are passed into the body of the request. - :param stats: Statistics for the request, such as the number of prompt tokens. - Used for tracking and reporting purposes. - :param constraints: Constraints for the request, such as the maximum number - of output tokens. Used for controlling the behavior of the backend. - """ - - request_id: str | None = Field( - default_factory=lambda: str(uuid.uuid4()), - description="The unique identifier for the request.", - ) - request_type: Literal["text_completions", "chat_completions"] = Field( - default="text_completions", - description=( - "The type of request (e.g., text, chat). " - "If request_type='text_completions', resolved by backend.text_completions. " - "If request_typ='chat_completions', resolved by backend.chat_completions." - ), - ) - content: Any = Field( - description=( - "The content for the request to send to the backend. " - "If request_type is 'text', this should be a string or list of strings " - "which will be resolved by backend.text_completions. " - "If request_type is 'chat', this should be a string, " - "a list of (str, Dict[str, Union[str, Dict[str, str]], Path, Image]), " - "or Any raw content which will be resolved by backend.chat_completions. " - "If raw content, raw_content=True must be passed in the params." - ) - ) - params: dict[str, Any] = Field( - default_factory=dict, - description=( - "Additional parameters for the request that will be passed in as kwargs. " - "For an http backend, these are passed into the body of the request. " - ), - ) - stats: dict[Literal["prompt_tokens"], int] = Field( - default_factory=dict, - description=( - "Statistics for the request, such as the number of prompt tokens. " - "Used for tracking and reporting purposes." - ), - ) - constraints: dict[Literal["output_tokens"], int] = Field( - default_factory=dict, - description=( - "Constraints for the request, such as the maximum number of output tokens. " - "Used for controlling the behavior of the backend." - ), - ) diff --git a/src/guidellm/request/types.py b/src/guidellm/request/types.py deleted file mode 100644 index f82493be..00000000 --- a/src/guidellm/request/types.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import TypeVar - -__all__ = [ - "RequestT", - "ResponseT", -] - - -RequestT = TypeVar("RequestT") -ResponseT = TypeVar("ResponseT") diff --git a/src/guidellm/scheduler/__init__.py b/src/guidellm/scheduler/__init__.py index 64647424..2f5eb53f 100644 --- a/src/guidellm/scheduler/__init__.py +++ b/src/guidellm/scheduler/__init__.py @@ -12,21 +12,18 @@ UnserializableConstraintInitializer, ) from .environments import Environment, NonDistributedEnvironment -from .objects import ( +from .scheduler import Scheduler +from .schemas import ( BackendInterface, BackendT, - MeasuredRequestTimings, MultiTurnRequestT, - RequestSchedulerTimings, RequestT, ResponseT, - ScheduledRequestInfo, SchedulerMessagingPydanticRegistry, SchedulerState, SchedulerUpdateAction, SchedulerUpdateActionProgress, ) -from .scheduler import Scheduler from .strategies import ( AsyncConstantStrategy, AsyncPoissonStrategy, @@ -62,16 +59,13 @@ "MaxErrorsConstraint", "MaxGlobalErrorRateConstraint", "MaxNumberConstraint", - "MeasuredRequestTimings", "MultiTurnRequestT", "NoDelayRequestTimings", "NonDistributedEnvironment", "PoissonRateRequestTimings", "PydanticConstraintInitializer", - "RequestSchedulerTimings", "RequestT", "ResponseT", - "ScheduledRequestInfo", "ScheduledRequestTimings", "Scheduler", "SchedulerMessagingPydanticRegistry", diff --git a/src/guidellm/scheduler/constraints.py b/src/guidellm/scheduler/constraints.py index c974225a..2eb24bdb 100644 --- a/src/guidellm/scheduler/constraints.py +++ b/src/guidellm/scheduler/constraints.py @@ -16,12 +16,12 @@ from pydantic import Field, field_validator -from guidellm.scheduler.objects import ( - ScheduledRequestInfo, +from guidellm.scheduler.schemas import ( SchedulerState, SchedulerUpdateAction, SchedulerUpdateActionProgress, ) +from guidellm.schemas import RequestInfo from guidellm.settings import settings from guidellm.utils import InfoMixin, RegistryMixin, StandardBaseModel @@ -46,7 +46,7 @@ class Constraint(Protocol): """Protocol for constraint evaluation functions that control scheduler behavior.""" def __call__( - self, state: SchedulerState, request: ScheduledRequestInfo + self, state: SchedulerState, request: RequestInfo ) -> SchedulerUpdateAction: """ Evaluate constraint against scheduler state and request information. @@ -176,7 +176,7 @@ def serialize(cls, initializer: ConstraintInitializer) -> dict[str, Any]: @classmethod def deserialize( cls, initializer_dict: dict[str, Any] - ) -> SerializableConstraintInitializer: + ) -> SerializableConstraintInitializer | UnserializableConstraintInitializer: """ Deserialize constraint initializer from dictionary format. @@ -370,7 +370,7 @@ def create_constraint( def __call__( self, state: SchedulerState, # noqa: ARG002 - request: ScheduledRequestInfo, # noqa: ARG002 + request: RequestInfo, # noqa: ARG002 ) -> SchedulerUpdateAction: """ Raise error since unserializable constraints cannot be invoked. @@ -438,7 +438,7 @@ def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 def __call__( self, state: SchedulerState, - request_info: ScheduledRequestInfo, # noqa: ARG002 + request_info: RequestInfo, # noqa: ARG002 ) -> SchedulerUpdateAction: """ Evaluate constraint against current scheduler state and request count. @@ -556,7 +556,7 @@ def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 def __call__( self, state: SchedulerState, - request_info: ScheduledRequestInfo, # noqa: ARG002 + request_info: RequestInfo, # noqa: ARG002 ) -> SchedulerUpdateAction: """ Evaluate constraint against current scheduler state and elapsed time. @@ -670,7 +670,7 @@ def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 def __call__( self, state: SchedulerState, - request_info: ScheduledRequestInfo, # noqa: ARG002 + request_info: RequestInfo, # noqa: ARG002 ) -> SchedulerUpdateAction: """ Evaluate constraint against current error count. @@ -787,7 +787,7 @@ def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 return self.model_copy() # type: ignore[return-value] def __call__( - self, state: SchedulerState, request_info: ScheduledRequestInfo + self, state: SchedulerState, request_info: RequestInfo ) -> SchedulerUpdateAction: """ Evaluate constraint against sliding window error rate. @@ -928,7 +928,7 @@ def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 def __call__( self, state: SchedulerState, - request_info: ScheduledRequestInfo, # noqa: ARG002 + request_info: RequestInfo, # noqa: ARG002 ) -> SchedulerUpdateAction: """ Evaluate constraint against global error rate. @@ -1007,7 +1007,7 @@ def info(self) -> dict[str, Any]: def __call__( self, state: SchedulerState, - request_info: ScheduledRequestInfo, # noqa: ARG002 + request_info: RequestInfo, # noqa: ARG002 ) -> SchedulerUpdateAction: create_exceeded = state.created_requests >= self.num_requests processed_exceeded = state.processed_requests >= self.num_requests diff --git a/src/guidellm/scheduler/environments.py b/src/guidellm/scheduler/environments.py index 6234f8f6..69997e57 100644 --- a/src/guidellm/scheduler/environments.py +++ b/src/guidellm/scheduler/environments.py @@ -20,19 +20,17 @@ import time from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterable -from typing import ( - Generic, -) +from typing import Generic from guidellm.scheduler.constraints import Constraint -from guidellm.scheduler.objects import ( +from guidellm.scheduler.schemas import ( MultiTurnRequestT, RequestT, ResponseT, - ScheduledRequestInfo, SchedulerState, ) from guidellm.scheduler.strategies import SchedulingStrategy +from guidellm.schemas import RequestInfo from guidellm.settings import settings from guidellm.utils import InfoMixin @@ -93,7 +91,7 @@ async def update_run_iteration( self, response: ResponseT | None, request: RequestT, - request_info: ScheduledRequestInfo, + request_info: RequestInfo, state: SchedulerState, ): """ @@ -129,9 +127,9 @@ async def sync_run_end( self, ) -> AsyncIterator[ tuple[ - ResponseT, - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, + ResponseT | None, + RequestT, + RequestInfo, SchedulerState, ] ]: @@ -146,10 +144,10 @@ async def sync_run_end( remote nodes in distributed environments, empty for non-distributed :raises Exception: Any errors that occurred during execution """ - ... + yield None # type: ignore[misc] -class NonDistributedEnvironment(Environment): +class NonDistributedEnvironment(Environment[RequestT, ResponseT]): """ Single-node scheduler execution environment with minimal coordination overhead. @@ -162,7 +160,7 @@ class NonDistributedEnvironment(Environment): from guidellm.scheduler import ( MaxNumberConstraint, NonDistributedEnvironment, - ScheduledRequestInfo, + RequestInfo, SchedulerState, SynchronousStrategy, ) @@ -182,7 +180,7 @@ class NonDistributedEnvironment(Environment): for req in local_req: state.processed_requests += 1 await env.update_run_iteration( - f"resp_{req}", req, ScheduledRequestInfo(), state + f"resp_{req}", req, RequestInfo(), state ) async for nonlocal_req in env.sync_run_end(): state.processed_requests += 1 @@ -224,7 +222,7 @@ async def update_run_iteration( self, response: ResponseT | None, request: RequestT, - request_info: ScheduledRequestInfo, + request_info: RequestInfo, state: SchedulerState, ): """ @@ -236,7 +234,7 @@ async def update_run_iteration( :param state: Current scheduler state with metrics and progress """ - async def sync_run_error(self, err: Exception): + async def sync_run_error(self, err: Exception | list[Exception]): """ Store error for later propagation during run finalization. @@ -249,9 +247,9 @@ async def sync_run_end( self, ) -> AsyncIterator[ tuple[ - ResponseT, - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, + ResponseT | None, + RequestT, + RequestInfo, SchedulerState, ] ]: @@ -269,5 +267,6 @@ async def sync_run_end( f"Errors occurred during execution: {self.run_errors}" ) - return - yield # needed to force generator compilation + if False: + # Force compiler to recognize as generator + yield None # type: ignore[misc] diff --git a/src/guidellm/scheduler/scheduler.py b/src/guidellm/scheduler/scheduler.py index e7d8b2c6..0e19350b 100644 --- a/src/guidellm/scheduler/scheduler.py +++ b/src/guidellm/scheduler/scheduler.py @@ -13,21 +13,18 @@ from collections.abc import AsyncIterator, Iterable from typing import Any, Generic -from guidellm.scheduler.constraints import ( - Constraint, - ConstraintsInitializerFactory, -) +from guidellm.scheduler.constraints import Constraint, ConstraintsInitializerFactory from guidellm.scheduler.environments import Environment, NonDistributedEnvironment -from guidellm.scheduler.objects import ( +from guidellm.scheduler.schemas import ( BackendInterface, MultiTurnRequestT, RequestT, ResponseT, - ScheduledRequestInfo, SchedulerState, ) from guidellm.scheduler.strategies import SchedulingStrategy from guidellm.scheduler.worker_group import WorkerProcessGroup +from guidellm.schemas import RequestInfo from guidellm.utils.singleton import ThreadSafeSingletonMixin __all__ = ["Scheduler"] @@ -69,13 +66,13 @@ async def run( requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], backend: BackendInterface[RequestT, ResponseT], strategy: SchedulingStrategy, - env: Environment | None, - **constraints: dict[str, Any | dict[str, Any] | Constraint], + env: Environment[RequestT, ResponseT] | None, + **constraints: Any | dict[str, Any] | Constraint, ) -> AsyncIterator[ tuple[ ResponseT | None, RequestT, - ScheduledRequestInfo, + RequestInfo, SchedulerState, ] ]: @@ -104,7 +101,7 @@ async def run( """ with self.thread_lock: if env is None: - env = NonDistributedEnvironment() + env = NonDistributedEnvironment[RequestT, ResponseT]() worker_group: WorkerProcessGroup[RequestT, ResponseT] | None = None @@ -113,18 +110,18 @@ async def run( # and will ensure clean up before raising the error. try: # Setup local run parameters, sync with the environment - constraints = ConstraintsInitializerFactory.resolve_constraints( - constraints + resolved_constraints = ( + ConstraintsInitializerFactory.resolve_constraints(constraints) ) ( local_requests, local_strategy, local_constraints, - ) = await env.sync_run_params(requests, strategy, constraints) + ) = await env.sync_run_params(requests, strategy, resolved_constraints) # Setup the worker group, sync start with the environment worker_group = WorkerProcessGroup[RequestT, ResponseT]( - requests=None, + requests=local_requests, cycle_requests=local_requests, backend=backend, strategy=local_strategy, @@ -147,19 +144,20 @@ async def run( yield response, request, request_info, state except Exception as err: # noqa: BLE001 await env.sync_run_error(err) + raise err finally: # Ensure all worker processes are cleaned up for error or completion if worker_group is not None: - err = await worker_group.shutdown() + err = await worker_group.shutdown() # type: ignore[misc] if err is not None: await env.sync_run_error(err) # Ensure any errors are raised and all responses # are yielded for aggregation on the primary node async for ( - response, - request, - request_info, - state, + dist_response, + dist_request, + dist_request_info, + dist_state, ) in env.sync_run_end(): - yield response, request, request_info, state + yield dist_response, dist_request, dist_request_info, dist_state diff --git a/src/guidellm/scheduler/objects.py b/src/guidellm/scheduler/schemas.py similarity index 61% rename from src/guidellm/scheduler/objects.py rename to src/guidellm/scheduler/schemas.py index e2583987..d53b55a1 100644 --- a/src/guidellm/scheduler/objects.py +++ b/src/guidellm/scheduler/schemas.py @@ -10,23 +10,20 @@ from __future__ import annotations import time -import uuid from collections.abc import AsyncIterator from typing import ( Any, - ClassVar, Generic, Literal, Protocol, TypeVar, - runtime_checkable, ) -from pydantic import Field, computed_field +from pydantic import Field from typing_extensions import TypeAliasType, TypedDict +from guidellm.schemas import RequestInfo from guidellm.utils import ( - PydanticClassRegistryMixin, RegistryMixin, StandardBaseModel, ) @@ -35,12 +32,9 @@ __all__ = [ "BackendInterface", "BackendT", - "MeasuredRequestTimings", "MultiTurnRequestT", - "RequestSchedulerTimings", "RequestT", "ResponseT", - "ScheduledRequestInfo", "SchedulerMessagingPydanticRegistry", "SchedulerState", "SchedulerUpdateAction", @@ -68,168 +62,6 @@ class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]): """ -@SchedulerMessagingPydanticRegistry.register() -class RequestSchedulerTimings(StandardBaseModel): - """ - Scheduler-level timing measurements for request lifecycle tracking. - All timestamps are expected to be in Unix time (seconds since epoch). - """ - - targeted_start: float | None = Field( - default=None, - description="When the request was initially targeted for execution", - ) - queued: float | None = Field( - default=None, - description="When the request was placed into the processing queue", - ) - dequeued: float | None = Field( - default=None, - description="When the request was removed from the queue for processing", - ) - scheduled_at: float | None = Field( - default=None, description="When the request was scheduled for processing" - ) - resolve_start: float | None = Field( - default=None, description="When backend resolution of the request began" - ) - resolve_end: float | None = Field( - default=None, description="When backend resolution of the request completed" - ) - finalized: float | None = Field( - default=None, - description="When the request was processed/acknowledged by the scheduler", - ) - - -@SchedulerMessagingPydanticRegistry.register() -class MeasuredRequestTimings(PydanticClassRegistryMixin["MeasuredRequestTimings"]): - """ - Base timing measurements for backend request processing. - All timestamps are expected to be in Unix time (seconds since epoch). - """ - - @classmethod - def __pydantic_schema_base_type__(cls) -> type[MeasuredRequestTimings]: - if cls.__name__ == "MeasuredRequestTimings": - return cls - - return MeasuredRequestTimings - - schema_discriminator: ClassVar[str] = "timings_type" - - timings_type: Literal["measured_request_timings"] = Field( - default="measured_request_timings", - description="Type identifier for the timing measurement", - ) - request_start: float | None = Field( - default=None, description="When the backend began processing the request" - ) - request_end: float | None = Field( - default=None, description="When the backend completed processing the request" - ) - - -@SchedulerMessagingPydanticRegistry.register() -class ScheduledRequestInfo(StandardBaseModel): - """ - Complete request information including status, timings, and metadata. - - Central data structure for tracking request lifecycle from creation through - completion, containing scheduling metadata, timing measurements, and processing - status. Used by scheduler components to coordinate request processing across - distributed worker processes. - - Example: - :: - from guidellm.scheduler.objects import ScheduledRequestInfo - - # Create request info with automatic ID generation - request_info = ScheduledRequestInfo() - request_info.status = "in_progress" - request_info.scheduler_timings.queued = time.time() - - # Check processing completion - if request_info.completed_at: - duration = request_info.completed_at - request_info.started_at - """ - - request_id: str = Field( - description="Unique identifier for the request", - default_factory=lambda: str(uuid.uuid4()), - ) - status: Literal[ - "queued", "pending", "in_progress", "completed", "errored", "cancelled" - ] = Field(description="Current processing status of the request", default="queued") - scheduler_node_id: int = Field( - description="ID/rank of the scheduler node handling the request", - default=-1, - ) - scheduler_process_id: int = Field( - description="ID/rank of the node's scheduler process handling the request", - default=-1, - ) - scheduler_start_time: float = Field( - description="Unix timestamp for the local time when scheduler processing began", - default=-1.0, - ) - - error: str | None = Field( - default=None, description="Error message if the request.status is 'errored'" - ) - scheduler_timings: RequestSchedulerTimings = Field( - default_factory=RequestSchedulerTimings, - description="Scheduler-level timing measurements for request lifecycle", - ) - request_timings: MeasuredRequestTimings | None = Field( - default=None, - description="Backend-specific timing measurements for request processing", - ) - - @computed_field # type: ignore[misc] - @property - def started_at(self) -> float | None: - """ - Get the effective request processing start time. - - :return: Unix timestamp when processing began, or None if not started. - """ - request_start = ( - self.request_timings.request_start if self.request_timings else None - ) - - return request_start or self.scheduler_timings.resolve_start - - @computed_field # type: ignore[misc] - @property - def completed_at(self) -> float | None: - """ - Get the effective request processing completion time. - - :return: Unix timestamp when processing completed, or None if not completed. - """ - request_end = self.request_timings.request_end if self.request_timings else None - - return request_end or self.scheduler_timings.resolve_end - - def model_copy(self, **kwargs) -> ScheduledRequestInfo: # type: ignore[override] # noqa: ARG002 - """ - Create a deep copy of the request info with copied timing objects. - - :return: New ScheduledRequestInfo instance with independent timing objects - """ - return super().model_copy( - update={ - "scheduler_timings": self.scheduler_timings.model_copy(), - "request_timings": ( - self.request_timings.model_copy() if self.request_timings else None - ), - }, - deep=False, - ) - - -@runtime_checkable class BackendInterface(Protocol, Generic[RequestT, ResponseT]): """ Abstract interface for request processing backends. @@ -295,9 +127,9 @@ async def process_shutdown(self) -> None: async def resolve( self, request: RequestT, - request_info: ScheduledRequestInfo, + request_info: RequestInfo, history: list[tuple[RequestT, ResponseT]] | None = None, - ) -> AsyncIterator[tuple[ResponseT, ScheduledRequestInfo]]: + ) -> AsyncIterator[tuple[ResponseT, RequestInfo]]: """ Process a request and yield incremental response updates. diff --git a/src/guidellm/scheduler/strategies.py b/src/guidellm/scheduler/strategies.py index 8c791671..d3e31d43 100644 --- a/src/guidellm/scheduler/strategies.py +++ b/src/guidellm/scheduler/strategies.py @@ -17,7 +17,7 @@ from pydantic import Field, PrivateAttr -from guidellm.scheduler.objects import ScheduledRequestInfo +from guidellm.schemas import RequestInfo from guidellm.utils import InfoMixin, PydanticClassRegistryMixin, StandardBaseModel __all__ = [ @@ -83,7 +83,7 @@ def next_offset(self) -> float: """ @abstractmethod - def request_completed(self, request_info: ScheduledRequestInfo): + def request_completed(self, request_info: RequestInfo): """ Handle request completion and update internal timing state. @@ -129,7 +129,7 @@ def next_offset(self) -> float: return self.offset - def request_completed(self, request_info: ScheduledRequestInfo): + def request_completed(self, request_info: RequestInfo): """ Update timing state based on the completed request. @@ -197,7 +197,7 @@ def next_offset(self) -> float: return self.offset + startup_percent * self.startup_duration - def request_completed(self, request_info: ScheduledRequestInfo): + def request_completed(self, request_info: RequestInfo): """ Handle request completion (no action needed for throughput strategy). @@ -236,7 +236,7 @@ def next_offset(self) -> float: return self.offset + interval * num_requests - def request_completed(self, request_info: ScheduledRequestInfo): + def request_completed(self, request_info: RequestInfo): """ Handle request completion (no action needed for constant rate strategy). @@ -283,7 +283,7 @@ def next_offset(self) -> float: return self.offset - def request_completed(self, request_info: ScheduledRequestInfo): + def request_completed(self, request_info: RequestInfo): """ Handle request completion (no action needed for Poisson rate strategy). @@ -331,7 +331,7 @@ def requests_limit(self) -> int | None: return None def create_request_timings( - self, local_rank: int, local_world_size: int, local_max_concurrency: int + self, local_rank: int, local_world_size: int, local_max_concurrency: int | float ) -> ScheduledRequestTimings: """ Create a timing instance to define scheduling behavior for a worker process. diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 104ab418..45716b78 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -29,15 +29,14 @@ ] = False -from guidellm.scheduler.objects import ( +from guidellm.scheduler.schemas import ( BackendInterface, MultiTurnRequestT, RequestT, ResponseT, - ScheduledRequestInfo, - SchedulerMessagingPydanticRegistry, ) from guidellm.scheduler.strategies import ScheduledRequestTimings +from guidellm.schemas import RequestInfo from guidellm.utils import ( InterProcessMessaging, wait_for_sync_barrier, @@ -77,7 +76,7 @@ def __init__( tuple[ ResponseT | None, RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, + RequestInfo, ], ], backend: BackendInterface[RequestT, ResponseT], @@ -235,8 +234,7 @@ async def _processing_startup(self): # Get messaging system ready await self.messaging.start( - receive_stop_criteria=[self.requests_generated_event], - pydantic_models=list(SchedulerMessagingPydanticRegistry.registry.values()), + receive_stop_criteria=[self.requests_generated_event] ) self.messaging_started = True @@ -289,56 +287,59 @@ async def _cancel_requests_loop(self): while True: try: request: RequestT - request_info: ScheduledRequestInfo + request_info: RequestInfo request, request_info = await self.messaging.get( timeout=self.messaging.poll_interval ) except asyncio.TimeoutError: continue - request_info.scheduler_node_id = self.messaging.worker_index + request_info.scheduler_node_id = self.messaging.worker_index or -1 request_info.error = "Request was cancelled" - request_info.scheduler_timings.resolve_end = time.time() + request_info.timings.resolve_end = time.time() self._send_update("cancelled", None, request, request_info) async def _process_next_request(self): request: RequestT | MultiTurnRequestT[RequestT] | None = None - request_info: ScheduledRequestInfo | None = None + request_info: RequestInfo | None = None response: ResponseT | None = None try: # Pull request from the queue request, request_info = await self.messaging.get() - if isinstance(request, list | tuple): + if request is None or request_info is None: + raise RuntimeError("Received invalid request or request info") + + if isinstance(request, (list, tuple)): raise NotImplementedError("Multi-turn requests are not yet supported") # Calculate targeted start and set pending state for request - request_info.scheduler_node_id = self.messaging.worker_index - request_info.scheduler_timings.dequeued = time.time() + request_info.scheduler_node_id = self.messaging.worker_index or -1 + request_info.timings.dequeued = time.time() target_start = ( request_info.scheduler_start_time + self.request_timings.next_offset() ) - request_info.scheduler_timings.targeted_start = target_start + request_info.timings.targeted_start = target_start self._send_update("pending", response, request, request_info) # Schedule the request current_time = time.time() - request_info.scheduler_timings.scheduled_at = current_time + request_info.timings.scheduled_at = current_time if target_start > current_time: await asyncio.sleep(target_start - current_time) # Adapt delay so that scheduled at reflects the sleep time - request_info.scheduler_timings.scheduled_at = target_start + request_info.timings.scheduled_at = target_start # Process the request with the backend - request_info.scheduler_timings.resolve_start = time.time() + request_info.timings.resolve_start = time.time() self._send_update("in_progress", response, request, request_info) async for resp, info in self.backend.resolve(request, request_info, None): response = resp request_info = info # Complete the request - request_info.scheduler_timings.resolve_end = time.time() + request_info.timings.resolve_end = time.time() self._send_update("completed", response, request, request_info) response = request = request_info = None @@ -346,13 +347,13 @@ async def _process_next_request(self): # Handle cancellation if request is not None and request_info is not None: request_info.error = "Request was cancelled" - request_info.scheduler_timings.resolve_end = time.time() + request_info.timings.resolve_end = time.time() self._send_update("cancelled", response, request, request_info) raise except Exception as exc: # noqa: BLE001 if request is not None and request_info is not None: request_info.error = str(exc) - request_info.scheduler_timings.resolve_end = time.time() + request_info.timings.resolve_end = time.time() self._send_update("errored", response, request, request_info) def _send_update( @@ -362,7 +363,7 @@ def _send_update( ], response: ResponseT | None, request: RequestT | MultiTurnRequestT[RequestT], - request_info: ScheduledRequestInfo, + request_info: RequestInfo, ): prev_status = request_info.status diff --git a/src/guidellm/scheduler/worker_group.py b/src/guidellm/scheduler/worker_group.py index c1d516f1..21394668 100644 --- a/src/guidellm/scheduler/worker_group.py +++ b/src/guidellm/scheduler/worker_group.py @@ -22,19 +22,19 @@ from multiprocessing.synchronize import Barrier, Event from typing import Generic, NamedTuple +from guidellm.logger import logger from guidellm.scheduler.constraints import Constraint, RequestsExhaustedConstraint -from guidellm.scheduler.objects import ( +from guidellm.scheduler.schemas import ( BackendInterface, MultiTurnRequestT, RequestT, ResponseT, - ScheduledRequestInfo, - SchedulerMessagingPydanticRegistry, SchedulerState, SchedulerUpdateAction, ) from guidellm.scheduler.strategies import SchedulingStrategy from guidellm.scheduler.worker import WorkerProcess +from guidellm.schemas import RequestInfo from guidellm.settings import settings from guidellm.utils import ( InterProcessMessaging, @@ -98,10 +98,9 @@ def __init__( :raises ValueError: If neither requests nor cycle_requests are provided, or if cycle_requests is an Iterator rather than Iterable """ - if not requests and not cycle_requests: + if requests is None and cycle_requests is None: raise ValueError( "At least one of 'requests' or 'cycle_requests' must be provided. " - f"Got requests: {requests}, cycle_requests: {cycle_requests}" ) if isinstance(cycle_requests, Iterator): @@ -117,29 +116,32 @@ def __init__( self.constraints = constraints # Multiprocessing contexts and primitives, created in create_processes - self.mp_context: BaseContext = None - self.mp_manager: BaseManager = None - self.processes: list[BaseProcess] = None - self.startup_barrier: Barrier = None - self.requests_generated_event: Event = None - self.constraint_reached_event: Event = None - self.shutdown_event: Event = None - self.error_event: Event = None + self.mp_context: BaseContext | None = None + self.mp_manager: BaseManager | None = None + self.processes: list[BaseProcess] | None = None + self.startup_barrier: Barrier | None = None + self.requests_generated_event: Event | None = None + self.constraint_reached_event: Event | None = None + self.shutdown_event: Event | None = None + self.error_event: Event | None = None # Scheduler and messaging state, created in start - self.state: WorkerGroupState[ResponseT, RequestT] = None - self.messaging: InterProcessMessaging[ - tuple[ - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, - ], - tuple[ - ResponseT | None, - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, - SchedulerState, - ], - ] = None + self.state: WorkerGroupState[RequestT, ResponseT] | None = None + self.messaging: ( + InterProcessMessaging[ + tuple[ + RequestT | MultiTurnRequestT[RequestT], + RequestInfo, + ], + tuple[ + ResponseT | None, + RequestT | MultiTurnRequestT[RequestT], + RequestInfo, + SchedulerState, + ], + ] + | None + ) = None async def create_processes(self): """ @@ -153,19 +155,23 @@ async def create_processes(self): :raises RuntimeError: If process initialization or startup fails """ # Processes limits and params - max_conc: int = min( - self.strategy.requests_limit or math.inf, - self.backend.requests_limit or math.inf, - ) - if max_conc == math.inf: - # if concurrency not specified, use settings + max_conc: int + if ( + requests_limit := min( + self.strategy.requests_limit or math.inf, + self.backend.requests_limit or math.inf, + ) + ) != math.inf: + max_conc = int(requests_limit) + else: + # If concurrency not specified, use settings max_conc = settings.max_concurrency if max_conc <= 0: raise RuntimeError("max_concurrency resolved to 0; increase limits/config") # Calculate number of processes, ensure we don't exceed the max concurrency, # or limits from the backend, strategy, or user settings - num_processes = int( + num_processes: int = int( min( max_conc, self.strategy.processes_limit or math.inf, @@ -185,7 +191,7 @@ async def create_processes(self): ) # Initialize multiprocessing components - self.mp_context: BaseContext = get_context(settings.mp_context_type) + self.mp_context = get_context(settings.mp_context_type) self.mp_manager = self.mp_context.Manager() self.startup_barrier = self.mp_context.Barrier(num_processes + 1) self.requests_generated_event = self.mp_context.Event() @@ -280,7 +286,14 @@ async def start(self, start_time: float): :raises RuntimeError: If workers encounter errors during startup or if create_processes() was not called first """ - if not self.processes: + if ( + not self.processes + or not self.requests_generated_event + or not self.constraint_reached_event + or not self.shutdown_event + or not self.error_event + or not self.messaging + ): raise RuntimeError("create_processes() must be called before start()") stop_send_requests_event = threading.Event() @@ -304,7 +317,6 @@ async def start(self, start_time: float): send_stopped_event=send_requests_stopped_event, send_stop_criteria=[stop_send_requests_event], receive_stop_criteria=[self.shutdown_event], - pydantic_models=list(SchedulerMessagingPydanticRegistry.registry.values()), ) if (wait_time := start_time - time.time()) > 0: @@ -320,8 +332,8 @@ async def request_updates( ) -> AsyncIterator[ tuple[ ResponseT | None, - RequestT, - ScheduledRequestInfo, + RequestT | MultiTurnRequestT[RequestT], + RequestInfo, SchedulerState, ] ]: @@ -337,7 +349,8 @@ async def request_updates( :raises RuntimeError: If workers encounter unrecoverable errors """ while True: - if self.error_event.is_set(): + if self.error_event.is_set(): # type: ignore[union-attr] + logger.error("Error event set in WorkerProcessGroup") raise RuntimeError( "error_event is set in WorkerProcessGroup, " "indicating an error occurred in one of the worker processes." @@ -349,11 +362,11 @@ async def request_updates( request, request_info, scheduler_state, - ) = await self.messaging.get(timeout=settings.mp_poll_interval) + ) = await self.messaging.get(timeout=settings.mp_poll_interval) # type: ignore[union-attr] yield response, request, request_info, scheduler_state except asyncio.TimeoutError: - if self.shutdown_event.is_set(): + if self.shutdown_event.is_set(): # type: ignore[union-attr] # Everything yielded, exit break @@ -465,15 +478,17 @@ def __init__( num_processes=len(processes), start_time=start_time, ) - self._queued_requests = set() - self._pending_requests = set() - self._processing_requests = set() + self._queued_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set() + self._pending_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set() + self._processing_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set() def requests_generator( self, requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, cycle_requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, - ) -> Generator[tuple[RequestT | MultiTurnRequestT[RequestT],], None, None]: + ) -> Generator[ + tuple[RequestT | MultiTurnRequestT[RequestT], RequestInfo], None, None + ]: """ Generate request-info pairs for worker processing with constraint evaluation. @@ -486,56 +501,67 @@ def requests_generator( :return: Generator yielding (request, request_info) tuples """ - def _iter(): - if requests: + def _iter() -> Iterator[RequestT | MultiTurnRequestT[RequestT]]: + if requests is not None: yield from requests - if cycle_requests: + if cycle_requests is not None: while True: yield from cycle_requests - count = 0 - request_info: ScheduledRequestInfo = None - for request in _iter(): - count += 1 - - if hasattr(request, "request_id"): - request_id = request.request_id - elif hasattr(request, "id"): - request_id = request.id - else: - request_id = str(uuid.uuid4()) - request_info: ScheduledRequestInfo = ScheduledRequestInfo( - request_id=request_id, - status="queued", - scheduler_process_id=0, - scheduler_start_time=self.start_time, + try: + count = 0 + request_iter = _iter() + for request in request_iter: + count += 1 + + if hasattr(request, "request_id"): + request_id = request.request_id + elif hasattr(request, "id"): + request_id = request.id + else: + request_id = str(uuid.uuid4()) + request_info: RequestInfo = RequestInfo( + request_id=request_id, + status="queued", + scheduler_process_id=0, + scheduler_start_time=self.start_time, + ) + state_update = self._locked_update(request_info) + request_info.timings.queued = time.time() + + yield (request, request_info) + + if state_update.stop_queueing: + self.stop_send_requests_event.set() + return + + # Reached the end, inject a RequestsExhaustedConstraint to record + self._locked_update( + info=None, + requests_exhausted={ + "requests_exhausted": RequestsExhaustedConstraint( + num_requests=count + ) + }, ) - state_update = self._locked_update(request_info) - yield (request, request_info) - - if state_update.stop_queueing: - self.stop_send_requests_event.set() - return - - # Reached the end, inject a RequestsExhaustedConstraint to record - self._locked_update( - info=None, - requests_exhausted=RequestsExhaustedConstraint(num_requests=count), - ) - self.stop_send_requests_event.set() + self.stop_send_requests_event.set() + except Exception as err: + logger.error(f"Error generating requests: {err}") + self.error_event.set() + raise err def received_callback( self, update: tuple[ ResponseT | None, RequestT | MultiTurnRequestT, - ScheduledRequestInfo, + RequestInfo, ], ) -> tuple[ ResponseT | None, RequestT | MultiTurnRequestT, - ScheduledRequestInfo, + RequestInfo, SchedulerState, ]: """ @@ -548,31 +574,40 @@ def received_callback( :param update: Tuple containing response, request, and request info :return: Updated tuple with injected scheduler state """ - response, request, request_info = update - state_update = self._locked_update(info=request_info) + try: + response, request, request_info = update + state_update = self._locked_update(info=request_info) - # Check if we need to tell workers to stop pulling new requests - # based on no more requests sent and all requests removed from queue - if ( - state_update.state.queued_requests == 0 - and self.send_requests_stopped_event.is_set() - and not self.requests_generated_event.is_set() - ): - self.requests_generated_event.set() + # Check if we need to tell workers to stop pulling new requests + # based on no more requests sent and all requests removed from queue + if ( + state_update.state.queued_requests == 0 + and self.stop_send_requests_event.is_set() + and not self.requests_generated_event.is_set() + ): + self.requests_generated_event.set() - # Check if we need to tell workers to stop processing requests (constraints) - if state_update.stop_processing and not self.constraint_reached_event.is_set(): - self.constraint_reached_event.set() + # Check if we need to tell workers to stop processing requests (constraints) + if ( + state_update.stop_processing + and not self.constraint_reached_event.is_set() + ): + self.constraint_reached_event.set() - # Check if all requests have been processed and can shutdown - if ( - state_update.state.processed_requests == state_update.state.created_requests - and self.send_requests_stopped_event.is_set() - and self.requests_generated_event.is_set() - and self.constraint_reached_event.is_set() - and not self.shutdown_event.is_set() - ): - self.shutdown_event.set() + # Check if all requests have been processed and can shutdown + if ( + state_update.state.processed_requests + == state_update.state.created_requests + and self.stop_send_requests_event.is_set() + and self.requests_generated_event.is_set() + and self.constraint_reached_event.is_set() + and not self.shutdown_event.is_set() + ): + self.shutdown_event.set() + except Exception as err: + logger.error(f"Error processing received update: {err}") + self.error_event.set() + raise err return ( response, @@ -583,7 +618,7 @@ def received_callback( def _locked_update( self, - info: ScheduledRequestInfo | None = None, + info: RequestInfo | None = None, **add_constraints: dict[str, Constraint], ) -> _StateUpdate: with self._update_lock: @@ -603,7 +638,7 @@ def _locked_update( state_copy.end_processing_time is not None, ) - def _update_state_request_counts(self, info: ScheduledRequestInfo): + def _update_state_request_counts(self, info: RequestInfo): if info.status == "queued": self._queued_requests.add(info.request_id) self._state.queued_requests = len(self._queued_requests) @@ -640,7 +675,7 @@ def _update_state_request_counts(self, info: ScheduledRequestInfo): else: raise ValueError(f"Unknown request_info status {info.status} for {info}") - def _update_with_constraints(self, info: ScheduledRequestInfo): + def _update_with_constraints(self, info: RequestInfo): actions: dict[str, SchedulerUpdateAction] = { name: const(self._state, info) for name, const in self.constraints.items() } diff --git a/src/guidellm/schemas/__init__.py b/src/guidellm/schemas/__init__.py new file mode 100644 index 00000000..42268f72 --- /dev/null +++ b/src/guidellm/schemas/__init__.py @@ -0,0 +1,31 @@ +""" +Pydantic schema models for GuideLLM operations. + +Provides standardized data models and type definitions for generation requests, +responses, timing measurements, and statistics aggregation. These schemas ensure +type safety and consistent data handling across the benchmarking pipeline, +from request submission through backend processing to results compilation. +""" + +from __future__ import annotations + +from .info import RequestInfo, RequestTimings +from .request import ( + GenerationRequest, + GenerationRequestArguments, + GenerativeRequestType, + UsageMetrics, +) +from .response import GenerationResponse +from .stats import GenerativeRequestStats + +__all__ = [ + "GenerationRequest", + "GenerationRequestArguments", + "GenerationResponse", + "GenerativeRequestStats", + "GenerativeRequestType", + "RequestInfo", + "RequestTimings", + "UsageMetrics", +] diff --git a/src/guidellm/schemas/info.py b/src/guidellm/schemas/info.py new file mode 100644 index 00000000..4b5d188c --- /dev/null +++ b/src/guidellm/schemas/info.py @@ -0,0 +1,159 @@ +""" +Core data structures and interfaces for the GuideLLM scheduler system. + +Provides type-safe abstractions for distributed request processing, timing +measurements, and backend interfaces for benchmarking operations. Central to +the scheduler architecture, enabling request lifecycle tracking, backend +coordination, and state management across distributed worker processes. +""" + +from __future__ import annotations + +import uuid +from typing import Literal + +from pydantic import Field, computed_field + +from guidellm.utils import StandardBaseDict, StandardBaseModel + +__all__ = ["RequestInfo", "RequestTimings"] + + +class RequestTimings(StandardBaseDict): + """ + Timing measurements for tracking request lifecycle events. + + Provides comprehensive timing data for distributed request processing, capturing + key timestamps from initial targeting through final completion. Essential for + performance analysis, SLA monitoring, and debugging request processing bottlenecks + across scheduler workers and backend systems. + """ + + targeted_start: float | None = Field( + default=None, + description="Unix timestamp when request was initially targeted for execution", + ) + queued: float | None = Field( + default=None, + description="Unix timestamp when request was placed into processing queue", + ) + dequeued: float | None = Field( + default=None, + description="Unix timestamp when request was removed from queue for processing", + ) + scheduled_at: float | None = Field( + default=None, + description="Unix timestamp when the request was scheduled for processing", + ) + resolve_start: float | None = Field( + default=None, + description="Unix timestamp when backend resolution of the request began", + ) + request_start: float | None = Field( + default=None, + description="Unix timestamp when the backend began processing the request", + ) + first_iteration: float | None = Field( + default=None, + description="Unix timestamp when the first iteration for a streaming began", + ) + last_iteration: float | None = Field( + default=None, + description="Unix timestamp when the last iteration for a streaming completed", + ) + iterations: int | None = Field( + default=None, + description="Total number of streaming update iterations performed", + ) + request_end: float | None = Field( + default=None, + description="Unix timestamp when the backend completed processing the request", + ) + resolve_end: float | None = Field( + default=None, + description="Unix timestamp when backend resolution of the request completed", + ) + finalized: float | None = Field( + default=None, + description="Unix timestamp when request was processed by the scheduler", + ) + + +class RequestInfo(StandardBaseModel): + """ + Complete information about a request in the scheduler system. + + Encapsulates all metadata, status tracking, and timing information for requests + processed through the distributed scheduler. Provides comprehensive lifecycle + tracking from initial queuing through final completion, including error handling + and node identification for debugging and performance analysis. + + Example: + :: + request = RequestInfo() + request.status = "in_progress" + start_time = request.started_at + completion_time = request.completed_at + """ + + request_id: str = Field( + description="Unique identifier for the request", + default_factory=lambda: str(uuid.uuid4()), + ) + status: Literal[ + "queued", "pending", "in_progress", "completed", "errored", "cancelled" + ] = Field(description="Current processing status of the request", default="queued") + scheduler_node_id: int = Field( + description="ID/rank of the scheduler node handling the request", + default=-1, + ) + scheduler_process_id: int = Field( + description="ID/rank of the node's scheduler process handling the request", + default=-1, + ) + scheduler_start_time: float = Field( + description="Unix timestamp when scheduler processing began", + default=-1, + ) + timings: RequestTimings = Field( + default_factory=RequestTimings, + description="Timing measurements for the request lifecycle", + ) + + error: str | None = Field( + default=None, description="Error message if the request status is 'errored'" + ) + + @computed_field # type: ignore[misc] + @property + def started_at(self) -> float | None: + """ + Get the effective request processing start time. + + :return: Unix timestamp when processing began, or None if not started + """ + return self.timings.request_start or self.timings.resolve_start + + @computed_field # type: ignore[misc] + @property + def completed_at(self) -> float | None: + """ + Get the effective request processing completion time. + + :return: Unix timestamp when processing completed, or None if not completed + """ + return self.timings.request_end or self.timings.resolve_end + + def model_copy(self, **_kwargs) -> RequestInfo: # type: ignore[override] # noqa: ARG002 + """ + Create a deep copy of the request info with copied timing objects. + + :param kwargs: Additional keyword arguments for model copying + :return: New RequestInfo instance with independent timing objects + """ + return super().model_copy( + update={ + "timings": self.timings.model_copy(), + }, + deep=False, + ) diff --git a/src/guidellm/schemas/request.py b/src/guidellm/schemas/request.py new file mode 100644 index 00000000..9e9189fc --- /dev/null +++ b/src/guidellm/schemas/request.py @@ -0,0 +1,216 @@ +""" +Request schema definitions for generation operations. + +Contains request models and data structures used to define and execute generation +requests across different backend services. Provides standardized interfaces for +request arguments, usage metrics tracking, and request type definitions that enable +consistent interaction with various AI generation APIs. +""" + +from __future__ import annotations + +import uuid +from typing import Any, Literal + +from pydantic import Field, computed_field + +from guidellm.utils import StandardBaseDict, StandardBaseModel + +__all__ = [ + "GenerationRequest", + "GenerationRequestArguments", + "GenerativeRequestType", + "UsageMetrics", +] + + +GenerativeRequestType = Literal[ + "text_completions", + "chat_completions", + "audio_transcriptions", + "audio_translations", +] + + +class GenerationRequestArguments(StandardBaseDict): + """ + HTTP request arguments for generation operations. + + Encapsulates all necessary HTTP request components including method, headers, + parameters, and payload data required to execute generation requests against + backend services. Supports file uploads and streaming responses. + """ + + method: str | None = Field( + default=None, + description="The HTTP method to use for the request (e.g., 'POST', 'GET').", + ) + stream: bool | None = Field( + default=None, + description="Whether to stream the response, if applicable.", + ) + headers: dict[str, str] | None = Field( + default=None, + description="Any headers to include in the request, if applicable.", + ) + params: dict[str, Any] | None = Field( + default=None, + description="Query parameters to include in the request, if applicable.", + ) + body: dict[str, Any] | None = Field( + default=None, + description="Content to include in the main request body.", + ) + files: dict[str, Any] | None = Field( + default=None, + description="Files to include in the request, if applicable.", + ) + + def model_combine( + self, additional: GenerationRequestArguments | dict[str, Any] + ) -> GenerationRequestArguments: + """ + Merge additional request arguments into the current instance. + + Combines method and stream fields by overwriting, while merging collection + fields like headers, params, json_body, and files by extending existing values. + + :param additional: Additional arguments to merge with current instance + :return: Updated instance with merged arguments + """ + additional_dict = ( + additional.model_dump() + if isinstance(additional, GenerationRequestArguments) + else additional + ) + + for overwrite in ("method", "stream"): + if (val := additional_dict.get(overwrite)) is not None: + setattr(self, overwrite, val) + + for combine in ("headers", "params", "json_body", "files"): + if (val := additional_dict.get(combine)) is not None: + setattr(self, combine, {**getattr(self, combine, {}), **val}) + + return self + + +class UsageMetrics(StandardBaseDict): + """ + Multimodal usage metrics for generation requests. + + Tracks resource consumption across different modalities including text, images, + video, and audio. Provides granular metrics for tokens, bytes, duration, and + format-specific measurements to enable comprehensive usage monitoring and billing. + """ + + # Text stats + text_tokens: int | None = Field( + default=None, description="Number of text tokens processed/generated." + ) + text_words: int | None = Field( + default=None, description="Number of text words processed/generated." + ) + text_characters: int | None = Field( + default=None, description="Number of text characters processed/generated." + ) + + # Vision image stats + image_tokens: int | None = Field( + default=None, description="Number of image tokens processed/generated." + ) + image_count: int | None = Field( + default=None, description="Number of images processed/generated." + ) + image_pixels: int | None = Field( + default=None, description="Number of image pixels processed/generated." + ) + image_bytes: int | None = Field( + default=None, description="Number of image bytes processed/generated." + ) + + # Vision video stats + video_tokens: int | None = Field( + default=None, description="Number of video tokens processed/generated." + ) + video_frames: int | None = Field( + default=None, description="Number of video frames processed/generated." + ) + video_seconds: float | None = Field( + default=None, description="Duration of video processed/generated in seconds." + ) + video_bytes: int | None = Field( + default=None, description="Number of video bytes processed/generated." + ) + + # Audio stats + audio_tokens: int | None = Field( + default=None, description="Number of audio tokens processed/generated." + ) + audio_samples: int | None = Field( + default=None, description="Number of audio samples processed/generated." + ) + audio_seconds: float | None = Field( + default=None, description="Duration of audio processed/generated in seconds." + ) + audio_bytes: int | None = Field( + default=None, description="Number of audio bytes processed/generated." + ) + + @computed_field # type: ignore[misc] + @property + def total_tokens(self) -> int | None: + """ + Calculate total tokens across all modalities. + + :return: Sum of text, image, video, and audio tokens, or None if all are None + """ + return (self.text_tokens or 0) + (self.image_tokens or 0) + ( + self.video_tokens or 0 + ) + (self.audio_tokens or 0) or None + + +class GenerationRequest(StandardBaseModel): + """ + Complete request specification for backend generation operations. + + Encapsulates all components needed to execute a generation request including + unique identification, request type specification, HTTP arguments, and input/output + usage metrics. Serves as the primary interface between the scheduler and backend + services for coordinating AI generation tasks. + + Example:: + request = GenerationRequest( + request_type="text_completions", + arguments=GenerationRequestArguments( + method="POST", + body={"prompt": "Hello world", "max_tokens": 100} + ) + ) + """ + + request_id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Unique identifier for the request.", + ) + request_type: GenerativeRequestType | str = Field( + description=( + "Type of request. If url is not provided in arguments, " + "this will be used to determine the request url." + ), + ) + arguments: GenerationRequestArguments = Field( + description=( + "Payload for the request, structured as a dictionary of arguments to pass " + "to the respective backend method. For example, can contain " + "'json', 'headers', 'files', etc." + ) + ) + input_metrics: UsageMetrics = Field( + default_factory=UsageMetrics, + description="Input statistics including counts, sizes, and durations.", + ) + output_metrics: UsageMetrics = Field( + default_factory=UsageMetrics, + description="Output statistics including counts, sizes, and durations.", + ) diff --git a/src/guidellm/schemas/response.py b/src/guidellm/schemas/response.py new file mode 100644 index 00000000..d4e53aa3 --- /dev/null +++ b/src/guidellm/schemas/response.py @@ -0,0 +1,119 @@ +""" +Backend response models for request and response handling. + +Provides standardized response models for generation operations that capture +output text, usage metrics, and compilation of request statistics. Ensures +consistent data handling and statistics aggregation across different backend +implementations. +""" + +from __future__ import annotations + +from pydantic import Field + +from guidellm.schemas.info import RequestInfo +from guidellm.schemas.request import GenerationRequest, UsageMetrics +from guidellm.schemas.stats import GenerativeRequestStats +from guidellm.utils import StandardBaseModel + +__all__ = ["GenerationResponse"] + + +class GenerationResponse(StandardBaseModel): + """ + Response model for backend generation operations. + + Captures the output and metrics from a generation request, providing structured + data for text output, token usage statistics, and compilation of detailed + request statistics for analysis and monitoring purposes. + + Example: + :: + response = GenerationResponse( + request_id="req-123", + text="Generated response text", + input_metrics=UsageMetrics(token_count=50), + output_metrics=UsageMetrics(token_count=25) + ) + stats = response.compile_stats(request, info) + """ + + request_id: str = Field( + description="Unique identifier matching the original GenerationRequest." + ) + request_args: str | None = Field( + description="Arguments passed to the backend for request processing." + ) + text: str | None = Field( + default=None, + description="The generated response text.", + ) + input_metrics: UsageMetrics = Field( + default_factory=UsageMetrics, + description="Token usage statistics from the input prompt.", + ) + output_metrics: UsageMetrics = Field( + default_factory=UsageMetrics, + description="Token usage statistics from the generated output.", + ) + + def compile_stats( + self, + request: GenerationRequest, + info: RequestInfo, + prefer_response: bool = True, + ) -> GenerativeRequestStats: + """ + Compile and return comprehensive request statistics. + + Merges metrics from the request and response objects to create a complete + statistical record, with preference given to response-level metrics when + available to ensure accuracy of actual execution data. + + :param request: The original generation request containing input data + :param info: Metadata and timing information for the request execution + :param prefer_response: Whether to prefer response metrics over request + metrics when both are available + :return: A GenerativeRequestStats object containing detailed statistics + :raises ValueError: When request IDs don't match between objects + """ + if request.request_id != self.request_id: + raise ValueError("Mismatched request IDs between request and response.") + + if info.request_id != self.request_id: + raise ValueError("Mismatched request IDs between info and response.") + + if info.status != "completed": + # clear out request output metrics if the request failed since + # those are not valid + request.output_metrics = UsageMetrics() + + base_input = request.input_metrics if prefer_response else self.input_metrics + override_input = ( + self.input_metrics if prefer_response else request.input_metrics + ) + base_output = request.output_metrics if prefer_response else self.output_metrics + override_output = ( + self.output_metrics if prefer_response else request.output_metrics + ) + + input_metrics_dict = base_input.model_dump() + for key, value in override_input.model_dump().items(): + if value is not None: + input_metrics_dict[key] = value + output_metrics_dict = base_output.model_dump() + for key, value in override_output.model_dump().items(): + if value is not None: + output_metrics_dict[key] = value + + return GenerativeRequestStats( + request_id=self.request_id, + request_type=request.request_type, + request_args=str( + request.arguments.model_dump() if request.arguments else {} + ), + output=self.text, + info=info, + input_metrics=UsageMetrics(**input_metrics_dict), + output_metrics=UsageMetrics(**output_metrics_dict), + ) diff --git a/src/guidellm/schemas/stats.py b/src/guidellm/schemas/stats.py new file mode 100644 index 00000000..67f1d26c --- /dev/null +++ b/src/guidellm/schemas/stats.py @@ -0,0 +1,228 @@ +""" +Request statistics and metrics for generative AI benchmark analysis. + +Provides data structures for capturing and analyzing performance metrics from +generative AI workloads. Contains request-level statistics including token counts, +latency measurements, and throughput calculations for text generation benchmarks. +""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import Field, computed_field + +from guidellm.schemas.info import RequestInfo +from guidellm.schemas.request import GenerativeRequestType, UsageMetrics +from guidellm.utils import StandardBaseDict + +__all__ = ["GenerativeRequestStats"] + + +class GenerativeRequestStats(StandardBaseDict): + """ + Request statistics for generative AI text generation workloads. + + Captures comprehensive performance metrics for individual generative requests, + including token counts, timing measurements, and derived performance statistics. + Provides computed properties for latency analysis, throughput calculations, + and token generation metrics essential for benchmark evaluation. + + Example: + :: + stats = GenerativeRequestStats( + request_id="req_123", + request_type="text_completion", + info=request_info, + input_metrics=input_usage, + output_metrics=output_usage + ) + throughput = stats.output_tokens_per_second + """ + + type_: Literal["generative_request_stats"] = "generative_request_stats" + request_id: str = Field(description="Unique identifier for the request") + request_type: GenerativeRequestType | str = Field( + description="Type of generative request: text or chat completion" + ) + request_args: str | None = Field( + default=None, description="Arguments passed to the backend for this request" + ) + output: str | None = Field( + description="Generated text output, if request completed successfully" + ) + info: RequestInfo = Field( + description="Metadata and timing information for the request" + ) + input_metrics: UsageMetrics = Field( + description="Usage statistics for the input prompt" + ) + output_metrics: UsageMetrics = Field( + description="Usage statistics for the generated output" + ) + + # Request stats + @computed_field # type: ignore[misc] + @property + def request_latency(self) -> float | None: + """ + End-to-end request processing latency in seconds. + + :return: Duration from request start to completion, or None if unavailable. + """ + if not self.info.timings.request_end or not self.info.timings.request_start: + return None + + return self.info.timings.request_end - self.info.timings.request_start + + # General token stats + @computed_field # type: ignore[misc] + @property + def prompt_tokens(self) -> int | None: + """ + Number of tokens in the input prompt. + + :return: Input prompt token count, or None if unavailable. + """ + return self.input_metrics.text_tokens + + @computed_field # type: ignore[misc] + @property + def input_tokens(self) -> int | None: + """ + Number of tokens in the input prompt. + + :return: Input prompt token count, or None if unavailable. + """ + return self.input_metrics.total_tokens + + @computed_field # type: ignore[misc] + @property + def output_tokens(self) -> int | None: + """ + Number of tokens in the generated output. + + :return: Generated output token count, or None if unavailable. + """ + return self.output_metrics.total_tokens + + @computed_field # type: ignore[misc] + @property + def total_tokens(self) -> int | None: + """ + Total token count including prompt and output tokens. + + :return: Sum of prompt and output tokens, or None if either is unavailable. + """ + input_tokens = self.input_metrics.total_tokens + output_tokens = self.output_metrics.total_tokens + + if input_tokens is None and output_tokens is None: + return None + + return (input_tokens or 0) + (output_tokens or 0) + + @computed_field # type: ignore[misc] + @property + def time_to_first_token_ms(self) -> float | None: + """ + Time to first token generation in milliseconds. + + :return: Latency from request start to first token, or None if unavailable. + """ + if ( + not self.info.timings.first_iteration + or not self.info.timings.request_start + or self.info.timings.first_iteration == self.info.timings.last_iteration + ): + return None + + return 1000 * ( + self.info.timings.first_iteration - self.info.timings.request_start + ) + + @computed_field # type: ignore[misc] + @property + def time_per_output_token_ms(self) -> float | None: + """ + Average time per output token in milliseconds. + + Includes time for first token and all subsequent tokens. + + :return: Average milliseconds per output token, or None if unavailable. + """ + if ( + not self.info.timings.request_start + or not self.info.timings.last_iteration + or not self.output_metrics.total_tokens + ): + return None + + return ( + 1000 + * (self.info.timings.last_iteration - self.info.timings.request_start) + / self.output_metrics.total_tokens + ) + + @computed_field # type: ignore[misc] + @property + def inter_token_latency_ms(self) -> float | None: + """ + Average inter-token latency in milliseconds. + + Measures time between token generations, excluding first token. + + :return: Average milliseconds between tokens, or None if unavailable. + """ + if ( + not self.info.timings.first_iteration + or not self.info.timings.last_iteration + or not self.output_metrics.total_tokens + or self.output_metrics.total_tokens <= 1 + ): + return None + + return ( + 1000 + * (self.info.timings.last_iteration - self.info.timings.first_iteration) + / (self.output_metrics.total_tokens - 1) + ) + + @computed_field # type: ignore[misc] + @property + def tokens_per_second(self) -> float | None: + """ + Overall token throughput including prompt and output tokens. + + :return: Total tokens per second, or None if unavailable. + """ + if not (latency := self.request_latency) or self.total_tokens is None: + return None + + return self.total_tokens / latency + + @computed_field # type: ignore[misc] + @property + def output_tokens_per_second(self) -> float | None: + """ + Output token generation throughput. + + :return: Output tokens per second, or None if unavailable. + """ + if not (latency := self.request_latency) or self.output_tokens is None: + return None + + return self.output_tokens / latency + + @computed_field # type: ignore[misc] + @property + def output_tokens_per_iteration(self) -> float | None: + """ + Average output tokens generated per iteration. + + :return: Output tokens per iteration, or None if unavailable. + """ + if self.output_tokens is None or not self.info.timings.iterations: + return None + + return self.output_tokens / self.info.timings.iterations diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index 702b2a9d..89312771 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -17,13 +17,9 @@ safe_getattr, safe_multiply, ) -from .hf_datasets import ( - SUPPORTED_TYPES, - save_dataset_to_file, -) -from .hf_transformers import ( - check_load_processor, -) +from .hf_datasets import SUPPORTED_TYPES, save_dataset_to_file +from .hf_transformers import check_load_processor +from .imports import json from .messaging import ( InterProcessMessaging, InterProcessMessagingManagerQueue, @@ -113,6 +109,7 @@ "format_value_display", "get_literal_vals", "is_punctuation", + "json", "load_text", "recursive_key_update", "safe_add", diff --git a/src/guidellm/utils/cli.py b/src/guidellm/utils/cli.py index 4d83526a..f049e94e 100644 --- a/src/guidellm/utils/cli.py +++ b/src/guidellm/utils/cli.py @@ -5,8 +5,10 @@ def parse_json(ctx, param, value): # noqa: ARG001 - if value is None: + if value is None or value == [None]: return None + if isinstance(value, (list, tuple)): + return [parse_json(ctx, param, val) for val in value] try: return json.loads(value) except json.JSONDecodeError as err: diff --git a/src/guidellm/utils/encoding.py b/src/guidellm/utils/encoding.py index 916d6633..7ececef5 100644 --- a/src/guidellm/utils/encoding.py +++ b/src/guidellm/utils/encoding.py @@ -10,7 +10,6 @@ from __future__ import annotations -import json from collections.abc import Mapping from typing import Any, ClassVar, Generic, Literal, TypeVar, cast @@ -24,11 +23,11 @@ HAS_MSGPACK = False try: - from msgspec.msgpack import ( # type: ignore[import-not-found] # Optional dependency - Decoder as MsgspecDecoder, + from msgspec.msgpack import ( + Decoder as MsgspecDecoder, # type: ignore[import-not-found] # Optional dependency ) - from msgspec.msgpack import ( # type: ignore[import-not-found] # Optional dependency - Encoder as MsgspecEncoder, + from msgspec.msgpack import ( + Encoder as MsgspecEncoder, # type: ignore[import-not-found] # Optional dependency ) HAS_MSGSPEC = True @@ -36,16 +35,11 @@ MsgspecDecoder = MsgspecEncoder = None HAS_MSGSPEC = False -try: - import orjson # type: ignore[import-not-found] # Optional dependency - - HAS_ORJSON = True -except ImportError: - orjson = None - HAS_ORJSON = False from pydantic import BaseModel +from guidellm.utils.imports import json + __all__ = [ "Encoder", "EncodingTypesAlias", @@ -62,7 +56,7 @@ # Type alias for available serialization strategies SerializationTypesAlias = Literal["dict", "sequence"] | None # "Type alias for available binary encoding formats" -EncodingTypesAlias = Literal["msgpack", "msgspec"] +EncodingTypesAlias = Literal["msgpack", "msgspec"] | None class MessageEncoding(Generic[ObjT, MsgT]): @@ -510,7 +504,7 @@ def to_sequence(self, obj: Any) -> str | Any: ): payload_type = "collection_mapping" keys = ",".join(str(key) for key in obj) - payload = keys.encode() + b"|" if HAS_ORJSON else keys + "|" + payload = keys.encode() + b"|" for item in obj.values(): is_pydantic = isinstance(item, BaseModel) payload = self.pack_next_sequence( @@ -601,15 +595,7 @@ def to_sequence_pydantic(self, obj: BaseModel) -> str | bytes: class_module: str = obj.__class__.__module__ json_data = obj.__pydantic_serializer__.to_json(obj) - return ( - (class_name.encode() + b"|" + class_module.encode() + b"|" + json_data) - if HAS_ORJSON - else ( - class_name + "|" + class_module + "|" + json_data.decode() - if isinstance(json_data, bytes) - else json_data - ) - ) + return class_name.encode() + b"|" + class_module.encode() + b"|" + json_data def from_sequence_pydantic(self, data: str | bytes) -> BaseModel: """ @@ -643,7 +629,7 @@ def to_sequence_python(self, obj: Any) -> str | bytes: :param obj: Python object to serialize :return: JSON string or bytes representation """ - return orjson.dumps(obj) if HAS_ORJSON else json.dumps(obj) + return json.dumps(obj) def from_sequence_python(self, data: str | bytes) -> Any: """ @@ -651,13 +637,7 @@ def from_sequence_python(self, data: str | bytes) -> Any: :param data: JSON string or bytes to deserialize :return: Reconstructed Python object - :raises ImportError: If orjson is required but not available """ - if isinstance(data, bytes): - if not HAS_ORJSON: - raise ImportError("orjson is not available, cannot deserialize bytes") - return orjson.loads(data) - return json.loads(data) def pack_next_sequence( # noqa: C901, PLR0912 diff --git a/src/guidellm/utils/imports.py b/src/guidellm/utils/imports.py new file mode 100644 index 00000000..9a6b82d1 --- /dev/null +++ b/src/guidellm/utils/imports.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +try: + import orjson as json +except ImportError: + import json + + +__all__ = ["json"] diff --git a/src/guidellm/utils/messaging.py b/src/guidellm/utils/messaging.py index 4dce576d..f64aef8d 100644 --- a/src/guidellm/utils/messaging.py +++ b/src/guidellm/utils/messaging.py @@ -477,7 +477,7 @@ def __init__( self, mp_context: BaseContext | None = None, serialization: SerializationTypesAlias = "dict", - encoding: EncodingTypesAlias = None, + encoding: EncodingTypesAlias | list[EncodingTypesAlias] = None, max_pending_size: int | None = None, max_buffer_send_size: int | None = None, max_done_size: int | None = None, @@ -668,6 +668,8 @@ def _send_messages_task_thread( # noqa: C901, PLR0912 except (culsans.QueueFull, queue.Full): pass + time.sleep(0) # Yield to other threads + def _receive_messages_task_thread( # noqa: C901 self, receive_callback: Callable[[Any], Any] | None, @@ -721,6 +723,8 @@ def _receive_messages_task_thread( # noqa: C901 except (culsans.QueueFull, queue.Full): pass + time.sleep(0) # Yield to other threads + class InterProcessMessagingManagerQueue( InterProcessMessagingQueue[SendMessageT, ReceiveMessageT] @@ -750,7 +754,7 @@ def __init__( manager: SyncManager, mp_context: BaseContext | None = None, serialization: SerializationTypesAlias = "dict", - encoding: EncodingTypesAlias = None, + encoding: EncodingTypesAlias | list[EncodingTypesAlias] = None, max_pending_size: int | None = None, max_buffer_send_size: int | None = None, max_done_size: int | None = None, @@ -854,7 +858,7 @@ def __init__( num_workers: int, mp_context: BaseContext | None = None, serialization: SerializationTypesAlias = "dict", - encoding: EncodingTypesAlias = None, + encoding: EncodingTypesAlias | list[EncodingTypesAlias] = None, max_pending_size: int | None = None, max_buffer_send_size: int | None = None, max_done_size: int | None = None, diff --git a/src/guidellm/utils/registry.py b/src/guidellm/utils/registry.py index e4727cbd..1a1a213f 100644 --- a/src/guidellm/utils/registry.py +++ b/src/guidellm/utils/registry.py @@ -11,7 +11,7 @@ from __future__ import annotations from collections.abc import Callable -from typing import Any, ClassVar, Generic, TypeVar, cast +from typing import ClassVar, Generic, TypeVar, cast from guidellm.utils.auto_importer import AutoImporterMixin @@ -65,7 +65,7 @@ class TokenProposal(RegistryMixin): :cvar registry_populated: Track whether auto-discovery has completed """ - registry: ClassVar[dict[str, Any] | None] = None + registry: ClassVar[dict[str, RegistryObjT] | None] = None registry_auto_discovery: ClassVar[bool] = False registry_populated: ClassVar[bool] = False diff --git a/src/guidellm/utils/statistics.py b/src/guidellm/utils/statistics.py index f71a2c24..0529cb0c 100644 --- a/src/guidellm/utils/statistics.py +++ b/src/guidellm/utils/statistics.py @@ -255,6 +255,7 @@ def from_values( def from_request_times( requests: list[tuple[float, float]], distribution_type: Literal["concurrency", "rate"], + weights: list[float] | None = None, include_cdf: bool = False, epsilon: float = 1e-6, ) -> DistributionSummary: @@ -273,67 +274,86 @@ def from_request_times( :return: DistributionSummary with timing-based statistical metrics :raises ValueError: If distribution_type is not "concurrency" or "rate" """ + if not weights: + weights = [1.0] * len(requests) + + if len(requests) != len(weights): + raise ValueError( + "The length of requests and weights must be the same.", + ) + + # First convert to timing events based on type + events: list[tuple[float, float]] = [] + if distribution_type == "concurrency": - # convert to delta changes based on when requests were running - events = [(start, 1) for start, _ in requests] + [ - (end, -1) for _, end in requests - ] + # For concurrency, each request adds to concurrency at start + # and subtracts at end + for (start, end), weight in zip(requests, weights, strict=False): + events.append((start, weight)) + events.append((end, -1 * weight)) elif distribution_type == "rate": - # convert to events for when requests finished - global_start = min(start for start, _ in requests) if requests else 0 - events = [(global_start, 1)] + [(end, 1) for _, end in requests] + # For rate, each request is added at the end time only + global_start = min(start for start, _ in requests) if requests else 0.0 + events.append((global_start, 0.0)) + for (_, end), weight in zip(requests, weights, strict=False): + events.append((end, weight)) else: raise ValueError( f"Invalid distribution_type '{distribution_type}'. " "Must be 'concurrency' or 'rate'." ) - # combine any events that are very close together - flattened_events: list[tuple[float, float]] = [] - for time, val in sorted(events): - last_time, last_val = ( - flattened_events[-1] if flattened_events else (None, None) - ) + # Combine any events within epsilon of each other for stability + sorted_events = sorted(events, key=lambda event: event[0]) + flattened_events: list[tuple[float, float]] = ( + [sorted_events.pop(0)] if sorted_events else [] + ) + last_time = flattened_events[0][0] if flattened_events else 0.0 - if ( - last_time is not None - and last_val is not None - and abs(last_time - time) <= epsilon - ): + for time, val in sorted_events: + if abs(time - last_time) <= epsilon: + last_val = flattened_events[-1][1] flattened_events[-1] = (last_time, last_val + val) else: + last_time = time flattened_events.append((time, val)) - if distribution_type == "concurrency": - # convert to the events over time measuring concurrency changes - events_over_time: list[tuple[float, float]] = [] - active = 0 - for time, delta in flattened_events: - active += delta # type: ignore [assignment] - events_over_time.append((time, active)) - - flattened_events = events_over_time - - # convert to value distribution function + # Convert events to value distribution function distribution: dict[float, float] = defaultdict(float) - for ind in range(len(flattened_events) - 1): - start_time, value = flattened_events[ind] - end_time, _ = flattened_events[ind + 1] - duration = end_time - start_time - - if distribution_type == "concurrency": - # weight the concurrency value by the duration + if distribution_type == "concurrency": + # For concurrency, convert to active concurrency over time + active = 0.0 + for ind in range(len(flattened_events)): + time, change = flattened_events[ind] + active += change + flattened_events[ind] = (time, active) + + # Then convert to distribution by weighting each concurrency + # by duration to next event (last event is 0 concurrency) + for ind in range(len(flattened_events) - 1): + time, value = flattened_events[ind] + next_time = flattened_events[ind + 1][0] + duration = next_time - time distribution[value] += duration - elif distribution_type == "rate": - # weight the rate value by the duration - rate = value / duration + elif distribution_type == "rate": + # For rate, convert to distribution by converting each value + # to a rate (value/duration) weighted by duration from previous + # (first event is 0 rate) + for ind in range(1, len(flattened_events)): + time, value = flattened_events[ind] + prev_time = flattened_events[ind - 1][0] + duration = time - prev_time + rate = value / duration if duration > 0 else 0.0 distribution[rate] += duration - - distribution_list: list[tuple[float, float]] = sorted(distribution.items()) + else: + raise ValueError( + f"Invalid distribution_type '{distribution_type}'. " + "Must be 'concurrency' or 'rate'." + ) return DistributionSummary.from_distribution_function( - distribution=distribution_list, + distribution=sorted(distribution.items()), include_cdf=include_cdf, ) @@ -563,6 +583,7 @@ def from_request_times( request_types: list[Literal["successful", "incomplete", "error"]], requests: list[tuple[float, float]], distribution_type: Literal["concurrency", "rate"], + weights: list[float] | None = None, include_cdf: bool = False, epsilon: float = 1e-6, ) -> StatusDistributionSummary: @@ -603,65 +624,78 @@ def from_request_times( f"Got {len(request_types)} and {len(requests)} instead.", ) - _, successful_requests = ( - zip(*successful, strict=True) + if weights is None: + weights = [1.0] * len(requests) + + if len(requests) != len(weights): + raise ValueError( + "The length of requests and weights must be the same." + f"Got {len(requests)} and {len(weights)} instead.", + ) + + _, successful_requests, successful_weights = ( + zip(*successful, strict=False) if ( successful := list( filter( lambda val: val[0] == "successful", - zip(request_types, requests, strict=True), + zip(request_types, requests, weights, strict=False), ) ) ) - else ([], []) + else ([], [], []) ) - _, incomplete_requests = ( - zip(*incomplete, strict=True) + _, incomplete_requests, incomplete_weights = ( + zip(*incomplete, strict=False) if ( incomplete := list( filter( lambda val: val[0] == "incomplete", - zip(request_types, requests, strict=True), + zip(request_types, requests, weights, strict=False), ) ) ) - else ([], []) + else ([], [], []) ) - _, errored_requests = ( - zip(*errored, strict=True) + _, errored_requests, errored_weights = ( + zip(*errored, strict=False) if ( errored := list( filter( lambda val: val[0] == "error", - zip(request_types, requests, strict=True), + zip(request_types, requests, weights, strict=False), ) ) ) - else ([], []) + else ([], [], []) ) return StatusDistributionSummary( total=DistributionSummary.from_request_times( requests, distribution_type=distribution_type, + weights=weights, include_cdf=include_cdf, epsilon=epsilon, ), successful=DistributionSummary.from_request_times( successful_requests, # type: ignore[arg-type] distribution_type=distribution_type, + weights=successful_weights, # type: ignore[arg-type] include_cdf=include_cdf, epsilon=epsilon, ), incomplete=DistributionSummary.from_request_times( incomplete_requests, # type: ignore[arg-type] distribution_type=distribution_type, + weights=incomplete_weights, # type: ignore[arg-type] include_cdf=include_cdf, epsilon=epsilon, ), errored=DistributionSummary.from_request_times( errored_requests, # type: ignore[arg-type] distribution_type=distribution_type, + weights=errored_weights, # type: ignore[arg-type] include_cdf=include_cdf, epsilon=epsilon, ), diff --git a/src/guidellm/utils/text.py b/src/guidellm/utils/text.py index a659ac6a..37f2e8d3 100644 --- a/src/guidellm/utils/text.py +++ b/src/guidellm/utils/text.py @@ -13,7 +13,6 @@ import gzip import re import textwrap -from importlib.resources import as_file, files # type: ignore[attr-defined] from pathlib import Path from typing import Any @@ -21,7 +20,6 @@ import httpx from loguru import logger -from guidellm import data as package_data from guidellm.settings import settings from guidellm.utils.console import Colors @@ -239,15 +237,6 @@ def load_text(data: str | Path, encoding: str | None = None) -> str: response.raise_for_status() return response.text - # check package data - if isinstance(data, str) and data.startswith("data:"): - resource_path = files(package_data).joinpath(data[5:]) - with ( - as_file(resource_path) as resource_file, - gzip.open(resource_file, "rt", encoding=encoding) as file, - ): - return file.read() - # check gzipped files if isinstance(data, str) and data.endswith(".gz"): with gzip.open(data, "rt", encoding=encoding) as file: diff --git a/tests/integration/scheduler/test_scheduler.py b/tests/integration/scheduler/test_scheduler.py index 65bff95f..106f320f 100644 --- a/tests/integration/scheduler/test_scheduler.py +++ b/tests/integration/scheduler/test_scheduler.py @@ -167,7 +167,8 @@ def _request_indices(): _request_indices(), received_updates.keys(), received_updates.values(), - received_responses, strict=False, + received_responses, + strict=False, ): assert req == f"req_{index}" assert resp in (f"response_for_{req}", f"mock_error_for_{req}") diff --git a/tests/unit/backends/test_backend.py b/tests/unit/backends/test_backend.py index d5a4b955..bf3129df 100644 --- a/tests/unit/backends/test_backend.py +++ b/tests/unit/backends/test_backend.py @@ -11,11 +11,11 @@ import pytest from guidellm.backends.backend import Backend, BackendType -from guidellm.backends.objects import ( +from guidellm.scheduler import BackendInterface, ScheduledRequestInfo +from guidellm.schemas.response import ( GenerationRequest, GenerationRequestTimings, ) -from guidellm.scheduler import BackendInterface, ScheduledRequestInfo from guidellm.utils import RegistryMixin from tests.unit.testing_utils import async_timeout diff --git a/tests/unit/backends/test_objects.py b/tests/unit/backends/test_objects.py index bf903733..600592bc 100644 --- a/tests/unit/backends/test_objects.py +++ b/tests/unit/backends/test_objects.py @@ -9,12 +9,12 @@ import pytest from pydantic import ValidationError -from guidellm.backends.objects import ( +from guidellm.scheduler import MeasuredRequestTimings +from guidellm.schemas.response import ( GenerationRequest, GenerationRequestTimings, GenerationResponse, ) -from guidellm.scheduler import MeasuredRequestTimings from guidellm.utils import StandardBaseModel diff --git a/tests/unit/backends/test_openai_backend.py b/tests/unit/backends/test_openai_backend.py index 724075e8..fefd7a26 100644 --- a/tests/unit/backends/test_openai_backend.py +++ b/tests/unit/backends/test_openai_backend.py @@ -13,13 +13,13 @@ from PIL import Image from guidellm.backends.backend import Backend -from guidellm.backends.objects import ( +from guidellm.backends.openai import OpenAIHTTPBackend +from guidellm.schemas import ( GenerationRequest, - GenerationRequestTimings, GenerationResponse, + RequestInfo, + RequestTimings, ) -from guidellm.backends.openai import OpenAIHTTPBackend, UsageStats -from guidellm.scheduler import ScheduledRequestInfo from tests.unit.testing_utils import async_timeout @@ -613,13 +613,13 @@ async def test_resolve_not_implemented_history(self): await backend.process_startup() request = GenerationRequest(content="test") - request_info = ScheduledRequestInfo( + request_info = RequestInfo( request_id="test-id", status="pending", scheduler_node_id=1, scheduler_process_id=1, scheduler_start_time=123.0, - request_timings=GenerationRequestTimings(), + request_timings=RequestTimings(), ) history = [(request, GenerationResponse(request_id="test", request_args={}))] @@ -641,13 +641,13 @@ async def test_resolve_text_completions(self): params={"temperature": 0.7}, constraints={"output_tokens": 100}, ) - request_info = ScheduledRequestInfo( + request_info = RequestInfo( request_id="test-id", status="pending", scheduler_node_id=1, scheduler_process_id=1, scheduler_start_time=123.0, - request_timings=GenerationRequestTimings(), + request_timings=RequestTimings(), ) # Mock text_completions method @@ -682,13 +682,13 @@ async def test_resolve_chat_completions(self): request_type="chat_completions", params={"temperature": 0.5}, ) - request_info = ScheduledRequestInfo( + request_info = RequestInfo( request_id="test-id", status="pending", scheduler_node_id=1, scheduler_process_id=1, scheduler_start_time=123.0, - request_timings=GenerationRequestTimings(), + request_timings=RequestTimings(), ) # Mock chat_completions method @@ -1123,13 +1123,13 @@ async def test_resolve_timing_edge_cases(self): request_type="text_completions", constraints={"output_tokens": 50}, ) - request_info = ScheduledRequestInfo( + request_info = RequestInfo( request_id="test-id", status="pending", scheduler_node_id=1, scheduler_process_id=1, scheduler_start_time=123.0, - request_timings=GenerationRequestTimings(), + request_timings=RequestTimings(), ) # Mock text_completions to test timing edge cases diff --git a/tests/unit/dataset/__init__.py b/tests/unit/data/__init__.py similarity index 100% rename from tests/unit/dataset/__init__.py rename to tests/unit/data/__init__.py diff --git a/tests/unit/data/deserializers/__init__.py b/tests/unit/data/deserializers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/data/deserializers/test_synthetic.py b/tests/unit/data/deserializers/test_synthetic.py new file mode 100644 index 00000000..de95227a --- /dev/null +++ b/tests/unit/data/deserializers/test_synthetic.py @@ -0,0 +1,587 @@ +""" +Unit tests for guidellm.data.deserializers.synthetic module. +""" + +import json +import tempfile +from pathlib import Path +from unittest.mock import Mock + +import pytest +import yaml +from datasets import IterableDataset + +from guidellm.data.deserializers.deserializer import DataNotSupportedError +from guidellm.data.deserializers.synthetic import ( + SyntheticTextDatasetConfig, + SyntheticTextDatasetDeserializer, + SyntheticTextGenerator, + SyntheticTextPrefixBucketConfig, +) + + +class TestPrefixBucketConfig: + """Test cases for PrefixBucketConfig class. + + ### WRITTEN BY AI ### + """ + + @pytest.mark.smoke + def test_creation_with_valid_params(self): + """Test creating PrefixBucketConfig with valid parameters. + + ### WRITTEN BY AI ### + """ + config = SyntheticTextPrefixBucketConfig( + bucket_weight=100, prefix_count=1, prefix_tokens=5 + ) + + assert config.bucket_weight == 100 + assert config.prefix_count == 1 + assert config.prefix_tokens == 5 + + @pytest.mark.sanity + def test_creation_with_negative_values(self): + """Test creating PrefixBucketConfig with negative values raises ValueError. + + ### WRITTEN BY AI ### + """ + with pytest.raises(ValueError): + SyntheticTextPrefixBucketConfig( + bucket_weight=-10, prefix_count=1, prefix_tokens=5 + ) + + with pytest.raises(ValueError): + SyntheticTextPrefixBucketConfig( + bucket_weight=100, prefix_count=-1, prefix_tokens=5 + ) + + with pytest.raises(ValueError): + SyntheticTextPrefixBucketConfig( + bucket_weight=100, prefix_count=1, prefix_tokens=-5 + ) + + @pytest.mark.regression + def test_prefix_bucket_zero_weight_error(self): + """Test that zero total weight raises an error. + + ### WRITTEN BY AI ### + """ + # Test validation error for creating PrefixBucketConfig with weight=0 + with pytest.raises(ValueError): + SyntheticTextPrefixBucketConfig( + bucket_weight=0, prefix_count=1, prefix_tokens=2 + ) + + @pytest.mark.sanity + def test_prefix_bucket_config_validation(self): + """Test PrefixBucketConfig validation. + + ### WRITTEN BY AI ### + """ + # Test valid config + valid_config = SyntheticTextPrefixBucketConfig( + bucket_weight=50, prefix_count=2, prefix_tokens=3 + ) + assert valid_config.bucket_weight == 50 + assert valid_config.prefix_count == 2 + assert valid_config.prefix_tokens == 3 + + # Test invalid bucket_weight + with pytest.raises(ValueError): + SyntheticTextPrefixBucketConfig( + bucket_weight=0, prefix_count=1, prefix_tokens=2 + ) + + # Test invalid prefix_count + with pytest.raises(ValueError): + SyntheticTextPrefixBucketConfig( + bucket_weight=100, prefix_count=0, prefix_tokens=2 + ) + + # Test invalid prefix_tokens + with pytest.raises(ValueError): + SyntheticTextPrefixBucketConfig( + bucket_weight=100, prefix_count=1, prefix_tokens=-1 + ) + + +class TestSyntheticDatasetConfig: + """Test cases for SyntheticDatasetConfig class. + + ### WRITTEN BY AI ### + """ + + @pytest.mark.smoke + def test_config_creation_with_all_params(self): + """Test creating config with all parameters specified. + + ### WRITTEN BY AI ### + """ + prefix_bucket = SyntheticTextPrefixBucketConfig( + bucket_weight=100, prefix_count=1, prefix_tokens=5 + ) + + config = SyntheticTextDatasetConfig( + prefix_buckets=[prefix_bucket], + prompt_tokens=100, + prompt_tokens_stdev=10, + prompt_tokens_min=50, + prompt_tokens_max=150, + output_tokens=30, + output_tokens_stdev=5, + output_tokens_min=20, + output_tokens_max=40, + source="custom_text.txt", + ) + + assert config.prefix_buckets[0].prefix_tokens == 5 # type: ignore [index] + assert config.prompt_tokens == 100 + assert config.prompt_tokens_stdev == 10 + assert config.prompt_tokens_min == 50 + assert config.prompt_tokens_max == 150 + assert config.output_tokens == 30 + assert config.output_tokens_stdev == 5 + assert config.output_tokens_min == 20 + assert config.output_tokens_max == 40 + assert config.source == "custom_text.txt" + + @pytest.mark.regression + def test_parse_json_string(self): + """Test parsing JSON string configuration. + + ### WRITTEN BY AI ### + """ + json_str = json.dumps( + { + "prompt_tokens": 75, + "output_tokens": 25, + "source": "test.txt", + "prefix_buckets": [ + {"bucket_weight": 100, "prefix_count": 1, "prefix_tokens": 10} + ], + } + ) + + config = SyntheticTextDatasetConfig.model_validate_json(json_str) + + assert config.prompt_tokens == 75 + assert config.output_tokens == 25 + assert config.source == "test.txt" + assert config.prefix_buckets[0].prefix_tokens == 10 # type: ignore [index] + + @pytest.mark.sanity + def test_validation_positive_values(self): + """Test that negative or zero values are rejected. + + ### WRITTEN BY AI ### + """ + with pytest.raises(ValueError): + SyntheticTextDatasetConfig(prompt_tokens=0, output_tokens=20) + + with pytest.raises(ValueError): + SyntheticTextDatasetConfig(prompt_tokens=20, output_tokens=0) + + # Test negative prefix tokens via PrefixBucketConfig validation + with pytest.raises(ValueError): + SyntheticTextPrefixBucketConfig(prefix_tokens=-1) + + @pytest.mark.regression + def test_validation_optional_positive_values(self): + """Test that optional parameters reject negative values. + + ### WRITTEN BY AI ### + """ + with pytest.raises(ValueError): + SyntheticTextDatasetConfig( + prompt_tokens=20, output_tokens=10, prompt_tokens_stdev=-1 + ) + + with pytest.raises(ValueError): + SyntheticTextDatasetConfig( + prompt_tokens=20, output_tokens=10, prompt_tokens_min=-1 + ) + + with pytest.raises(ValueError): + SyntheticTextDatasetConfig( + prompt_tokens=20, output_tokens=10, output_tokens_max=0 + ) + + +class TestSyntheticTextGenerator: + """Test cases for SyntheticTextGenerator class. + + ### WRITTEN BY AI ### + """ + + @pytest.fixture + def mock_tokenizer(self): + """Fixture to provide a mocked tokenizer. + + ### WRITTEN BY AI ### + """ + tokenizer = Mock() + tokenizer.encode.side_effect = lambda text: list(range(len(text.split()))) + tokenizer.decode.side_effect = ( + lambda tokens, skip_special_tokens=False: " ".join( + f"token_{t}" for t in tokens[:5] + ) + ) + return tokenizer + + @pytest.fixture + def simple_config(self): + """Fixture for simple configuration. + + ### WRITTEN BY AI ### + """ + return SyntheticTextDatasetConfig( + prompt_tokens=15, + output_tokens=10, + source="The quick brown fox jumps over the lazy dog.", + ) + + @pytest.fixture + def config_with_prefix(self): + """Fixture for configuration with prefix tokens. + + ### WRITTEN BY AI ### + """ + prefix_bucket = SyntheticTextPrefixBucketConfig( + bucket_weight=100, prefix_count=1, prefix_tokens=3 + ) + + return SyntheticTextDatasetConfig( + prefix_buckets=[prefix_bucket], + prompt_tokens=15, + output_tokens=10, + source="The quick brown fox jumps over the lazy dog.", + ) + + @pytest.mark.smoke + def test_generator_initialization(self, simple_config, mock_tokenizer): + """Test generator initialization. + + ### WRITTEN BY AI ### + """ + generator = SyntheticTextGenerator( + simple_config, mock_tokenizer, random_seed=42 + ) + + assert generator.config == simple_config + assert generator.processor == mock_tokenizer + assert generator.random_seed == 42 + + @pytest.mark.smoke + def test_basic_iteration(self, simple_config, mock_tokenizer): + """Test basic iteration functionality. + + ### WRITTEN BY AI ### + """ + generator = SyntheticTextGenerator( + simple_config, mock_tokenizer, random_seed=42 + ) + + items = [] + for i, item in enumerate(generator): + items.append(item) + if i >= 4: # Only get 5 items + break + + # Verify we get the expected number of items + assert len(items) == 5 + + # Verify each item has the required keys + for item in items: + assert "prefix" in item + assert "prompt" in item + assert "prompt_tokens_count" in item + assert "output_tokens_count" in item + assert isinstance(item["prefix"], str) + assert isinstance(item["prompt"], str) + assert isinstance(item["prompt_tokens_count"], int) + assert isinstance(item["output_tokens_count"], int) + + @pytest.mark.sanity + def test_create_prompt_method(self, simple_config, mock_tokenizer): + """Test _create_prompt method. + + ### WRITTEN BY AI ### + """ + from faker import Faker + + generator = SyntheticTextGenerator( + simple_config, mock_tokenizer, random_seed=42 + ) + faker = Faker() + faker.seed_instance(42) + + # Test normal case + result = generator._create_prompt(5, faker, "unique_prefix ") + assert isinstance(result, str) + # The result should be the decoded tokens (token_0 token_1 etc.) due to our mock + assert "token_" in result + + # Test zero tokens + result = generator._create_prompt(0, faker) + assert result == "" + + @pytest.mark.regression + def test_prefix_tokens_integration(self, config_with_prefix, mock_tokenizer): + """Test integration with prefix tokens. + + ### WRITTEN BY AI ### + """ + generator = SyntheticTextGenerator( + config_with_prefix, mock_tokenizer, random_seed=42 + ) + + items = [] + for i, item in enumerate(generator): + items.append(item) + if i >= 2: # Only get 3 items + break + + # Verify prefix is present in items + for item in items: + assert isinstance(item["prefix"], str) + + @pytest.mark.regression + def test_random_seeding_consistency(self, simple_config, mock_tokenizer): + """Test that same seed produces consistent results. + + ### WRITTEN BY AI ### + """ + # Create two generators with same seed + generator1 = SyntheticTextGenerator( + simple_config, mock_tokenizer, random_seed=42 + ) + generator2 = SyntheticTextGenerator( + simple_config, mock_tokenizer, random_seed=42 + ) + + items1 = [] + items2 = [] + for i, (item1, item2) in enumerate(zip(generator1, generator2, strict=False)): + items1.append(item1) + items2.append(item2) + if i >= 2: # Only get 3 items + break + + # With same seed and deterministic mocks, results should be identical + assert len(items1) == len(items2) + for item1, item2 in zip(items1, items2, strict=False): + assert item1["prompt_tokens_count"] == item2["prompt_tokens_count"] + assert item1["output_tokens_count"] == item2["output_tokens_count"] + + +class TestSyntheticDatasetDeserializer: + """Test cases for SyntheticDatasetDeserializer class. + + ### WRITTEN BY AI ### + """ + + @pytest.fixture + def mock_tokenizer(self): + """Fixture to provide a mocked tokenizer. + + ### WRITTEN BY AI ### + """ + tokenizer = Mock() + tokenizer.encode.side_effect = lambda text: list(range(len(text.split()))) + tokenizer.decode.side_effect = ( + lambda tokens, skip_special_tokens=False: " ".join( + f"token_{t}" for t in tokens[:5] + ) + ) + return tokenizer + + @pytest.mark.sanity + def test_load_config_file_yaml(self): + """Test loading YAML config file. + + ### WRITTEN BY AI ### + """ + config_data = { + "prompt_tokens": 60, + "output_tokens": 15, + "source": "yaml_test.txt", + "prefix_buckets": [ + {"bucket_weight": 100, "prefix_count": 1, "prefix_tokens": 3} + ], + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config_data, f) + yaml_path = f.name + + try: + deserializer = SyntheticTextDatasetDeserializer() + config = deserializer._load_config_file(yaml_path) + + assert config.prompt_tokens == 60 + assert config.output_tokens == 15 + assert config.source == "yaml_test.txt" + assert config.prefix_buckets[0].prefix_tokens == 3 # type: ignore [index] + finally: + Path(yaml_path).unlink() + + @pytest.mark.sanity + def test_load_config_file_config_extension(self): + """Test loading .config file. + + ### WRITTEN BY AI ### + """ + config_data = { + "prompt_tokens": 90, + "output_tokens": 35, + "prefix_buckets": [ + {"bucket_weight": 100, "prefix_count": 1, "prefix_tokens": 2} + ], + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".config", delete=False) as f: + yaml.dump(config_data, f) + config_path = f.name + + try: + deserializer = SyntheticTextDatasetDeserializer() + config = deserializer._load_config_file(config_path) + + assert config.prompt_tokens == 90 + assert config.output_tokens == 35 + assert config.prefix_buckets[0].prefix_tokens == 2 # type: ignore [index] + finally: + Path(config_path).unlink() + + @pytest.mark.smoke + def test_load_config_str_json(self): + """Test loading JSON string config. + + ### WRITTEN BY AI ### + """ + json_str = '{"prompt_tokens": 50, "output_tokens": 25}' + deserializer = SyntheticTextDatasetDeserializer() + config = deserializer._load_config_str(json_str) + + assert config.prompt_tokens == 50 + assert config.output_tokens == 25 + + @pytest.mark.smoke + def test_load_config_str_key_value(self): + """Test loading key-value string config. + + ### WRITTEN BY AI ### + """ + kv_str = "prompt_tokens=50,output_tokens=25" + deserializer = SyntheticTextDatasetDeserializer() + config = deserializer._load_config_str(kv_str) + + assert config.prompt_tokens == 50 + assert config.output_tokens == 25 + + @pytest.mark.sanity + def test_load_config_str_invalid_format(self): + """Test loading invalid format raises DataNotSupportedError. + + ### WRITTEN BY AI ### + """ + deserializer = SyntheticTextDatasetDeserializer() + with pytest.raises(DataNotSupportedError, match="Unsupported string data"): + deserializer._load_config_str("invalid_format_string") + + @pytest.mark.regression + def test_load_config_file_non_existent(self): + """Test loading non-existent file returns None. + + ### WRITTEN BY AI ### + """ + deserializer = SyntheticTextDatasetDeserializer() + config = deserializer._load_config_file("/non/existent/path.config") + assert config is None + + @pytest.mark.regression + def test_load_config_str_non_string(self): + """Test loading non-string returns None. + + ### WRITTEN BY AI ### + """ + deserializer = SyntheticTextDatasetDeserializer() + config = deserializer._load_config_str(123) + assert config is None + + @pytest.mark.smoke + def test_call_with_config_object(self, mock_tokenizer): + """Test calling deserializer with SyntheticTextDatasetConfig. + + ### WRITTEN BY AI ### + """ + config = SyntheticTextDatasetConfig(prompt_tokens=50, output_tokens=25) + deserializer = SyntheticTextDatasetDeserializer() + + result = deserializer( + data=config, + data_kwargs={}, + processor_factory=lambda: mock_tokenizer, + random_seed=42, + ) + + assert isinstance(result, IterableDataset) + + @pytest.mark.regression + def test_call_with_unsupported_data(self, mock_tokenizer): + """Test calling deserializer with unsupported data raises error. + + ### WRITTEN BY AI ### + """ + deserializer = SyntheticTextDatasetDeserializer() + + with pytest.raises(DataNotSupportedError, match="Unsupported data"): + deserializer( + data=123, + data_kwargs={}, + processor_factory=lambda: mock_tokenizer, + random_seed=42, + ) + + @pytest.mark.regression + def test_call_with_json_string(self, mock_tokenizer): + """Test calling deserializer with JSON string. + + ### WRITTEN BY AI ### + """ + json_str = '{"prompt_tokens": 50, "output_tokens": 25}' + deserializer = SyntheticTextDatasetDeserializer() + + result = deserializer( + data=json_str, + data_kwargs={}, + processor_factory=lambda: mock_tokenizer, + random_seed=42, + ) + + assert isinstance(result, IterableDataset) + + @pytest.mark.regression + def test_call_with_config_file(self, mock_tokenizer): + """Test calling deserializer with config file. + + ### WRITTEN BY AI ### + """ + config_data = {"prompt_tokens": 65, "output_tokens": 45} + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config_data, f) + config_path = f.name + + try: + deserializer = SyntheticTextDatasetDeserializer() + result = deserializer( + data=config_path, + data_kwargs={}, + processor_factory=lambda: mock_tokenizer, + random_seed=42, + ) + assert isinstance(result, IterableDataset) + finally: + Path(config_path).unlink() diff --git a/tests/unit/dataset/test_synthetic.py b/tests/unit/dataset/test_synthetic.py deleted file mode 100644 index 544634c8..00000000 --- a/tests/unit/dataset/test_synthetic.py +++ /dev/null @@ -1,873 +0,0 @@ -""" -Unit tests for guidellm.dataset.synthetic module. -""" - -import json -import tempfile -from pathlib import Path -from unittest.mock import Mock, patch - -import pytest -import yaml - -from guidellm.dataset.synthetic import ( - SyntheticDatasetConfig, - SyntheticDatasetCreator, - SyntheticTextItemsGenerator, -) - - -class TestSyntheticDatasetConfig: - """Test cases for SyntheticDatasetConfig class. - - ### WRITTEN BY AI ### - """ - - @pytest.mark.smoke - def test_config_creation_with_all_params(self): - """Test creating config with all parameters specified. - - ### WRITTEN BY AI ### - """ - config = SyntheticDatasetConfig( - prefix_tokens=5, - prompt_tokens=100, - prompt_tokens_stdev=10, - prompt_tokens_min=50, - prompt_tokens_max=150, - output_tokens=30, - output_tokens_stdev=5, - output_tokens_min=20, - output_tokens_max=40, - samples=500, - source="custom_text.txt", - ) - - assert config.prefix_tokens == 5 - assert config.prompt_tokens == 100 - assert config.prompt_tokens_stdev == 10 - assert config.prompt_tokens_min == 50 - assert config.prompt_tokens_max == 150 - assert config.output_tokens == 30 - assert config.output_tokens_stdev == 5 - assert config.output_tokens_min == 20 - assert config.output_tokens_max == 40 - assert config.samples == 500 - assert config.source == "custom_text.txt" - - @pytest.mark.regression - def test_parse_json_string(self): - """Test parsing JSON string configuration. - - ### WRITTEN BY AI ### - """ - json_str = json.dumps( - { - "prompt_tokens": 75, - "output_tokens": 25, - "samples": 200, - "source": "test.txt", - "prefix_tokens": 10, - } - ) - - config = SyntheticDatasetConfig.parse_str(json_str) - - assert config.prompt_tokens == 75 - assert config.output_tokens == 25 - assert config.samples == 200 - assert config.source == "test.txt" - assert config.prefix_tokens == 10 - - @pytest.mark.regression - def test_parse_key_value_pairs(self): - """Test parsing key-value pairs configuration. - - ### WRITTEN BY AI ### - """ - kv_str = "prompt_tokens=80,output_tokens=30,samples=300,source=data.txt,prefix_tokens=5" # noqa: E501 - - config = SyntheticDatasetConfig.parse_str(kv_str) - - assert config.prompt_tokens == 80 - assert config.output_tokens == 30 - assert config.samples == 300 - assert config.source == "data.txt" - assert config.prefix_tokens == 5 - - @pytest.mark.sanity - def test_parse_yaml_file(self): - """Test parsing YAML file configuration. - - ### WRITTEN BY AI ### - """ - config_data = { - "prompt_tokens": 60, - "output_tokens": 15, - "samples": 100, - "source": "yaml_test.txt", - "prefix_tokens": 3, - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - yaml.dump(config_data, f) - yaml_path = f.name - - try: - config = SyntheticDatasetConfig.parse_str(yaml_path) - - assert config.prompt_tokens == 60 - assert config.output_tokens == 15 - assert config.samples == 100 - assert config.source == "yaml_test.txt" - assert config.prefix_tokens == 3 - finally: - Path(yaml_path).unlink() - - @pytest.mark.sanity - def test_parse_config_file(self): - """Test parsing .config file. - - ### WRITTEN BY AI ### - """ - config_data = { - "prompt_tokens": 90, - "output_tokens": 35, - "samples": 150, - "prefix_tokens": 2, - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".config", delete=False) as f: - yaml.dump(config_data, f) - config_path = f.name - - try: - config = SyntheticDatasetConfig.parse_str(config_path) - - assert config.prompt_tokens == 90 - assert config.output_tokens == 35 - assert config.samples == 150 - assert config.prefix_tokens == 2 - finally: - Path(config_path).unlink() - - @pytest.mark.regression - def test_parse_path_object(self): - """Test parsing with Path object. - - ### WRITTEN BY AI ### - """ - config_data = {"prompt_tokens": 45, "output_tokens": 25} - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - yaml.dump(config_data, f) - yaml_path = Path(f.name) - - try: - config = SyntheticDatasetConfig.parse_str(yaml_path) - assert config.prompt_tokens == 45 - assert config.output_tokens == 25 - finally: - yaml_path.unlink() - - @pytest.mark.sanity - def test_parse_invalid_format(self): - """Test parsing invalid format raises ValueError. - - ### WRITTEN BY AI ### - """ - with pytest.raises(ValueError, match="Unsupported data format"): - SyntheticDatasetConfig.parse_str("invalid_format_string") - - @pytest.mark.sanity - def test_validation_positive_values(self): - """Test that negative or zero values are rejected. - - ### WRITTEN BY AI ### - """ - with pytest.raises(ValueError): - SyntheticDatasetConfig(prompt_tokens=0, output_tokens=20) - - with pytest.raises(ValueError): - SyntheticDatasetConfig(prompt_tokens=20, output_tokens=0) - - with pytest.raises(ValueError): - SyntheticDatasetConfig(prompt_tokens=20, output_tokens=10, samples=0) - - with pytest.raises(ValueError): - SyntheticDatasetConfig(prompt_tokens=20, output_tokens=10, prefix_tokens=-1) - - @pytest.mark.regression - def test_validation_optional_positive_values(self): - """Test that optional parameters reject negative values. - - ### WRITTEN BY AI ### - """ - with pytest.raises(ValueError): - SyntheticDatasetConfig( - prompt_tokens=20, output_tokens=10, prompt_tokens_stdev=-1 - ) - - with pytest.raises(ValueError): - SyntheticDatasetConfig( - prompt_tokens=20, output_tokens=10, prompt_tokens_min=-1 - ) - - with pytest.raises(ValueError): - SyntheticDatasetConfig( - prompt_tokens=20, output_tokens=10, output_tokens_max=0 - ) - - @pytest.mark.regression - def test_parse_json_method_directly(self): - """Test parse_json static method directly. - - ### WRITTEN BY AI ### - """ - json_data = {"prompt_tokens": 100, "output_tokens": 50} - json_str = json.dumps(json_data) - - config = SyntheticDatasetConfig.parse_json(json_str) - - assert config.prompt_tokens == 100 - assert config.output_tokens == 50 - - @pytest.mark.regression - def test_parse_key_value_pairs_method_directly(self): - """Test parse_key_value_pairs static method directly. - - ### WRITTEN BY AI ### - """ - kv_str = "prompt_tokens=75,output_tokens=35" - - config = SyntheticDatasetConfig.parse_key_value_pairs(kv_str) - - assert config.prompt_tokens == 75 - assert config.output_tokens == 35 - - @pytest.mark.regression - def test_parse_config_file_method_directly(self): - """Test parse_config_file static method directly. - - ### WRITTEN BY AI ### - """ - config_data = {"prompt_tokens": 65, "output_tokens": 45} - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - yaml.dump(config_data, f) - config_path = f.name - - try: - config = SyntheticDatasetConfig.parse_config_file(config_path) - assert config.prompt_tokens == 65 - assert config.output_tokens == 45 - finally: - Path(config_path).unlink() - - -class TestSyntheticTextItemsGenerator: - """Test cases for SyntheticTextItemsGenerator class. - - ### WRITTEN BY AI ### - """ - - @pytest.fixture - def mock_tokenizer(self): - """Fixture to provide a mocked tokenizer. - - ### WRITTEN BY AI ### - """ - tokenizer = Mock() - tokenizer.get_vocab.return_value = {f"token_{i}": i for i in range(1000)} - tokenizer.encode.side_effect = lambda text: [1, 2, 3] * (len(text) // 10 + 1) - tokenizer.decode.side_effect = ( - lambda tokens, skip_special_tokens=False: " ".join( - f"token_{t}" for t in tokens[:5] - ) - ) - return tokenizer - - @pytest.fixture - def simple_config(self): - """Fixture for simple configuration. - - ### WRITTEN BY AI ### - """ - return SyntheticDatasetConfig( - prompt_tokens=15, - output_tokens=10, - samples=5, - source="The quick brown fox jumps over the lazy dog.", - ) - - @pytest.fixture - def config_with_prefix(self): - """Fixture for configuration with prefix tokens. - - ### WRITTEN BY AI ### - """ - return SyntheticDatasetConfig( - prefix_tokens=3, - prompt_tokens=15, - output_tokens=10, - samples=5, - source="The quick brown fox jumps over the lazy dog.", - ) - - @pytest.fixture - def complex_config(self): - """Fixture for complex configuration with variance. - - ### WRITTEN BY AI ### - """ - return SyntheticDatasetConfig( - prompt_tokens=20, - prompt_tokens_stdev=5, - prompt_tokens_min=10, - prompt_tokens_max=30, - output_tokens=15, - output_tokens_stdev=3, - output_tokens_min=10, - output_tokens_max=20, - samples=10, - source="The quick brown fox jumps over the lazy dog.", - ) - - @pytest.mark.smoke - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - def test_generator_initialization( - self, mock_text_creator, simple_config, mock_tokenizer - ): - """Test generator initialization. - - ### WRITTEN BY AI ### - """ - generator = SyntheticTextItemsGenerator( - simple_config, mock_tokenizer, random_seed=42 - ) - - assert generator.config == simple_config - assert generator.processor == mock_tokenizer - assert generator.random_seed == 42 - mock_text_creator.assert_called_once_with(data=simple_config.source) - - @pytest.mark.smoke - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - @patch("guidellm.dataset.synthetic.IntegerRangeSampler") - def test_basic_iteration( - self, mock_sampler, mock_text_creator, simple_config, mock_tokenizer - ): - """Test basic iteration functionality. - - ### WRITTEN BY AI ### - """ - # Setup mocks - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word1", "word2", "word3"] * 100 - mock_text_creator_instance.create_text.return_value = "sample text" - mock_text_creator.return_value = mock_text_creator_instance - - # Mock IntegerRangeSampler to return iterators - def mock_sampler_side_effect(*args, **kwargs): - mock_instance = Mock() - mock_instance.__iter__ = Mock(return_value=iter([15, 15, 15, 15, 15])) - return mock_instance - - mock_sampler.side_effect = mock_sampler_side_effect - - generator = SyntheticTextItemsGenerator( - simple_config, mock_tokenizer, random_seed=42 - ) - - items = list(generator) - - # Verify we get the expected number of items - assert len(items) == simple_config.samples - - # Verify each item has the required keys - for item in items: - assert "prompt" in item - assert "prompt_tokens_count" in item - assert "output_tokens_count" in item - assert isinstance(item["prompt"], str) - assert isinstance(item["prompt_tokens_count"], int) - assert isinstance(item["output_tokens_count"], int) - - @pytest.mark.sanity - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - def test_create_prompt_method( - self, mock_text_creator, simple_config, mock_tokenizer - ): - """Test _create_prompt method. - - ### WRITTEN BY AI ### - """ - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word"] * 100 - mock_text_creator_instance.create_text.return_value = "test text" - mock_text_creator.return_value = mock_text_creator_instance - - mock_tokenizer.encode.return_value = [1, 2, 3] - - generator = SyntheticTextItemsGenerator( - simple_config, mock_tokenizer, random_seed=42 - ) - - # Test normal case - result = generator._create_prompt(5, 0, 42) - assert result == [42, 1, 2, 3] - - # Test zero tokens - result = generator._create_prompt(0, 0, 42) - assert result == [] - - # Test without unique prefix - result = generator._create_prompt(3, 0) - assert result == [1, 2, 3] - - @pytest.mark.regression - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - def test_create_prompt_binary_search( - self, mock_text_creator, simple_config, mock_tokenizer - ): - """Test binary search logic in _create_prompt. - - ### WRITTEN BY AI ### - """ - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word"] * 1000 - mock_text_creator_instance.create_text.side_effect = lambda start, length: ( - "text " * max(1, length // 4) - ).strip() - mock_text_creator.return_value = mock_text_creator_instance - - # Mock tokenizer to return different lengths based on input - def mock_encode(text): - return [1] * len(text.split()) - - mock_tokenizer.encode.side_effect = mock_encode - - generator = SyntheticTextItemsGenerator( - simple_config, mock_tokenizer, random_seed=42 - ) - - # Test that binary search finds appropriate length - result = generator._create_prompt(5, 0, 42) - assert len(result) >= 4 # Should include prefix + some tokens - - @pytest.mark.sanity - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - @patch("guidellm.dataset.synthetic.IntegerRangeSampler") - def test_prefix_tokens_integration( - self, mock_sampler, mock_text_creator, config_with_prefix, mock_tokenizer - ): - """Test integration with prefix tokens. - - ### WRITTEN BY AI ### - """ - # Setup mocks - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word"] * 100 - mock_text_creator_instance.create_text.return_value = "sample text" - mock_text_creator.return_value = mock_text_creator_instance - - mock_sampler_instance = Mock() - mock_sampler_instance.__iter__ = Mock(return_value=iter([15, 15, 15, 15, 15])) - mock_sampler.return_value = mock_sampler_instance - - generator = SyntheticTextItemsGenerator( - config_with_prefix, mock_tokenizer, random_seed=42 - ) - - items = list(generator) - - # Verify prompt_tokens_count includes prefix - for item in items: - assert item["prompt_tokens_count"] == config_with_prefix.prefix_tokens + 15 - - @pytest.mark.regression - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - @patch("guidellm.dataset.synthetic.IntegerRangeSampler") - def test_random_seeding_consistency( - self, mock_sampler, mock_text_creator, simple_config, mock_tokenizer - ): - """Test that same seed produces consistent results. - - ### WRITTEN BY AI ### - """ - # Setup mocks - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word"] * 100 - mock_text_creator_instance.create_text.return_value = "sample text" - mock_text_creator.return_value = mock_text_creator_instance - - # Create consistent mock sampler behavior - call_count = 0 - - def mock_sampler_side_effect(*args, **kwargs): - nonlocal call_count - mock_instance = Mock() - # Return same sequence for both prompt and output tokens - if call_count % 2 == 0: # prompt tokens - mock_instance.__iter__ = Mock(return_value=iter([15, 16, 17, 18, 19])) - else: # output tokens - mock_instance.__iter__ = Mock(return_value=iter([10, 11, 12, 13, 14])) - call_count += 1 - return mock_instance - - mock_sampler.side_effect = mock_sampler_side_effect - - # Create two generators with same seed - generator1 = SyntheticTextItemsGenerator( - simple_config, mock_tokenizer, random_seed=42 - ) - generator2 = SyntheticTextItemsGenerator( - simple_config, mock_tokenizer, random_seed=42 - ) - - items1 = list(generator1) - items2 = list(generator2) - - # Results should be identical with same seed - assert len(items1) == len(items2) - for item1, item2 in zip(items1, items2, strict=False): - assert item1["prompt"] == item2["prompt"] - assert item1["prompt_tokens_count"] == item2["prompt_tokens_count"] - assert item1["output_tokens_count"] == item2["output_tokens_count"] - - @pytest.mark.regression - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - @patch("guidellm.dataset.synthetic.IntegerRangeSampler") - def test_variance_configuration( - self, mock_sampler, mock_text_creator, complex_config, mock_tokenizer - ): - """Test that variance configuration is properly used. - - ### WRITTEN BY AI ### - """ - # Setup mocks - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word"] * 100 - mock_text_creator_instance.create_text.return_value = "sample text" - mock_text_creator.return_value = mock_text_creator_instance - - # Fix tokenizer mock to handle the create_text return properly - mock_tokenizer.encode.side_effect = ( - lambda text: [1, 2, 3] if isinstance(text, str) else [1, 2, 3] - ) - - # Setup mock sampler to track calls - def mock_sampler_side_effect(*args, **kwargs): - mock_instance = Mock() - mock_instance.__iter__ = Mock(return_value=iter([20, 18, 22, 19, 21] * 2)) - return mock_instance - - mock_sampler.side_effect = mock_sampler_side_effect - - generator = SyntheticTextItemsGenerator( - complex_config, mock_tokenizer, random_seed=42 - ) - - # Initialize the generator to trigger sampler creation - generator_iter = iter(generator) - next(generator_iter) - - # Verify that IntegerRangeSampler is called with correct parameters - assert mock_sampler.call_count == 2 - - # Check prompt tokens sampler call - prompt_call = mock_sampler.call_args_list[0] - assert prompt_call[1]["average"] == complex_config.prompt_tokens - assert prompt_call[1]["variance"] == complex_config.prompt_tokens_stdev - assert prompt_call[1]["min_value"] == complex_config.prompt_tokens_min - assert prompt_call[1]["max_value"] == complex_config.prompt_tokens_max - assert prompt_call[1]["random_seed"] == 42 - - # Check output tokens sampler call - output_call = mock_sampler.call_args_list[1] - assert output_call[1]["average"] == complex_config.output_tokens - assert output_call[1]["variance"] == complex_config.output_tokens_stdev - assert output_call[1]["min_value"] == complex_config.output_tokens_min - assert output_call[1]["max_value"] == complex_config.output_tokens_max - assert output_call[1]["random_seed"] == 43 # 42 + 1 - - @pytest.mark.regression - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - def test_unique_prefix_generation( - self, mock_text_creator, simple_config, mock_tokenizer - ): - """Test that unique prefixes are generated for each request. - - ### WRITTEN BY AI ### - """ - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word"] * 100 - mock_text_creator_instance.create_text.return_value = "sample text" - mock_text_creator.return_value = mock_text_creator_instance - - # Mock the cycle to return predictable values - with patch("guidellm.dataset.synthetic.cycle") as mock_cycle: - mock_cycle.return_value = iter([100, 101, 102, 103, 104]) - - generator = SyntheticTextItemsGenerator( - simple_config, mock_tokenizer, random_seed=42 - ) - - # Access the iterator to trigger the cycle creation - generator_iter = iter(generator) - next(generator_iter) - - # Verify cycle was called with vocab values - mock_cycle.assert_called_once() - - -class TestSyntheticDatasetCreator: - """Test cases for SyntheticDatasetCreator class. - - ### WRITTEN BY AI ### - """ - - @pytest.mark.sanity - def test_is_supported_path_config_file(self): - """Test is_supported with config file paths. - - ### WRITTEN BY AI ### - """ - with tempfile.NamedTemporaryFile(suffix=".config", delete=False) as f: - config_path = Path(f.name) - - try: - assert SyntheticDatasetCreator.is_supported(config_path, None) - finally: - config_path.unlink() - - @pytest.mark.sanity - def test_is_supported_path_yaml_file(self): - """Test is_supported with YAML file paths. - - ### WRITTEN BY AI ### - """ - with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f: - yaml_path = Path(f.name) - - try: - assert SyntheticDatasetCreator.is_supported(yaml_path, None) - finally: - yaml_path.unlink() - - @pytest.mark.smoke - def test_is_supported_json_string(self): - """Test is_supported with JSON string. - - ### WRITTEN BY AI ### - """ - json_str = '{"prompt_tokens": 50, "output_tokens": 25}' - assert SyntheticDatasetCreator.is_supported(json_str, None) - - @pytest.mark.smoke - def test_is_supported_key_value_string(self): - """Test is_supported with key-value string. - - ### WRITTEN BY AI ### - """ - kv_str = "prompt_tokens=50,output_tokens=25" - assert SyntheticDatasetCreator.is_supported(kv_str, None) - - @pytest.mark.sanity - def test_is_supported_config_filename_string(self): - """Test is_supported with config filename string. - - ### WRITTEN BY AI ### - """ - assert SyntheticDatasetCreator.is_supported("config.yaml", None) - assert SyntheticDatasetCreator.is_supported("settings.config", None) - - @pytest.mark.sanity - def test_is_not_supported_regular_string(self): - """Test is_supported returns False for regular strings. - - ### WRITTEN BY AI ### - """ - assert not SyntheticDatasetCreator.is_supported("regular string", None) - assert not SyntheticDatasetCreator.is_supported("single=pair", None) - - @pytest.mark.regression - def test_is_not_supported_non_existent_path(self): - """Test is_supported returns False for non-existent paths. - - ### WRITTEN BY AI ### - """ - non_existent_path = Path("/non/existent/path.config") - assert not SyntheticDatasetCreator.is_supported(non_existent_path, None) - - @pytest.mark.regression - def test_is_not_supported_other_types(self): - """Test is_supported returns False for other data types. - - ### WRITTEN BY AI ### - """ - assert not SyntheticDatasetCreator.is_supported(123, None) - assert not SyntheticDatasetCreator.is_supported(["list"], None) - assert not SyntheticDatasetCreator.is_supported({"dict": "value"}, None) - - @pytest.mark.smoke - @patch("guidellm.dataset.synthetic.check_load_processor") - @patch("guidellm.dataset.synthetic.SyntheticTextItemsGenerator") - @patch("guidellm.dataset.synthetic.Dataset") - def test_handle_create_basic( - self, mock_dataset, mock_generator, mock_check_processor - ): - """Test handle_create basic functionality. - - ### WRITTEN BY AI ### - """ - # Setup mocks - mock_processor = Mock() - mock_check_processor.return_value = mock_processor - - mock_generator_instance = Mock() - mock_generator_instance.__iter__ = Mock( - return_value=iter( - [ - { - "prompt": "test", - "prompt_tokens_count": 10, - "output_tokens_count": 5, - } - ] - ) - ) - mock_generator.return_value = mock_generator_instance - - mock_dataset_instance = Mock() - mock_dataset.from_list.return_value = mock_dataset_instance - - # Test - data = '{"prompt_tokens": 50, "output_tokens": 25}' - result = SyntheticDatasetCreator.handle_create( - data=data, - data_args=None, - processor="gpt2", - processor_args=None, - random_seed=42, - ) - - # Verify - mock_check_processor.assert_called_once_with( - "gpt2", - None, - error_msg="Processor/tokenizer required for synthetic dataset generation.", - ) - mock_generator.assert_called_once() - mock_dataset.from_list.assert_called_once() - assert result == mock_dataset_instance - - @pytest.mark.sanity - @patch("guidellm.dataset.synthetic.check_load_processor") - def test_handle_create_processor_required(self, mock_check_processor): - """Test handle_create requires processor. - - ### WRITTEN BY AI ### - """ - mock_check_processor.side_effect = ValueError("Processor required") - - data = '{"prompt_tokens": 50, "output_tokens": 25}' - - with pytest.raises(ValueError, match="Processor required"): - SyntheticDatasetCreator.handle_create( - data=data, - data_args=None, - processor=None, - processor_args=None, - random_seed=42, - ) - - @pytest.mark.regression - @patch("guidellm.dataset.synthetic.check_load_processor") - @patch("guidellm.dataset.synthetic.SyntheticTextItemsGenerator") - @patch("guidellm.dataset.synthetic.Dataset") - def test_handle_create_with_data_args( - self, mock_dataset, mock_generator, mock_check_processor - ): - """Test handle_create with data_args. - - ### WRITTEN BY AI ### - """ - # Setup mocks - mock_processor = Mock() - mock_check_processor.return_value = mock_processor - - mock_generator_instance = Mock() - mock_generator_instance.__iter__ = Mock(return_value=iter([])) - mock_generator.return_value = mock_generator_instance - - mock_dataset_instance = Mock() - mock_dataset.from_list.return_value = mock_dataset_instance - - # Test with data_args - data = '{"prompt_tokens": 50, "output_tokens": 25}' - data_args = {"features": "custom_features"} - - SyntheticDatasetCreator.handle_create( - data=data, - data_args=data_args, - processor="gpt2", - processor_args=None, - random_seed=42, - ) - - # Verify data_args are passed to Dataset.from_list - mock_dataset.from_list.assert_called_once_with([], **data_args) - - @pytest.mark.sanity - def test_extract_args_column_mappings_empty(self): - """Test extract_args_column_mappings with empty data_args. - - ### WRITTEN BY AI ### - """ - result = SyntheticDatasetCreator.extract_args_column_mappings(None) - - expected = { - "prompt_column": "prompt", - "prompt_tokens_count_column": "prompt_tokens_count", - "output_tokens_count_column": "output_tokens_count", - } - assert result == expected - - @pytest.mark.regression - def test_extract_args_column_mappings_with_parent_mappings(self): - """Test extract_args_column_mappings rejects column mappings. - - ### WRITTEN BY AI ### - """ - with ( - patch.object( - SyntheticDatasetCreator.__bases__[0], - "extract_args_column_mappings", - return_value={"prompt_column": "custom_prompt"}, - ), - pytest.raises(ValueError, match="Column mappings are not supported"), - ): - SyntheticDatasetCreator.extract_args_column_mappings({"some": "args"}) - - @pytest.mark.regression - def test_extract_args_column_mappings_no_parent_mappings(self): - """Test extract_args_column_mappings with no parent mappings. - - ### WRITTEN BY AI ### - """ - with patch.object( - SyntheticDatasetCreator.__bases__[0], - "extract_args_column_mappings", - return_value={}, - ): - result = SyntheticDatasetCreator.extract_args_column_mappings( - {"some": "args"} - ) - - expected = { - "prompt_column": "prompt", - "prompt_tokens_count_column": "prompt_tokens_count", - "output_tokens_count_column": "output_tokens_count", - } - assert result == expected diff --git a/tests/unit/mock_benchmark.py b/tests/unit/mock_benchmark.py index d7bfe7c9..9201d621 100644 --- a/tests/unit/mock_benchmark.py +++ b/tests/unit/mock_benchmark.py @@ -7,8 +7,8 @@ GenerativeMetrics, GenerativeRequestStats, ) -from guidellm.benchmark.objects import BenchmarkerDict, SchedulerDict from guidellm.benchmark.profile import SynchronousProfile +from guidellm.benchmark.schemas import BenchmarkerDict, SchedulerDict from guidellm.scheduler import ScheduledRequestInfo, SchedulerState, SynchronousStrategy from guidellm.utils import ( DistributionSummary, diff --git a/tests/unit/utils/test_encoding.py b/tests/unit/utils/test_encoding.py index 5664bcb0..cfdf14e2 100644 --- a/tests/unit/utils/test_encoding.py +++ b/tests/unit/utils/test_encoding.py @@ -6,11 +6,11 @@ import pytest from pydantic import BaseModel, Field -from guidellm.backends.objects import ( +from guidellm.scheduler.schemas import RequestSchedulerTimings, ScheduledRequestInfo +from guidellm.schemas.response import ( GenerationRequest, GenerationResponse, ) -from guidellm.scheduler.objects import RequestSchedulerTimings, ScheduledRequestInfo from guidellm.utils.encoding import Encoder, MessageEncoding, Serializer diff --git a/tests/unit/utils/test_functions.py b/tests/unit/utils/test_functions.py index 3e542ca8..96b7b920 100644 --- a/tests/unit/utils/test_functions.py +++ b/tests/unit/utils/test_functions.py @@ -190,7 +190,8 @@ def force_us_eastern_timezone(monkeypatch): ## WRITTEN BY AI ## """ monkeypatch.setenv("TZ", "America/New_York") - time.tzset() # Propagates the change to the underlying C library + time.tzset() # Propagates the change to the underlying C library + class TestSafeFormatTimestamp: """Test suite for safe_format_timestamp function."""