diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 8bda50fb7..13b4cfce5 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -33,7 +33,7 @@ import uuid from PIL import Image import multiprocessing as mp -from typing import AsyncGenerator, Union +from typing import Any, AsyncGenerator, Union from typing import Callable from lightllm.server import TokenLoad from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect @@ -130,6 +130,22 @@ def get_model_name(): return {"model_name": g_objs.args.model_name} +@app.get("/get_server_info") +@app.post("/get_server_info") +def get_server_info(): + # 将 StartArgs 转换为字典格式 + from dataclasses import asdict + + server_info: dict[str, Any] = asdict(g_objs.args) + return server_info + + +@app.get("/get_weight_version") +@app.post("/get_weight_version") +def get_weight_version(): + return {"weight_version": g_objs.args.weight_version} + + @app.get("/healthz", summary="Check server health") @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index b4447d808..b0a1189d3 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -1,15 +1,33 @@ import torch from .api_cli import make_argument_parser +from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.utils.log_utils import init_logger -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess - parser = make_argument_parser() - args = parser.parse_args() +logger = init_logger(__name__) + + +def launch_server(args: StartArgs): from .api_start import pd_master_start, normal_or_p_d_start, config_server_start + try: + # this code will not be ok for settings to fork to subprocess + torch.multiprocessing.set_start_method("spawn") + except RuntimeError as e: + logger.warning(f"Failed to set start method: {e}") + except Exception as e: + logger.error(f"Failed to set start method: {e}") + raise e + if args.run_mode == "pd_master": pd_master_start(args) elif args.run_mode == "config_server": config_server_start(args) else: normal_or_p_d_start(args) + + +if __name__ == "__main__": + parser = make_argument_parser() + args = parser.parse_args() + + launch_server(StartArgs(**vars(args))) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index f73be30db..319cd2608 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -16,6 +16,7 @@ from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip from lightllm.utils.shm_size_check import check_recommended_shm_size +from lightllm.server.core.objs.start_args_type import StartArgs logger = init_logger(__name__) @@ -51,20 +52,38 @@ def signal_handler(sig, frame): process_manager.terminate_all_processes() logger.info("All processes have been terminated gracefully.") sys.exit(0) + elif sig == signal.SIGHUP: + logger.info("Received SIGHUP (terminal closed), shutting down gracefully...") + if http_server_process and http_server_process.poll() is None: + http_server_process.send_signal(signal.SIGTERM) + + start_time = time.time() + while (time.time() - start_time) < 60: + if not is_process_active(http_server_process.pid): + logger.info("httpserver exit") + break + time.sleep(1) + + if time.time() - start_time < 60: + logger.info("HTTP server has exited gracefully") + else: + logger.warning("HTTP server did not exit in time, killing it...") + kill_recursive(http_server_process) + + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully due to terminal closure.") + sys.exit(0) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGHUP, signal_handler) logger.info(f"start process pid {os.getpid()}") logger.info(f"http server pid {http_server_process.pid}") return -def normal_or_p_d_start(args): - from lightllm.server.core.objs.start_args_type import StartArgs - - args: StartArgs = args - +def normal_or_p_d_start(args: StartArgs): set_unique_server_name(args) if not args.disable_shm_warning: @@ -368,7 +387,7 @@ def normal_or_p_d_start(args): return -def pd_master_start(args): +def pd_master_start(args: StartArgs): set_unique_server_name(args) if args.run_mode != "pd_master": return @@ -431,7 +450,7 @@ def pd_master_start(args): http_server_process.wait() -def config_server_start(args): +def config_server_start(args: StartArgs): set_unique_server_name(args) if args.run_mode != "config_server": return diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 8a2727794..a7c26e827 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -1,37 +1,42 @@ from dataclasses import dataclass, field from typing import List, Optional, Tuple -# 只是为了更好的编程提示 +# 服务启动参数 @dataclass class StartArgs: run_mode: str = field( default="normal", - metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]}, + metadata={ + "choices": ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"] + }, ) host: str = field(default="127.0.0.1") port: int = field(default=8000) + httpserver_workers: int = field(default=1) zmq_mode: str = field( default="ipc:///tmp/", metadata={"help": "use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']"}, ) - pd_master_ip: str = field(default="127.0.0.1") + pd_master_ip: str = field(default="0.0.0.0") pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) pd_decode_rpyc_port: int = field(default=42000) - select_p_d_node_strategy: str = field(default=None) + select_p_d_node_strategy: str = field( + default="round_robin", metadata={"choices": ["random", "round_robin", "adaptive_load"]} + ) model_name: str = field(default="default_model_name") model_dir: Optional[str] = field(default=None) - tokenizer_mode: str = field(default="slow") + tokenizer_mode: str = field(default="fast") load_way: str = field(default="HF") max_total_token_num: Optional[int] = field(default=None) mem_fraction: float = field(default=0.9) batch_max_tokens: Optional[int] = field(default=None) - eos_id: List[int] = field(default_factory=list) + eos_id: Optional[List[int]] = field(default=None) tool_call_parser: Optional[str] = field( - default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]} + default=None, metadata={"choices": ["qwen25", "llama3", "mistral", "deepseekv3", "qwen"]} ) chat_template: Optional[str] = field(default=None) running_max_req_size: int = field(default=1000) @@ -39,11 +44,11 @@ class StartArgs: dp: int = field(default=1) nnodes: int = field(default=1) node_rank: int = field(default=0) - max_req_total_len: int = field(default=2048 + 1024) + max_req_total_len: int = field(default=16384) nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=28765) use_config_server_to_init_nccl: bool = field(default=False) - mode: List[str] = field(default_factory=list) + mode: List[str] = field(default_factory=lambda: []) trust_remote_code: bool = field(default=False) disable_log_stats: bool = field(default=False) log_stats_interval: int = field(default=10) @@ -52,11 +57,11 @@ class StartArgs: router_max_wait_tokens: int = field(default=1) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) - chunked_prefill_size: int = field(default=8192) + chunked_prefill_size: int = field(default=4096) disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) - output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]}) + output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]}) first_token_constraint_mode: bool = field(default=False) enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) @@ -75,11 +80,11 @@ class StartArgs: health_monitor: bool = field(default=False) metric_gateway: Optional[str] = field(default=None) job_name: str = field(default="lightllm") - grouping_key: List[str] = field(default_factory=list) + grouping_key: List[str] = field(default_factory=lambda: []) push_interval: int = field(default=10) visual_infer_batch_size: int = field(default=1) visual_send_batch_size: int = field(default=1) - visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) + visual_gpu_ids: Optional[List[int]] = field(default=None) visual_tp: int = field(default=1) visual_dp: int = field(default=1) visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) @@ -88,10 +93,10 @@ class StartArgs: graph_max_batch_size: int = field(default=256) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) - graph_max_len_in_batch: int = field(default=8192) - quant_type: Optional[str] = field(default=None) + graph_max_len_in_batch: int = field(default=0) + quant_type: Optional[str] = field(default="none") quant_cfg: Optional[str] = field(default=None) - vit_quant_type: Optional[str] = field(default=None) + vit_quant_type: Optional[str] = field(default="none") vit_quant_cfg: Optional[str] = field(default=None) enable_flashinfer_prefill: bool = field(default=False) enable_flashinfer_decode: bool = field(default=False) @@ -101,7 +106,9 @@ class StartArgs: ) ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) - mtp_mode: Optional[str] = field(default=None) + mtp_mode: Optional[str] = field( + default=None, metadata={"choices": ["deepseekv3_vanilla", "deepseekv3_eagle", None]} + ) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) kv_quant_calibration_config_path: Optional[str] = field(default=None) @@ -110,7 +117,7 @@ class StartArgs: pd_node_id: int = field(default=-1) enable_cpu_cache: bool = field(default=False) cpu_cache_storage_size: float = field(default=2) - cpu_cache_token_page_size: int = field(default=64) + cpu_cache_token_page_size: int = field(default=256) enable_disk_cache: bool = field(default=False) disk_cache_storage_size: float = field(default=10) disk_cache_dir: Optional[str] = field(default=None) @@ -130,3 +137,16 @@ class StartArgs: # kernel setting enable_fa3: bool = field(default=False) + + httpserver_workers: int = field(default=1) + disable_shm_warning: bool = field(default=False) + dp_balancer: str = field(default="bs_balancer", metadata={"choices": ["round_robin", "bs_balancer"]}) + enable_custom_allgather: bool = field(default=False) + enable_fused_shared_experts: bool = field(default=False) + enable_mps: bool = field(default=False) + multinode_router_gloo_port: int = field(default=20001) + schedule_time_interval: float = field(default=0.03) + use_dynamic_prompt_cache: bool = field(default=False) + disable_custom_allreduce: bool = field(default=False) + + weight_version: str = "default" diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 4c32f2ab1..7273eb06b 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -90,7 +90,7 @@ def get_current_device_name(): gpu_name = gpu_name.replace(" ", "_") return gpu_name else: - return None + raise RuntimeError("No GPU available") @lru_cache(maxsize=None)