diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 7f0e29d14f16..0cf1e85d4e8e 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -157,11 +157,9 @@ def test_models_distributed( and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4" + and enable_prompt_embeds ): # noqa - if enable_prompt_embeds: - pytest.skip("enable_prompt_embeds does not work with ray compiled dag.") - monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") - monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") + pytest.skip("enable_prompt_embeds does not work with ray compiled dag.") if attention_backend: monkeypatch_context.setenv( diff --git a/tests/distributed/test_multi_node_assignment.py b/tests/distributed/test_multi_node_assignment.py index a660bd1420d0..5d3f524f4d2f 100644 --- a/tests/distributed/test_multi_node_assignment.py +++ b/tests/distributed/test_multi_node_assignment.py @@ -18,8 +18,8 @@ from vllm import initialize_ray_cluster from vllm.config import ParallelConfig -from vllm.executor.ray_utils import _wait_until_pg_removed from vllm.utils.network_utils import get_ip +from vllm.v1.executor.ray_utils import _wait_until_pg_removed VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 24f62cff299a..8ee0d12df640 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -305,10 +305,8 @@ def _compare_tp( common_args.extend(["--max-num-seqs", f"{max_num_seqs}"]) if distributed_backend == "ray": - # For V1, test Ray Compiled Graph for all the tests + # Test Ray Compiled Graph for all the tests pp_env = { - "VLLM_USE_RAY_COMPILED_DAG": "1", - "VLLM_USE_RAY_SPMD_WORKER": "1", "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1", } # Temporary. Currently when zeromq + SPMD is used, it does not properly diff --git a/tests/model_executor/model_loader/tensorizer_loader/conftest.py b/tests/model_executor/model_loader/tensorizer_loader/conftest.py index 31f2fa0b8de2..826ecec71e6c 100644 --- a/tests/model_executor/model_loader/tensorizer_loader/conftest.py +++ b/tests/model_executor/model_loader/tensorizer_loader/conftest.py @@ -9,7 +9,7 @@ from vllm.model_executor.model_loader import tensorizer as tensorizer_mod from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port -from vllm.v1.executor.abstract import UniProcExecutor +from vllm.v1.executor import UniProcExecutor from vllm.v1.worker.worker_base import WorkerWrapperBase MODEL_REF = "facebook/opt-125m" diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 341a1f335780..becedb59f644 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -15,7 +15,8 @@ from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore -from vllm.v1.executor.abstract import Executor, UniProcExecutor +from vllm.v1.executor.abstract import Executor +from vllm.v1.executor.uniproc_executor import UniProcExecutor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import ModelRunnerOutput diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py index 211abb463e2b..c9256cd91a4e 100644 --- a/tools/pre_commit/check_pickle_imports.py +++ b/tools/pre_commit/check_pickle_imports.py @@ -17,8 +17,6 @@ # add to this list if absolutely necessary and after careful security review. ALLOWED_FILES = { # pickle - "vllm/v1/serial_utils.py", - "vllm/v1/executor/multiproc_executor.py", "vllm/multimodal/hasher.py", "vllm/transformers_utils/config.py", "vllm/model_executor/models/registry.py", @@ -38,11 +36,13 @@ "benchmarks/cutlass_benchmarks/w8a8_benchmarks.py", "benchmarks/cutlass_benchmarks/sparse_benchmarks.py", # cloudpickle - "vllm/executor/mp_distributed_executor.py", - "vllm/executor/ray_distributed_executor.py", + "vllm/v1/executor/multiproc_executor.py", + "vllm/v1/executor/ray_executor.py", "vllm/entrypoints/llm.py", "vllm/utils/__init__.py", "tests/utils.py", + # pickle and cloudpickle + "vllm/v1/serial_utils.py", } PICKLE_RE = re.compile( diff --git a/vllm/__init__.py b/vllm/__init__.py index b9c868de6886..19b2cdc673c4 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -21,7 +21,7 @@ "AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine", "LLMEngine": ".engine.llm_engine:LLMEngine", "LLM": ".entrypoints.llm:LLM", - "initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster", + "initialize_ray_cluster": ".v1.executor.ray_utils:initialize_ray_cluster", "PromptType": ".inputs:PromptType", "TextPrompt": ".inputs:TextPrompt", "TokensPrompt": ".inputs:TokensPrompt", @@ -45,7 +45,6 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM - from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import ( @@ -62,6 +61,7 @@ ) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams + from vllm.v1.executor.ray_utils import initialize_ray_cluster from ._bc_linter import bc_linter_include, bc_linter_skip else: diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index b79bc6983b54..e8847354bb09 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -25,11 +25,11 @@ from ray.runtime_env import RuntimeEnv from ray.util.placement_group import PlacementGroup - from vllm.executor.executor_base import ExecutorBase + from vllm.v1.executor import Executor else: RuntimeEnv = Any PlacementGroup = Any - ExecutorBase = Any + Executor = Any logger = init_logger(__name__) @@ -189,7 +189,7 @@ class ParallelConfig: """ray distributed model workers placement group.""" distributed_executor_backend: ( - str | DistributedExecutorBackend | type[ExecutorBase] | None + str | DistributedExecutorBackend | type[Executor] | None ) = None """Backend to use for distributed model workers, either "ray" or "mp" (multiprocessing). If the product @@ -511,7 +511,7 @@ def __post_init__(self) -> None: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. - from vllm.executor import ray_utils + from vllm.v1.executor import ray_utils backend: DistributedExecutorBackend = "mp" ray_found = ray_utils.ray_is_available() @@ -553,6 +553,12 @@ def __post_init__(self) -> None: if self.distributed_executor_backend is None and self.world_size == 1: self.distributed_executor_backend = "uni" + if self.max_parallel_loading_workers is not None: + logger.warning( + "max_parallel_loading_workers is currently " + "not supported and will be ignored." + ) + @property def use_ray(self) -> bool: return self.distributed_executor_backend == "ray" or ( @@ -563,7 +569,7 @@ def use_ray(self) -> bool: @model_validator(mode="after") def _verify_args(self) -> Self: # Lazy import to avoid circular import - from vllm.executor.executor_base import ExecutorBase + from vllm.v1.executor import Executor # Enable batch invariance settings if requested if vllm_is_batch_invariant(): @@ -574,17 +580,17 @@ def _verify_args(self) -> Self: and not isinstance(self.distributed_executor_backend, str) and not ( isinstance(self.distributed_executor_backend, type) - and issubclass(self.distributed_executor_backend, ExecutorBase) + and issubclass(self.distributed_executor_backend, Executor) ) ): raise ValueError( "Unrecognized distributed executor backend " f"{self.distributed_executor_backend}. Supported " "values are 'ray', 'mp' 'uni', 'external_launcher', " - " custom ExecutorBase subclass or its import path." + " custom Executor subclass or its import path." ) if self.use_ray: - from vllm.executor import ray_utils + from vllm.v1.executor import ray_utils ray_utils.assert_ray_available() diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index d5eb07730923..402c29eb641f 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -107,12 +107,6 @@ class SchedulerConfig: NOTE: This is not currently configurable. It will be overridden by max_num_batched_tokens in case max multimodal embedding size is larger.""" - send_delta_data: bool = False - """Private API. If used, scheduler sends delta data to - workers instead of an entire data. It should be enabled only - when SPMD worker architecture is enabled. I.e., - VLLM_USE_RAY_SPMD_WORKER=1""" - policy: SchedulerPolicy = "fcfs" """The scheduling policy to use:\n - "fcfs" means first come first served, i.e. requests are handled in order diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index f20cdfab340f..a7724a86cc6a 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -31,7 +31,7 @@ ) if USE_RAY: - from vllm.executor import ray_utils + from vllm.v1.executor import ray_utils class TpuCommunicator(DeviceCommunicatorBase): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b50fbe130b1f..a06ec92b51c8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -88,12 +88,12 @@ from vllm.v1.sample.logits_processor import LogitsProcessor if TYPE_CHECKING: - from vllm.executor.executor_base import ExecutorBase from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.model_loader import LoadFormats from vllm.usage.usage_lib import UsageContext + from vllm.v1.executor import Executor else: - ExecutorBase = Any + Executor = Any QuantizationMethods = Any LoadFormats = Any UsageContext = Any @@ -369,7 +369,7 @@ class EngineArgs: # is intended for expert use only. The API may change without # notice. distributed_executor_backend: ( - str | DistributedExecutorBackend | type[ExecutorBase] | None + str | DistributedExecutorBackend | type[Executor] | None ) = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size @@ -1549,7 +1549,6 @@ def create_engine_config( disable_chunked_mm_input=self.disable_chunked_mm_input, is_multimodal_model=model_config.is_multimodal_model, is_encoder_decoder=model_config.is_encoder_decoder, - send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, scheduler_cls=self.scheduler_cls, max_num_partial_prefills=self.max_num_partial_prefills, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index d2d77fce411a..b96e0e7c860f 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -26,7 +26,7 @@ from vllm.utils.network_utils import get_tcp_uri from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines -from vllm.v1.executor.abstract import Executor +from vllm.v1.executor import Executor from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure diff --git a/vllm/envs.py b/vllm/envs.py index c8263de0dd9c..018b8c1c43c7 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -56,8 +56,6 @@ VLLM_XLA_CHECK_RECOMPILATION: bool = False VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True - VLLM_USE_RAY_SPMD_WORKER: bool = False - VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True @@ -623,22 +621,6 @@ def get_vllm_port() -> int | None: "VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), # (CPU backend only) whether to use SGL kernels, optimized for small batch. "VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), - # If the env var is set, then all workers will execute as separate - # processes from the engine, and we use the same mechanism to trigger - # execution on all workers. - # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. - "VLLM_USE_RAY_SPMD_WORKER": lambda: bool( - int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0")) - ), - # If the env var is set, it uses the Ray's Compiled Graph - # (previously known as ADAG) API which optimizes the - # control plane overhead. - # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. - # Note that this variable is set to 1 in V1 by default - # when ray distributed executor is used. - "VLLM_USE_RAY_COMPILED_DAG": lambda: bool( - int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0")) - ), # If the env var is set, Ray Compiled Graph uses the specified # channel type to communicate between workers belonging to # different pipeline-parallel stages. @@ -646,20 +628,17 @@ def get_vllm_port() -> int | None: # - "auto": use the default channel type # - "nccl": use NCCL for communication # - "shm": use shared memory and gRPC for communication - # This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set. "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": env_with_choices( "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto", ["auto", "nccl", "shm"] ), # If the env var is set, it enables GPU communication overlap - # (experimental feature) in Ray's Compiled Graph. This flag is ignored if - # VLLM_USE_RAY_COMPILED_DAG is not set. + # (experimental feature) in Ray's Compiled Graph. "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM": lambda: bool( int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0")) ), # If the env var is set, it uses a Ray Communicator wrapping # vLLM's pipeline parallelism communicator to interact with Ray's # Compiled Graph. Otherwise, it uses Ray's NCCL communicator. - # This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set. "VLLM_USE_RAY_WRAPPED_PP_COMM": lambda: bool( int(os.getenv("VLLM_USE_RAY_WRAPPED_PP_COMM", "1")) ), diff --git a/vllm/executor/__init__.py b/vllm/executor/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py deleted file mode 100644 index 9de2249f6c05..000000000000 --- a/vllm/executor/executor_base.py +++ /dev/null @@ -1,393 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import time -from abc import ABC, abstractmethod -from collections.abc import Awaitable, Callable -from functools import cached_property -from typing import Any - -from typing_extensions import TypeVar - -import vllm.platforms -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest -from vllm.tasks import SupportedTask -from vllm.utils.async_utils import make_async -from vllm.v1.outputs import SamplerOutput -from vllm.v1.worker.worker_base import WorkerBase - -logger = init_logger(__name__) - -_R = TypeVar("_R", default=Any) - - -class ExecutorBase(ABC): - """Base class for all executors. - - An executor is responsible for executing the model on one device, - or it can be a distributed executor - that can execute the model on multiple devices. - """ - - uses_ray: bool # whether the executor uses Ray for orchestration. - supports_pp: bool = False # whether the executor supports PP - - def __init__( - self, - vllm_config: VllmConfig, - ) -> None: - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.observability_config = vllm_config.observability_config - self._init_executor() - self.is_sleeping = False - self.sleeping_tags: set[str] = set() - self.kv_output_aggregator: KVOutputAggregator | None = None - - @abstractmethod - def _init_executor(self) -> None: - raise NotImplementedError - - @abstractmethod - def collective_rpc( - self, - method: str | Callable[[WorkerBase], _R], - timeout: float | None = None, - args: tuple = (), - kwargs: dict[str, Any] | None = None, - ) -> list[_R]: - """ - Execute an RPC call on all workers. - - Args: - method: Name of the worker method to execute, or a callable that - is serialized and sent to all workers to execute. - - If the method is a callable, it should accept an additional - `self` argument, in addition to the arguments passed in `args` - and `kwargs`. The `self` argument will be the worker object. - timeout: Maximum time in seconds to wait for execution. Raises a - [`TimeoutError`][] on timeout. `None` means wait indefinitely. - args: Positional arguments to pass to the worker method. - kwargs: Keyword arguments to pass to the worker method. - - Returns: - A list containing the results from each worker. - - Note: - It is recommended to use this API to only pass control messages, - and set up data-plane communication to pass data. - """ - raise NotImplementedError - - def determine_num_available_blocks(self) -> tuple[int, int]: - """Determine the number of available blocks for the GPU KV cache and - swappable CPU KV cache. - - Normally, this should simply delegate to the underlying Worker. Some - ExecutorBase may require modification of the result, e.g. to ensure the - selected cache sizes are compatible with all workers. - - Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where - `num_gpu_blocks` are blocks that are "active" on the device and can be - appended to. - `num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be - appended to. - """ - results = self.collective_rpc("determine_num_available_blocks") - a = min([r[0] for r in results]) - b = min([r[1] for r in results]) - return a, b - - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: - """Initialize the KV cache by invoking the underlying worker.""" - # NOTE: This is logged in the executor because there can be >1 workers. - logger.info( - "# %s blocks: %d, # CPU blocks: %d", - vllm.platforms.current_platform.device_name, - num_gpu_blocks, - num_cpu_blocks, - ) - max_concurrency = ( - num_gpu_blocks - * self.cache_config.block_size - / self.model_config.max_model_len - ) - logger.info( - "Maximum concurrency for %s tokens per request: %.2fx", - self.model_config.max_model_len, - max_concurrency, - ) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) - - @cached_property # Avoid unnecessary RPC calls - def supported_tasks(self) -> tuple[SupportedTask, ...]: - output = self.collective_rpc("get_supported_tasks") - return output[0] - - def execute_model( - self, execute_model_req: ExecuteModelRequest - ) -> list[SamplerOutput]: - output = self.collective_rpc("execute_model", args=(execute_model_req,)) - assert output[0] is not None - return output[0] - - def stop_remote_worker_execution_loop(self) -> None: - """Releases parallel workers from model loop.""" - return - - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return all(self.collective_rpc("add_lora", args=(lora_request,))) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return all(self.collective_rpc("remove_lora", args=(lora_id,))) - - def pin_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return all(self.collective_rpc("pin_lora", args=(lora_id,))) - - def list_loras(self) -> set[int]: - sets = self.collective_rpc("list_loras") - for s in sets: - assert s == sets[0], "All workers should have the same LORAs." - return sets[0] - - def reset_mm_cache(self) -> None: - """Reset the multi-modal cache in each worker.""" - self.collective_rpc("reset_mm_cache") - - def start_profile(self) -> None: - self.collective_rpc("start_profile") - - def stop_profile(self) -> None: - self.collective_rpc("stop_profile") - - def sleep(self, level: int = 1): - if self.is_sleeping: - logger.warning("Executor is already sleeping.") - return - time_before_sleep = time.perf_counter() - self.collective_rpc("sleep", kwargs=dict(level=level)) - time_after_sleep = time.perf_counter() - self.sleeping_tags = {"weights", "kv_cache"} - self.is_sleeping = True - logger.info( - "It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep - ) - - def wake_up(self, tags: list[str] | None = None): - if not self.is_sleeping: - logger.warning("Executor is not sleeping.") - return - if tags: - for tag in tags: - if tag not in self.sleeping_tags: - logger.warning( - "Tag %s is not in sleeping tags %s", tag, self.sleeping_tags - ) - return - time_before_wakeup = time.perf_counter() - self.collective_rpc("wake_up", kwargs=dict(tags=tags)) - time_after_wakeup = time.perf_counter() - logger.info( - "It took %.6f seconds to wake up tags %s.", - time_after_wakeup - time_before_wakeup, - tags if tags is not None else self.sleeping_tags, - ) - if tags: - for tag in tags: - self.sleeping_tags.remove(tag) - else: - self.sleeping_tags.clear() - if not self.sleeping_tags: - self.is_sleeping = False - - def save_sharded_state( - self, - path: str, - pattern: str | None = None, - max_size: int | None = None, - ) -> None: - self.collective_rpc( - "save_sharded_state", - kwargs=dict(path=path, pattern=pattern, max_size=max_size), - ) - - @abstractmethod - def check_health(self) -> None: - """Checks if the executor is healthy. If not, it should raise an - exception.""" - raise NotImplementedError - - def shutdown(self) -> None: - """Shutdown the executor.""" - self.collective_rpc("shutdown") - - async def execute_model_async( - self, execute_model_req: ExecuteModelRequest - ) -> list[SamplerOutput]: - """Executes one model step on the given sequences.""" - output = await make_async(self.execute_model)(execute_model_req) - return output - - async def stop_remote_worker_execution_loop_async(self) -> None: - """Releases parallel workers from model loop.""" - return - - async def check_health_async(self) -> None: - """Checks if the executor is healthy. If not, it should raise an - exception.""" - self.check_health() - - def init_kv_output_aggregator(self, finished_count: int | None) -> None: - """Init KVOutputAggregator""" - self.kv_output_aggregator = KVOutputAggregator( - finished_count or self.parallel_config.world_size - ) - - -class DistributedExecutorBase(ExecutorBase): - """Abstract superclass of distributed executor implementations.""" - - def __init__(self, *args, **kwargs): - # This is non-None when the execute model loop is running - # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. - self.parallel_worker_tasks: Any | Awaitable[Any] | None = None - - super().__init__(*args, **kwargs) - - def execute_model( - self, - execute_model_req: ExecuteModelRequest, - ) -> list[SamplerOutput]: - # TODO: unify into collective_rpc - if self.parallel_worker_tasks is None: - self.parallel_worker_tasks = self._run_workers( - "start_worker_execution_loop", - async_run_tensor_parallel_workers_only=True, - ) - - # Only the driver worker returns the sampling results. - driver_outputs = self._driver_execute_model(execute_model_req) - assert driver_outputs is not None - return driver_outputs - - def stop_remote_worker_execution_loop(self) -> None: - if self.parallel_worker_tasks is None: - return - - self._driver_execute_model(execute_model_req=None) - parallel_worker_tasks = self.parallel_worker_tasks - self.parallel_worker_tasks = None - # Ensure that workers exit model loop cleanly - # (this will raise otherwise) - self._wait_for_tasks_completion(parallel_worker_tasks) - - @abstractmethod - def _driver_execute_model( - self, execute_model_req: ExecuteModelRequest | None - ) -> list[SamplerOutput] | None: - """Run execute_model in the driver worker. - - Passing None will cause the driver to stop the model execution loop - running in each of the remote workers. In this case, this method - returns None. Otherwise, this method returns the model output. - """ - raise NotImplementedError - - def collective_rpc( - self, - method: str | Callable, - timeout: float | None = None, - args: tuple = (), - kwargs: dict[str, Any] | None = None, - ) -> list[Any]: - return self._run_workers(method, *args, **(kwargs or {})) - - @abstractmethod - def _run_workers( - self, - method: str | Callable, - *args, - async_run_tensor_parallel_workers_only: bool = False, - max_concurrent_workers: int | None = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers. - - Args: - async_run_tensor_parallel_workers_only: If True the method will be - run only in the remote TP workers, not the driver worker. - It will also be run asynchronously and return a list of futures - rather than blocking on the results. - - # TODO: simplify and merge with collective_rpc - """ - raise NotImplementedError - - @abstractmethod - def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: - """Wait for futures returned from _run_workers() with - async_run_remote_workers_only to complete.""" - raise NotImplementedError - - async def execute_model_async( - self, execute_model_req: ExecuteModelRequest - ) -> list[SamplerOutput]: - if self.parallel_worker_tasks is None: - # Start model execution loop running in the parallel workers - self.parallel_worker_tasks = asyncio.create_task( - self._start_worker_execution_loop() - ) - - # Only the driver worker returns the sampling results. - return await self._driver_execute_model_async(execute_model_req) - - async def stop_remote_worker_execution_loop_async(self) -> None: - if self.parallel_worker_tasks is None: - return - - await self._driver_execute_model_async() - parallel_worker_tasks = self.parallel_worker_tasks - self.parallel_worker_tasks = None - # Ensure that workers exit model loop cleanly - # (this will raise otherwise) - await parallel_worker_tasks - - @abstractmethod - async def _driver_execute_model_async( - self, - execute_model_req: ExecuteModelRequest | None = None, - ) -> list[SamplerOutput]: - """Execute the model asynchronously in the driver worker. - - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. - """ - raise NotImplementedError - - @abstractmethod - async def _start_worker_execution_loop(self): - """Run execution loop on all workers. It guarantees all workers run - the loop or None of them is running the loop. Loop can be stopped by - `stop_remote_worker_execution_loop`. - The API is idempotent (guarantee only 1 loop run at any moment).""" - raise NotImplementedError diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py deleted file mode 100644 index ac16f06b160e..000000000000 --- a/vllm/executor/msgspec_utils.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from array import array -from typing import Any - -from vllm.multimodal.inputs import MultiModalKwargs -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE - - -def encode_hook(obj: Any) -> Any: - """Custom msgspec enc hook that supports array types and MultiModalKwargs. - - See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder - """ - if isinstance(obj, array): - assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, ( - f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " - f"Given array has a type code of {obj.typecode}." - ) - return obj.tobytes() - if isinstance(obj, MultiModalKwargs): - return dict(obj) - - -def decode_hook(type: type, obj: Any) -> Any: - """Custom msgspec dec hook that supports array types and MultiModalKwargs. - - See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder - """ - if type is array: - deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE) - deserialized.frombytes(obj) - return deserialized - if type is MultiModalKwargs: - return MultiModalKwargs(obj) diff --git a/vllm/sequence.py b/vllm/sequence.py index afa4e20e4502..6bcc94ad5c62 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,7 +5,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any -import msgspec import torch if TYPE_CHECKING: @@ -92,12 +91,3 @@ def __eq__(self, other: object): def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" - - -class ExecuteModelRequest( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, -): # type: ignore[call-arg] - # Placeholder. Remove. - pass diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 623e17b05a6e..7802cece6075 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -943,7 +943,7 @@ def _reduce_config(config: VllmConfig): cloudpickle.register_pickle_by_value(transformers_modules) # ray vendors its own version of cloudpickle - from vllm.executor.ray_utils import ray + from vllm.v1.executor.ray_utils import ray if ray: ray.cloudpickle.register_pickle_by_value(transformers_modules) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index e17cd7beb05c..62faf590b23f 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -39,7 +39,7 @@ from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor -from vllm.v1.executor.abstract import Executor +from vllm.v1.executor import Executor from vllm.v1.metrics.loggers import ( StatLoggerFactory, StatLoggerManager, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a2a71ddbc30a..00d3821bc42b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -60,7 +60,7 @@ EngineZmqAddresses, get_device_indices, ) -from vllm.v1.executor.abstract import Executor +from vllm.v1.executor import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput @@ -322,7 +322,6 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: with self.log_error_detail(scheduler_output): model_output = self.model_executor.execute_model(scheduler_output) - assert isinstance(model_output, ModelRunnerOutput) engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output ) @@ -364,7 +363,7 @@ def step_with_batch_queue( if self.scheduler.has_requests(): scheduler_output = self.scheduler.schedule() future = self.model_executor.execute_model(scheduler_output, non_block=True) - batch_queue.appendleft((future, scheduler_output)) # type: ignore[arg-type] + batch_queue.appendleft((future, scheduler_output)) model_executed = scheduler_output.total_num_scheduled_tokens > 0 if ( @@ -463,14 +462,6 @@ def collective_rpc( ) -> list[_R]: return self.model_executor.collective_rpc(method, timeout, args, kwargs) - def save_tensorized_model( - self, - tensorizer_config, - ) -> None: - self.model_executor.save_tensorized_model( - tensorizer_config=tensorizer_config, - ) - def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]: """Preprocess the request. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 9e9945411782..7b554ca991b9 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -46,7 +46,7 @@ CoreEngineProcManager, launch_core_engines, ) -from vllm.v1.executor.abstract import Executor +from vllm.v1.executor import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr logger = init_logger(__name__) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 538fb6a04bd7..9d69ed93ed37 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -32,7 +32,7 @@ from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor -from vllm.v1.executor.abstract import Executor +from vllm.v1.executor import Executor from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager from vllm.v1.metrics.reader import Metric, get_metrics_snapshot from vllm.v1.metrics.stats import IterationStats diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index c7bfe2763c07..ea017df3d052 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -23,7 +23,7 @@ from vllm.utils import get_mp_context from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx from vllm.v1.engine.coordinator import DPCoordinator -from vllm.v1.executor.abstract import Executor +from vllm.v1.executor import Executor from vllm.v1.utils import get_engine_client_zmq_addr, shutdown if TYPE_CHECKING: diff --git a/vllm/v1/executor/__init__.py b/vllm/v1/executor/__init__.py index e69de29bb2d1..30d52c73791e 100644 --- a/vllm/v1/executor/__init__.py +++ b/vllm/v1/executor/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .abstract import Executor +from .uniproc_executor import UniProcExecutor + +__all__ = ["Executor", "UniProcExecutor"] diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 2a7e052f1329..609a681dc3fe 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,31 +1,40 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import time +from abc import ABC, abstractmethod from collections.abc import Callable from concurrent.futures import Future -from typing import Any - -import torch -import torch.distributed as dist +from functools import cached_property +from typing import Literal, TypeVar, overload from vllm.config import VllmConfig -from vllm.executor.executor_base import ExecutorBase -from vllm.executor.uniproc_executor import ( # noqa - ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0, -) -from vllm.executor.uniproc_executor import UniProcExecutor as UniProcExecutorV0 # noqa +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.tasks import SupportedTask from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput +from vllm.v1.worker.worker_base import WorkerBase + +logger = init_logger(__name__) + +_R = TypeVar("_R") FailureCallback = Callable[[], None] -class Executor(ExecutorBase): +class Executor(ABC): + """Abstract base class for vLLM executors." + + An executor is responsible for executing the model on one device, + or it can be a distributed executor that can execute the model on multiple devices. """ - Abstract class for v1 executors, mainly define some methods for v1. - For methods shared by v0 and v1, define them in ExecutorBase""" + + uses_ray: bool = False # whether the executor uses Ray for orchestration. + supports_pp: bool = False # whether the executor supports PP @staticmethod def get_class(vllm_config: VllmConfig) -> type["Executor"]: @@ -34,16 +43,14 @@ def get_class(vllm_config: VllmConfig) -> type["Executor"]: distributed_executor_backend = parallel_config.distributed_executor_backend # distributed_executor_backend must be set in VllmConfig.__post_init__ if isinstance(distributed_executor_backend, type): - if not issubclass(distributed_executor_backend, ExecutorBase): + if not issubclass(distributed_executor_backend, Executor): raise TypeError( "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {distributed_executor_backend}." + f"Executor. Got {distributed_executor_backend}." ) executor_class = distributed_executor_backend elif distributed_executor_backend == "ray": - from vllm.v1.executor.ray_distributed_executor import ( # noqa - RayDistributedExecutor, - ) + from vllm.v1.executor.ray_executor import RayDistributedExecutor executor_class = RayDistributedExecutor elif distributed_executor_backend == "mp": @@ -51,6 +58,8 @@ def get_class(vllm_config: VllmConfig) -> type["Executor"]: executor_class = MultiprocExecutor elif distributed_executor_backend == "uni": + from vllm.v1.executor.uniproc_executor import UniProcExecutor + executor_class = UniProcExecutor elif distributed_executor_backend == "external_launcher": # TODO: make v1 scheduling deterministic @@ -58,10 +67,10 @@ def get_class(vllm_config: VllmConfig) -> type["Executor"]: executor_class = ExecutorWithExternalLauncher elif isinstance(distributed_executor_backend, str): executor_class = resolve_obj_by_qualname(distributed_executor_backend) - if not issubclass(executor_class, ExecutorBase): + if not issubclass(executor_class, Executor): raise TypeError( "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {executor_class}." + f"Executor. Got {executor_class}." ) else: raise ValueError( @@ -69,6 +78,29 @@ def get_class(vllm_config: VllmConfig) -> type["Executor"]: ) return executor_class + def __init__( + self, + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.observability_config = vllm_config.observability_config + self._init_executor() + self.is_sleeping = False + self.sleeping_tags: set[str] = set() + self.kv_output_aggregator: KVOutputAggregator | None = None + + @abstractmethod + def _init_executor(self) -> None: + raise NotImplementedError + def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: """ Initialize the KV caches and begin the model execution loop of the @@ -77,7 +109,7 @@ def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: self.collective_rpc("initialize_from_config", args=(kv_cache_configs,)) self.collective_rpc("compile_or_warm_up_model") - def register_failure_callback(self, callback: FailureCallback): + def register_failure_callback(self, callback: FailureCallback): # noqa: B027 """ Register a function to be called if the executor enters a permanent failed state. @@ -90,22 +122,78 @@ def determine_available_memory(self) -> list[int]: # in bytes def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: return self.collective_rpc("get_kv_cache_spec") + @overload + def collective_rpc( + self, + method: str | Callable[[WorkerBase], _R], + timeout: float | None = None, + args: tuple = (), + kwargs: dict | None = None, + non_block: Literal[False] = False, + ) -> list[_R]: + """ + Execute an RPC call on all workers. + + Args: + method: Name of the worker method to execute, or a callable that + is serialized and sent to all workers to execute. + + If the method is a callable, it should accept an additional + `self` argument, in addition to the arguments passed in `args` + and `kwargs`. The `self` argument will be the worker object. + timeout: Maximum time in seconds to wait for execution. Raises a + [`TimeoutError`][] on timeout. `None` means wait indefinitely. + args: Positional arguments to pass to the worker method. + kwargs: Keyword arguments to pass to the worker method. + non_block: If `True`, returns a list of Futures instead of waiting + for the results. + + Returns: + A list containing the results from each worker. + + Note: + It is recommended to use this API to only pass control messages, + and set up data-plane communication to pass data. + """ + pass + + @overload def collective_rpc( self, - method: str | Callable, + method: str | Callable[[WorkerBase], _R], timeout: float | None = None, args: tuple = (), kwargs: dict | None = None, - non_block: bool = False, - ) -> list[Any]: + non_block: Literal[True] = True, + ) -> list[Future[_R]]: + pass + + @abstractmethod + def collective_rpc( + self, method, timeout=None, args=(), kwargs=None, non_block: bool = False + ): raise NotImplementedError + @overload def execute_model( self, scheduler_output: SchedulerOutput, - non_block: bool = False, + non_block: Literal[False] = False, + ) -> ModelRunnerOutput: + pass + + @overload + def execute_model( + self, + scheduler_output: SchedulerOutput, + non_block: Literal[True] = True, + ) -> Future[ModelRunnerOutput]: + pass + + def execute_model( + self, scheduler_output: SchedulerOutput, non_block: bool = False ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: - output = self.collective_rpc( + output = self.collective_rpc( # type: ignore[call-overload] "execute_model", args=(scheduler_output,), non_block=non_block ) return output[0] @@ -114,7 +202,7 @@ def execute_dummy_batch(self) -> None: self.collective_rpc("execute_dummy_batch") def take_draft_token_ids(self) -> DraftTokenIds | None: - output = self.collective_rpc("take_draft_token_ids") + output: list[DraftTokenIds] = self.collective_rpc("take_draft_token_ids") return output[0] @property @@ -124,19 +212,120 @@ def max_concurrent_batches(self) -> int: def profile(self, is_start: bool = True): self.collective_rpc("profile", args=(is_start,)) + def save_sharded_state( + self, + path: str, + pattern: str | None = None, + max_size: int | None = None, + ) -> None: + self.collective_rpc( + "save_sharded_state", + kwargs=dict(path=path, pattern=pattern, max_size=max_size), + ) + + @abstractmethod + def check_health(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + raise NotImplementedError -class UniProcExecutor(UniProcExecutorV0, Executor): - pass + def shutdown(self) -> None: + """Shutdown the executor.""" + self.collective_rpc("shutdown") + def init_kv_output_aggregator(self, finished_count: int | None) -> None: + """Init KVOutputAggregator""" + self.kv_output_aggregator = KVOutputAggregator( + finished_count or self.parallel_config.world_size + ) -class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): - def determine_available_memory(self) -> list[int]: # in bytes - # same as determine_num_available_blocks in v0, - # we need to get the min across all ranks. - memory = super().determine_available_memory() - from vllm.distributed.parallel_state import get_world_group - - cpu_group = get_world_group().cpu_group - memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64) - dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN) - return [memory_tensor.item()] + @cached_property # Avoid unnecessary RPC calls + def supported_tasks(self) -> tuple[SupportedTask, ...]: + output: list[tuple[SupportedTask, ...]] + output = self.collective_rpc("get_supported_tasks") + return output[0] + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("add_lora", args=(lora_request,))) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("remove_lora", args=(lora_id,))) + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("pin_lora", args=(lora_id,))) + + def list_loras(self) -> set[int]: + sets: list[set[int]] = self.collective_rpc("list_loras") + for s in sets: + assert s == sets[0], "All workers should have the same LORAs." + return sets[0] + + def reset_mm_cache(self) -> None: + """Reset the multi-modal cache in each worker.""" + self.collective_rpc("reset_mm_cache") + + def start_profile(self) -> None: + self.collective_rpc("start_profile") + + def stop_profile(self) -> None: + self.collective_rpc("stop_profile") + + def sleep(self, level: int = 1): + if self.is_sleeping: + logger.warning("Executor is already sleeping.") + return + time_before_sleep = time.perf_counter() + self.collective_rpc("sleep", kwargs=dict(level=level)) + time_after_sleep = time.perf_counter() + self.sleeping_tags = {"weights", "kv_cache"} + self.is_sleeping = True + logger.info( + "It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep + ) + + def wake_up(self, tags: list[str] | None = None): + if not self.is_sleeping: + logger.warning("Executor is not sleeping.") + return + if tags: + for tag in tags: + if tag not in self.sleeping_tags: + logger.warning( + "Tag %s is not in sleeping tags %s", tag, self.sleeping_tags + ) + return + time_before_wakeup = time.perf_counter() + self.collective_rpc("wake_up", kwargs=dict(tags=tags)) + time_after_wakeup = time.perf_counter() + logger.info( + "It took %.6f seconds to wake up tags %s.", + time_after_wakeup - time_before_wakeup, + tags if tags is not None else self.sleeping_tags, + ) + if tags: + for tag in tags: + self.sleeping_tags.remove(tag) + else: + self.sleeping_tags.clear() + if not self.sleeping_tags: + self.is_sleeping = False + + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: + raise NotImplementedError + + +from vllm.v1.executor.uniproc_executor import ( # noqa: E402 + ExecutorWithExternalLauncher as _ExecutorWithExternalLauncher, +) +from vllm.v1.executor.uniproc_executor import ( # noqa: E402 + UniProcExecutor as _UniProcExecutor, +) + +# For backwards compatibility. +UniProcExecutor = _UniProcExecutor +ExecutorWithExternalLauncher = _ExecutorWithExternalLauncher diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e9b35c969b2d..8eb45d85fff1 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -179,7 +179,7 @@ def register_failure_callback(self, callback: FailureCallback): else: self.failure_callback = callback - def execute_model( + def execute_model( # type: ignore[override] self, scheduler_output: SchedulerOutput, non_block: bool = False, @@ -204,6 +204,7 @@ def execute_model( ) # aggregate all workers output to a single output + assert self.kv_output_aggregator is not None if non_block: return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank) return self.kv_output_aggregator.aggregate(outputs, self.output_rank) diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index 586df591bfd8..9a56c093ad69 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -1,111 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from concurrent.futures import Future - -from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator -from vllm.executor.ray_distributed_executor import ( # noqa - RayDistributedExecutor as RayDistributedExecutorV0, +from vllm.v1.executor.ray_executor import ( + RayDistributedExecutor as _RayDistributedExecutor, ) -from vllm.logger import init_logger -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType -from vllm.v1.executor.abstract import Executor -from vllm.v1.outputs import ModelRunnerOutput - -logger = init_logger(__name__) - - -class FutureWrapper(Future): - """A wrapper around Ray output reference to meet the interface - of .execute_model(): The top level (core busy loop) expects .result() api - to block and return a single output. - - If aggregator is provided, the outputs from all workers are aggregated upon - the result() call. If not only the first worker's output is returned. - """ - - def __init__(self, refs, aggregator: KVOutputAggregator | None = None): - super().__init__() - self.refs = refs - self.aggregator = aggregator - - def result(self, timeout=None): - if timeout is not None: - raise NotImplementedError("timeout is not supported") - - if self.aggregator is None: - return self.refs[0].get() - - outputs = [ref.get() for ref in self.refs] - return self.aggregator.aggregate(outputs, output_rank=0) - - -class RayDistributedExecutor(RayDistributedExecutorV0, Executor): - """Ray distributed executor using Ray Compiled Graphs.""" - - supports_pp: bool = True - - def _init_executor(self) -> None: - super()._init_executor() - - # KV connector setup - self.has_connector = self.vllm_config.kv_transfer_config is not None - - @property - def max_concurrent_batches(self) -> int: - """Ray distributed executor supports pipeline parallelism, - meaning that it allows PP size batches to be executed concurrently. - """ - if self.scheduler_config.async_scheduling: - return 2 - return self.parallel_config.pipeline_parallel_size - - def execute_model( - self, - scheduler_output: SchedulerOutput, - non_block: bool = False, - ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: - """Execute the model on the Ray workers. - - Args: - scheduler_output: The scheduler output to execute. - non_block: If True, the method will return a Future. - - Returns: - The model runner output. - """ - # Build the compiled DAG for the first time. - if self.forward_dag is None: # type: ignore - self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) - - refs = self.forward_dag.execute(scheduler_output) # type: ignore - - if not self.has_connector: - # Get output only from a single worker (output_rank) - # When PP is not used, we block here until the result is available. - if not non_block: - return refs[0].get() - - # When PP is used, we return a FutureWrapper immediately so that - # the scheduler can yield to the next batch. - return FutureWrapper(refs) - - # Get output from all workers when connector is present - if not non_block: - # Block and get results from all workers - outputs = [ref.get() for ref in refs] - return self.kv_output_aggregator.aggregate(outputs) - - # Return a future that will aggregate outputs from all workers - return FutureWrapper(refs, self.kv_output_aggregator) - def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest - ) -> None: - self._run_workers("reinitialize_distributed", reconfig_request) - if ( - reconfig_request.new_data_parallel_rank - == ReconfigureRankType.SHUTDOWN_CURRENT_RANK - ): - self.shutdown() +# For backwards compatibility. +RayDistributedExecutor = _RayDistributedExecutor diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_executor.py similarity index 59% rename from vllm/executor/ray_distributed_executor.py rename to vllm/v1/executor/ray_executor.py index 8e8901807f69..a4823acc8764 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -1,31 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio import os from collections import defaultdict from collections.abc import Callable +from concurrent.futures import Future from dataclasses import dataclass from typing import TYPE_CHECKING, Any import cloudpickle -import msgspec import vllm.envs as envs -from vllm.executor.executor_base import DistributedExecutorBase -from vllm.executor.msgspec_utils import encode_hook -from vllm.executor.ray_utils import RayWorkerWrapper, initialize_ray_cluster, ray from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy -from vllm.sequence import ExecuteModelRequest -from vllm.utils.async_utils import make_async from vllm.utils.network_utils import ( get_distributed_init_method, get_ip, get_open_port, ) -from vllm.v1.outputs import SamplerOutput +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.executor.abstract import Executor +from vllm.v1.executor.ray_utils import ( + FutureWrapper, + RayWorkerWrapper, + initialize_ray_cluster, + ray, +) +from vllm.v1.outputs import ModelRunnerOutput if ray is not None: from ray.actor import ActorHandle @@ -53,7 +56,7 @@ class RayWorkerMetaData: ip: str = "" -class RayDistributedExecutor(DistributedExecutorBase): +class RayDistributedExecutor(Executor): """Ray-based distributed executor""" # These env vars are worker-specific, therefore are NOT copied @@ -69,37 +72,14 @@ class RayDistributedExecutor(DistributedExecutorBase): ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"} uses_ray: bool = True + supports_pp: bool = True def _init_executor(self) -> None: self.forward_dag: ray.dag.CompiledDAG | None = None - if envs.VLLM_USE_V1: - # V1 uses SPMD worker and compiled DAG - os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" - os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1" - - # For TPU or XPU, avoid compiling NVIDIA's NCCL - if current_platform.is_tpu() or current_platform.is_xpu(): - os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm" - - # If the env var is set, it uses the Ray's compiled DAG API - # which optimizes the control plane overhead. - # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. - # Currently, this requires USE_RAY_SPMD_WORKER=True. - self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG - # If the env var is set, then we do not distinguish between the - # "driver worker" vs other workers. Also, the rank 0 worker will - # be executed in a remote Ray worker. Currently this requires - # USE_RAY_COMPILED_DAG=True. - self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER - if self.use_ray_compiled_dag: - assert self.use_ray_spmd_worker, ( - "VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_RAY_SPMD_WORKER=1" - ) - if self.use_ray_spmd_worker: - # TODO: Support SPMD worker for non-DAG Ray executor. - assert self.use_ray_compiled_dag, ( - "VLLM_USE_RAY_SPMD_WORKER=1 requires VLLM_USE_RAY_COMPILED_DAG=1" - ) + + # For TPU or XPU, avoid compiling NVIDIA's NCCL + if current_platform.is_tpu() or current_platform.is_xpu(): + os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm" assert self.uses_ray initialize_ray_cluster(self.parallel_config) @@ -113,13 +93,17 @@ def _init_executor(self) -> None: # Create the parallel GPU workers. self._init_workers_ray(placement_group) - self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) - self.output_decoder = msgspec.msgpack.Decoder(list[SamplerOutput] | None) - self.use_v1 = envs.VLLM_USE_V1 + # KV connector setup + self.has_connector = self.vllm_config.kv_transfer_config is not None - self.pp_locks: list[asyncio.Lock] | None = None - if not self.use_ray_compiled_dag: - self.driver_exec_method = make_async(self.driver_worker.execute_method) + @property + def max_concurrent_batches(self) -> int: + """Ray distributed executor supports pipeline parallelism, + meaning that it allows PP size batches to be executed concurrently. + """ + if self.scheduler_config.async_scheduling: + return 2 + return self.parallel_config.pipeline_parallel_size def shutdown(self) -> None: if logger: @@ -176,8 +160,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwar ray_remote_kwargs ) - logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) - # Create the workers. bundle_indices: list[int] if envs.VLLM_RAY_BUNDLE_INDICES: @@ -241,30 +223,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwar for each, ip in zip(worker_metadata, worker_ips): each.ip = ip - if not self.use_ray_spmd_worker: - for i, each in enumerate(worker_metadata): - # find and remove the dummy worker from the list - worker = each.worker - worker_ip = each.ip - if self.driver_dummy_worker is None and worker_ip == driver_ip: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - vllm_config=self.vllm_config, rpc_rank=0 - ) - worker_metadata.pop(i) - break - logger.debug("workers: %s", worker_metadata) logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) - if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: - raise ValueError( - "Ray does not allocate any GPUs on the driver node." - f"Driver IP: {driver_ip}, worker IPs: {worker_ips}." - "Consider adjusting the Ray placement group or running " - "the driver on a GPU node." - ) ip_counts: dict[str, int] = {} for ip in worker_ips: @@ -281,7 +241,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): should be placed first. """ ip = item.ip - return (0 if ip == driver_ip else 1, ip_counts[ip], ip) + return 0 if ip == driver_ip else 1, ip_counts[ip], ip # After sorting, the workers on the same node will be # close to each other, and the workers on the driver @@ -289,14 +249,13 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): sorted_worker_metadata = sorted( worker_metadata, key=sort_by_driver_then_worker_ip ) - start_rank = 0 if self.use_ray_spmd_worker else 1 for i, item in enumerate(sorted_worker_metadata): - item.adjusted_rank = i + start_rank + item.adjusted_rank = i self.workers = [item.worker for item in sorted_worker_metadata] rerank_mapping = { item.created_rank: item.adjusted_rank for item in sorted_worker_metadata } - self._run_workers("adjust_rank", rerank_mapping) + self.collective_rpc("adjust_rank", args=(rerank_mapping,)) # Get the set of GPU IDs used on each node. worker_node_and_gpu_ids = [] @@ -365,8 +324,8 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): self._env_vars_for_all_workers = all_args_to_update_environment_variables - self._run_workers( - "update_environment_variables", self._get_env_vars_to_be_updated() + self.collective_rpc( + "update_environment_variables", args=(self._get_env_vars_to_be_updated(),) ) if len(node_gpus) == 1: @@ -396,138 +355,95 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): or (rank % self.parallel_config.tensor_parallel_size == 0), ) all_kwargs.append(kwargs) - self._run_workers("init_worker", all_kwargs) + self.collective_rpc("init_worker", args=(all_kwargs,)) + + self.collective_rpc("init_device") + self.collective_rpc("load_model") + + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range(self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: + self.collective_rpc("reinitialize_distributed", args=(reconfig_request,)) + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): + self.shutdown() + + def execute_model( # type: ignore[override] + self, scheduler_output: SchedulerOutput, non_block: bool = False + ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: + """Execute the model on the Ray workers. - self._run_workers("init_device") - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config.max_parallel_loading_workers, - ) + Args: + scheduler_output: The scheduler output to execute. + non_block: If True, the method will return a Future. - if self.use_ray_spmd_worker: - for pp_rank in range(self.parallel_config.pipeline_parallel_size): - self.pp_tp_workers.append([]) - for tp_rank in range(self.parallel_config.tensor_parallel_size): - # PP=2, TP=4 - # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] - rank = ( - pp_rank * self.parallel_config.tensor_parallel_size - ) + tp_rank - assert len(self.pp_tp_workers[pp_rank]) == tp_rank - assert pp_rank < len(self.pp_tp_workers) - self.pp_tp_workers[pp_rank].append(self.workers[rank]) - - # This is the list of workers that are rank 0 of each TP group EXCEPT - # global rank 0. These are the workers that will broadcast to the - # rest of the workers. - self.tp_driver_workers: list[RayWorkerWrapper] = [] - # This is the list of workers that are not drivers and not the first - # worker in a TP group. These are the workers that will be - # broadcasted to. - self.non_driver_workers: list[RayWorkerWrapper] = [] - - # Enforce rank order for correct rank to return final output. - for index, worker in enumerate(self.workers): - # The driver worker is rank 0 and not in self.workers. - rank = index + 1 - if rank % self.parallel_config.tensor_parallel_size == 0: - self.tp_driver_workers.append(worker) - else: - self.non_driver_workers.append(worker) + Returns: + The model runner output. + """ + # Build the compiled DAG for the first time. + if self.forward_dag is None: # type: ignore + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) - def _driver_execute_model( - self, execute_model_req: ExecuteModelRequest | None - ) -> list[SamplerOutput] | None: - """Run execute_model in the driver worker. + refs = self.forward_dag.execute(scheduler_output) # type: ignore - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. - """ - assert not self.use_ray_spmd_worker, ( - "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1" - ) - return self.driver_worker.execute_method("execute_model", execute_model_req) + if not self.has_connector: + # Get output only from a single worker (output_rank) + # When PP is not used, we block here until the result is available. + if not non_block: + return refs[0].get() - def execute_model( - self, execute_model_req: ExecuteModelRequest - ) -> list[SamplerOutput]: - if not self.use_ray_spmd_worker: - return super().execute_model(execute_model_req) + # When PP is used, we return a FutureWrapper immediately so that + # the scheduler can yield to the next batch. + return FutureWrapper(refs) - if self.forward_dag is None: - self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + # Get output from all workers when connector is present + assert self.kv_output_aggregator is not None + if not non_block: + # Block and get results from all workers + outputs = [ref.get() for ref in refs] + return self.kv_output_aggregator.aggregate(outputs) - if self.use_v1: - serialized_data = execute_model_req - else: - serialized_data = self.input_encoder.encode(execute_model_req) - outputs = ray.get(self.forward_dag.execute(serialized_data)) - output = outputs[0] if self.use_v1 else self.output_decoder.decode(outputs[0]) - return output + # Return a future that will aggregate outputs from all workers + return FutureWrapper(refs, self.kv_output_aggregator) - def _run_workers( + def collective_rpc( self, method: str | Callable, - *args, - async_run_tensor_parallel_workers_only: bool = False, - max_concurrent_workers: int | None = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers. Can be used in the following - ways: - - Args: - - async_run_tensor_parallel_workers_only: If True the method will be - run only in the remote TP workers, not the driver worker. - It will also be run asynchronously and return a list of futures - rather than blocking on the results. - - args/kwargs: All workers share the same args/kwargs - """ + timeout: float | None = None, + args: tuple = (), + kwargs: dict[str, Any] | None = None, + non_block: bool = False, + ) -> list[Any]: + """Runs the given method on all workers.""" sent_method = method if isinstance(method, str) else cloudpickle.dumps(method) del method - if self.use_ray_spmd_worker: - assert not async_run_tensor_parallel_workers_only, ( - "async_run_tensor_parallel_workers_only is not supported for spmd mode." - ) - - if max_concurrent_workers: - raise NotImplementedError("max_concurrent_workers is not supported yet.") - # Start the ray workers first. - ray_workers = self.workers - if async_run_tensor_parallel_workers_only: - ray_workers = self.non_driver_workers + if kwargs is None: + kwargs = {} ray_worker_outputs = [ worker.execute_method.remote( # type: ignore[attr-defined] sent_method, *args, **kwargs ) - for worker in ray_workers + for worker in self.workers ] - if async_run_tensor_parallel_workers_only: - # Just return futures - return ray_worker_outputs - - driver_worker_output = [] - # In SPMD mode, the driver worker is the same as any other worker, - # so we only explicitly execute on the driver worker if using a - # non-SPMD worker class. - if not self.use_ray_spmd_worker: - # Start the driver worker after all the ray workers. - driver_worker_output = [ - self.driver_worker.execute_method(sent_method, *args, **kwargs) - ] - # Get the results of the ray workers. - if self.workers: - ray_worker_outputs = ray.get(ray_worker_outputs) - - return driver_worker_output + ray_worker_outputs + if non_block: + return [FutureWrapper((output,)) for output in ray_worker_outputs] - def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: - """Wait for futures returned from _run_workers() with - async_run_remote_workers_only to complete.""" - ray.get(parallel_worker_tasks) + return ray.get(ray_worker_outputs, timeout=timeout) def _check_ray_cgraph_installation(self): import importlib.metadata @@ -595,13 +511,6 @@ def _compiled_ray_dag(self, enable_asyncio: bool): with InputNode() as input_data: # Example DAG: PP=2, TP=4 # - # For V0: - # ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput # noqa: E501 - # ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput # noqa: E501 - # ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput # noqa: E501 - # ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput # noqa: E501 - # - # For V1: # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput # noqa: E501 # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput # noqa: E501 # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput # noqa: E501 @@ -613,20 +522,10 @@ def _compiled_ray_dag(self, enable_asyncio: bool): for pp_rank, tp_group in enumerate(self.pp_tp_workers): # Each PP worker takes in the output of the previous PP worker, # and the TP group executes in SPMD fashion. - if self.use_v1: - outputs = [ - worker.execute_model_ray.bind( # type: ignore[attr-defined] - outputs[i] - ) - for i, worker in enumerate(tp_group) - ] - else: - outputs = [ - worker.execute_model_spmd.bind( # type: ignore[attr-defined] - outputs[i] - ) - for i, worker in enumerate(tp_group) - ] + outputs = [ + worker.execute_model_ray.bind(outputs[i]) # type: ignore[attr-defined] + for i, worker in enumerate(tp_group) + ] last_pp_rank = len(self.pp_tp_workers) - 1 if ( @@ -674,82 +573,7 @@ def _compiled_ray_dag(self, enable_asyncio: bool): def __del__(self): self.shutdown() - async def execute_model_async( - self, execute_model_req: ExecuteModelRequest - ) -> list[SamplerOutput]: - if not self.use_ray_spmd_worker: - return await super().execute_model_async(execute_model_req) - - if self.forward_dag is None: - self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) - - serialized_data = self.input_encoder.encode(execute_model_req) - dag_future = await self.forward_dag.execute_async(serialized_data) - output = await dag_future[0] - return self.output_decoder.decode(output) - - async def _driver_execute_model_async( - self, execute_model_req: ExecuteModelRequest | None = None - ) -> list[SamplerOutput]: - assert not self.use_ray_spmd_worker, ( - "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1" - ) - if not self.tp_driver_workers: - return await self.driver_exec_method("execute_model", execute_model_req) - if self.pp_locks is None: - # This locks each pipeline parallel stage so multiple virtual - # engines can't execute on the same stage at the same time - # We create the locks here to avoid creating them in the constructor - # which uses a different asyncio loop. - self.pp_locks = [ - asyncio.Lock() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - tasks = [ - asyncio.create_task( - _run_task_with_lock( - self.driver_exec_method, - self.pp_locks[0], - "execute_model", - execute_model_req, - ) - ) - ] - for pp_rank, driver_worker in enumerate(self.tp_driver_workers, start=1): - tasks.append( - asyncio.create_task( - _run_task_with_lock( - driver_worker.execute_method.remote, # type: ignore[attr-defined] - self.pp_locks[pp_rank], - "execute_model", - execute_model_req, - ) - ) - ) - - results = await asyncio.gather(*tasks) - - # Only the last PP stage has the final results. - return results[-1] - - async def _start_worker_execution_loop(self): - assert not self.use_ray_spmd_worker, ( - "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1" - ) - coros = [ - worker.execute_method.remote("start_worker_execution_loop") # type: ignore[attr-defined] - for worker in self.non_driver_workers - ] - return await asyncio.gather(*coros) - def check_health(self) -> None: # Assume that the Ray workers are healthy. # TODO: check the health of the Ray workers return - - -async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs): - """Utility function to run async task in a lock""" - async with lock: - return await task(*args, **kwargs) diff --git a/vllm/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py similarity index 88% rename from vllm/executor/ray_utils.py rename to vllm/v1/executor/ray_utils.py index b4a29da46171..518f1582faeb 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -4,17 +4,16 @@ import os import time from collections import defaultdict +from concurrent.futures import Future from typing import TYPE_CHECKING, Union -import msgspec - import vllm.platforms from vllm.config import ParallelConfig from vllm.distributed import get_pp_group -from vllm.executor.msgspec_utils import decode_hook, encode_hook +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.sequence import IntermediateTensors from vllm.utils.network_utils import get_ip from vllm.v1.outputs import AsyncModelRunnerOutput from vllm.v1.worker.worker_base import WorkerWrapperBase @@ -51,11 +50,6 @@ def __init__(self, *args, **kwargs) -> None: # that thread. self.compiled_dag_cuda_device_set = False - self.input_decoder = msgspec.msgpack.Decoder( - ExecuteModelRequest, dec_hook=decode_hook - ) - self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) - def get_node_ip(self) -> str: return get_ip() @@ -70,47 +64,6 @@ def get_node_and_gpu_ids(self) -> tuple[str, list[int]]: gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key] return node_id, gpu_ids - def execute_model_spmd( - self, - req_or_tuple: bytes | tuple[bytes, IntermediateTensors | None], - ) -> bytes: - """Execute model in SPMD fashion: used only when SPMD worker and - compiled DAG are both enabled. - - Args: - req_or_tuple: A request or a tuple containing the - request and intermediate tensors. Intermediate tensors are - None unless if it is provided because it is > 0 pipeline - stage. The request is serialized by msgspec. - """ - if isinstance(req_or_tuple, bytes): - serialized_req, intermediate_tensors = req_or_tuple, None - else: - serialized_req, intermediate_tensors = req_or_tuple - - execute_model_req = self.input_decoder.decode(serialized_req) - - assert self.worker is not None, "Worker is not initialized" - - # TODO(swang): This is needed right now because Ray Compiled Graph - # executes on a background thread, so we need to reset torch's - # current device. - if not self.compiled_dag_cuda_device_set: - assert self.worker.device is not None - current_platform.set_device(self.worker.device) - self.compiled_dag_cuda_device_set = True - - output = self.worker._execute_model_spmd( # type: ignore[attr-defined] - execute_model_req, intermediate_tensors - ) - # Pipeline model request and output to the next pipeline stage. - if isinstance(output, IntermediateTensors): - output = serialized_req, output - else: - output = self.output_encoder.encode(output) - - return output - def setup_device_if_necessary(self): # TODO(swang): This is needed right now because Ray CG executes # on a background thread, so we need to reset torch's current @@ -174,6 +127,31 @@ def override_env_vars(self, vars: dict[str, str]): RayWorkerWrapper = None # type: ignore +class FutureWrapper(Future): + """A wrapper around Ray output reference to meet the interface + of .execute_model(): The top level (core busy loop) expects .result() api + to block and return a single output. + + If aggregator is provided, the outputs from all workers are aggregated upon + the result() call. If not only the first worker's output is returned. + """ + + def __init__(self, refs, aggregator: KVOutputAggregator | None = None): + super().__init__() + self.refs = refs + self.aggregator = aggregator + + def result(self, timeout=None): + if timeout is not None: + raise NotImplementedError("timeout is not supported") + + if self.aggregator is None: + return self.refs[0].get() + + outputs = [ref.get() for ref in self.refs] + return self.aggregator.aggregate(outputs, output_rank=0) + + def ray_is_available() -> bool: """Returns True if Ray is available.""" return ray is not None diff --git a/vllm/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py similarity index 80% rename from vllm/executor/uniproc_executor.py rename to vllm/v1/executor/uniproc_executor.py index 6a1838d3df74..0d072172fdf3 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -11,20 +11,18 @@ import torch.distributed as dist import vllm.envs as envs -from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.utils import run_method from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import AsyncModelRunnerOutput from vllm.v1.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) -class UniProcExecutor(ExecutorBase): - uses_ray: bool = False - +class UniProcExecutor(Executor): def _init_executor(self) -> None: """Initialize the worker and load the model.""" self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) @@ -44,9 +42,9 @@ def _init_executor(self) -> None: max_workers=1, thread_name_prefix="WorkerAsyncOutput" ) - self.collective_rpc("init_worker", args=([kwargs],)) - self.collective_rpc("init_device") - self.collective_rpc("load_model") + self.driver_worker.init_worker(all_kwargs=[kwargs]) + self.driver_worker.init_device() + self.driver_worker.load_model() def _distributed_args(self) -> tuple[str, int, int]: """Return (distributed_init_method, rank, local_rank).""" @@ -101,16 +99,12 @@ def reinitialize_distributed( == ReconfigureRankType.SHUTDOWN_CURRENT_RANK ): self.shutdown() - return def shutdown(self) -> None: if worker := self.driver_worker: worker.shutdown() -UniProcExecutorAsync = UniProcExecutor - - class ExecutorWithExternalLauncher(UniProcExecutor): """An executor that uses external launchers to launch engines, specially designed for torchrun-compatible launchers, for @@ -128,8 +122,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor): and they don't need to synchronize the states with each other. """ - uses_ray: bool = False - def _init_executor(self) -> None: """Initialize the worker and load the model.""" if envs.VLLM_USE_V1: @@ -152,22 +144,12 @@ def _distributed_args(self) -> tuple[str, int, int]: local_rank = int(os.environ["LOCAL_RANK"]) return distributed_init_method, rank, local_rank - def determine_num_available_blocks(self) -> tuple[int, int]: - """ - Determine the number of available KV blocks. - Add an additional all_reduce to get the min across all ranks. - Note that even if we have the same `gpu_memory_utilization` and - `swap_space`, the available memory in every rank might still - differ because NCCL can take different amounts of memory in - different ranks. Therefore, it is necessary to test if all ranks - agree on the same KV cache configuration. - """ - a, b = super().determine_num_available_blocks() + def determine_available_memory(self) -> list[int]: # in bytes + # we need to get the min across all ranks. + memory = super().determine_available_memory() from vllm.distributed.parallel_state import get_world_group cpu_group = get_world_group().cpu_group - a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64) - b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64) - dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN) - dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN) - return a_tensor.item(), b_tensor.item() + memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64) + dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN) + return [memory_tensor.item()] diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 9319918b84be..7032f3ef68b4 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -128,28 +128,6 @@ def load_model(self) -> None: def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput: raise NotImplementedError - def start_worker_execution_loop(self) -> None: - """Execute model loop in parallel worker. - - You can stop the loop by executing a driver worker with an empty output. - See `stop_remote_worker_execution_loop` for more details. - """ - raise NotImplementedError("Dead V0 code") - - def determine_num_available_blocks(self) -> tuple[int, int]: - """Determine the number of available blocks for the GPU KV cache and - swappable CPU KV cache. - - The implementation may run profiling or other heuristics to determine - the size of caches. - - Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks - are blocks that are "active" on the device and can be appended to. - num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be - appended to. - """ - raise NotImplementedError - def get_cache_block_size_bytes(self) -> int: """Return the size of a single cache block, in bytes. Used in speculative decoding.