Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import uuid
from PIL import Image
import multiprocessing as mp
from typing import AsyncGenerator, Union
from typing import Any, AsyncGenerator, Union
from typing import Callable
from lightllm.server import TokenLoad
from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect
Expand Down Expand Up @@ -130,6 +130,22 @@ def get_model_name():
return {"model_name": g_objs.args.model_name}


@app.get("/get_server_info")
@app.post("/get_server_info")
def get_server_info():
# 将 StartArgs 转换为字典格式
from dataclasses import asdict

server_info: dict[str, Any] = asdict(g_objs.args)
return server_info


@app.get("/get_weight_version")
@app.post("/get_weight_version")
def get_weight_version():
return {"weight_version": g_objs.args.weight_version}


@app.get("/healthz", summary="Check server health")
@app.get("/health", summary="Check server health")
@app.head("/health", summary="Check server health")
Expand Down
26 changes: 22 additions & 4 deletions lightllm/server/api_server.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
import torch
from .api_cli import make_argument_parser
from lightllm.server.core.objs.start_args_type import StartArgs
from lightllm.utils.log_utils import init_logger

if __name__ == "__main__":
torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
parser = make_argument_parser()
args = parser.parse_args()
logger = init_logger(__name__)


def launch_server(args: StartArgs):
from .api_start import pd_master_start, normal_or_p_d_start, config_server_start

try:
# this code will not be ok for settings to fork to subprocess
torch.multiprocessing.set_start_method("spawn")
except RuntimeError as e:
logger.warning(f"Failed to set start method: {e}")
except Exception as e:
logger.error(f"Failed to set start method: {e}")
raise e

if args.run_mode == "pd_master":
pd_master_start(args)
elif args.run_mode == "config_server":
config_server_start(args)
else:
normal_or_p_d_start(args)


if __name__ == "__main__":
parser = make_argument_parser()
args = parser.parse_args()

launch_server(StartArgs(**vars(args)))
33 changes: 26 additions & 7 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from lightllm.utils.process_check import is_process_active
from lightllm.utils.multinode_utils import send_and_receive_node_ip
from lightllm.utils.shm_size_check import check_recommended_shm_size
from lightllm.server.core.objs.start_args_type import StartArgs

logger = init_logger(__name__)

Expand Down Expand Up @@ -51,20 +52,38 @@ def signal_handler(sig, frame):
process_manager.terminate_all_processes()
logger.info("All processes have been terminated gracefully.")
sys.exit(0)
elif sig == signal.SIGHUP:
logger.info("Received SIGHUP (terminal closed), shutting down gracefully...")
if http_server_process and http_server_process.poll() is None:
http_server_process.send_signal(signal.SIGTERM)

start_time = time.time()
while (time.time() - start_time) < 60:
if not is_process_active(http_server_process.pid):
logger.info("httpserver exit")
break
time.sleep(1)

if time.time() - start_time < 60:
logger.info("HTTP server has exited gracefully")
else:
logger.warning("HTTP server did not exit in time, killing it...")
kill_recursive(http_server_process)

process_manager.terminate_all_processes()
logger.info("All processes have been terminated gracefully due to terminal closure.")
sys.exit(0)

signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGHUP, signal_handler)

logger.info(f"start process pid {os.getpid()}")
logger.info(f"http server pid {http_server_process.pid}")
return


def normal_or_p_d_start(args):
from lightllm.server.core.objs.start_args_type import StartArgs

args: StartArgs = args

def normal_or_p_d_start(args: StartArgs):
set_unique_server_name(args)

if not args.disable_shm_warning:
Expand Down Expand Up @@ -368,7 +387,7 @@ def normal_or_p_d_start(args):
return


def pd_master_start(args):
def pd_master_start(args: StartArgs):
set_unique_server_name(args)
if args.run_mode != "pd_master":
return
Expand Down Expand Up @@ -431,7 +450,7 @@ def pd_master_start(args):
http_server_process.wait()


def config_server_start(args):
def config_server_start(args: StartArgs):
set_unique_server_name(args)
if args.run_mode != "config_server":
return
Expand Down
56 changes: 38 additions & 18 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,54 @@
from dataclasses import dataclass, field
from typing import List, Optional, Tuple

# 只是为了更好的编程提示
# 服务启动参数


@dataclass
class StartArgs:
run_mode: str = field(
default="normal",
metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]},
metadata={
"choices": ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"]
},
)
host: str = field(default="127.0.0.1")
port: int = field(default=8000)
httpserver_workers: int = field(default=1)
zmq_mode: str = field(
default="ipc:///tmp/",
metadata={"help": "use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']"},
)
pd_master_ip: str = field(default="127.0.0.1")
pd_master_ip: str = field(default="0.0.0.0")
pd_master_port: int = field(default=1212)
config_server_host: str = field(default=None)
config_server_port: int = field(default=None)
pd_decode_rpyc_port: int = field(default=42000)
select_p_d_node_strategy: str = field(default=None)
select_p_d_node_strategy: str = field(
default="round_robin", metadata={"choices": ["random", "round_robin", "adaptive_load"]}
)
model_name: str = field(default="default_model_name")
model_dir: Optional[str] = field(default=None)
tokenizer_mode: str = field(default="slow")
tokenizer_mode: str = field(default="fast")
load_way: str = field(default="HF")
max_total_token_num: Optional[int] = field(default=None)
mem_fraction: float = field(default=0.9)
batch_max_tokens: Optional[int] = field(default=None)
eos_id: List[int] = field(default_factory=list)
eos_id: Optional[List[int]] = field(default=None)
tool_call_parser: Optional[str] = field(
default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]}
default=None, metadata={"choices": ["qwen25", "llama3", "mistral", "deepseekv3", "qwen"]}
)
chat_template: Optional[str] = field(default=None)
running_max_req_size: int = field(default=1000)
tp: int = field(default=1)
dp: int = field(default=1)
nnodes: int = field(default=1)
node_rank: int = field(default=0)
max_req_total_len: int = field(default=2048 + 1024)
max_req_total_len: int = field(default=16384)
nccl_host: str = field(default="127.0.0.1")
nccl_port: int = field(default=28765)
use_config_server_to_init_nccl: bool = field(default=False)
mode: List[str] = field(default_factory=list)
mode: List[str] = field(default_factory=lambda: [])
trust_remote_code: bool = field(default=False)
disable_log_stats: bool = field(default=False)
log_stats_interval: int = field(default=10)
Expand All @@ -52,11 +57,11 @@ class StartArgs:
router_max_wait_tokens: int = field(default=1)
disable_aggressive_schedule: bool = field(default=False)
disable_dynamic_prompt_cache: bool = field(default=False)
chunked_prefill_size: int = field(default=8192)
chunked_prefill_size: int = field(default=4096)
disable_chunked_prefill: bool = field(default=False)
diverse_mode: bool = field(default=False)
token_healing_mode: bool = field(default=False)
output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]})
output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]})
first_token_constraint_mode: bool = field(default=False)
enable_multimodal: bool = field(default=False)
enable_multimodal_audio: bool = field(default=False)
Expand All @@ -75,11 +80,11 @@ class StartArgs:
health_monitor: bool = field(default=False)
metric_gateway: Optional[str] = field(default=None)
job_name: str = field(default="lightllm")
grouping_key: List[str] = field(default_factory=list)
grouping_key: List[str] = field(default_factory=lambda: [])
push_interval: int = field(default=10)
visual_infer_batch_size: int = field(default=1)
visual_send_batch_size: int = field(default=1)
visual_gpu_ids: List[int] = field(default_factory=lambda: [0])
visual_gpu_ids: Optional[List[int]] = field(default=None)
visual_tp: int = field(default=1)
visual_dp: int = field(default=1)
visual_nccl_ports: List[int] = field(default_factory=lambda: [29500])
Expand All @@ -88,10 +93,10 @@ class StartArgs:
graph_max_batch_size: int = field(default=256)
graph_split_batch_size: int = field(default=32)
graph_grow_step_size: int = field(default=16)
graph_max_len_in_batch: int = field(default=8192)
quant_type: Optional[str] = field(default=None)
graph_max_len_in_batch: int = field(default=0)
quant_type: Optional[str] = field(default="none")
quant_cfg: Optional[str] = field(default=None)
vit_quant_type: Optional[str] = field(default=None)
vit_quant_type: Optional[str] = field(default="none")
vit_quant_cfg: Optional[str] = field(default=None)
enable_flashinfer_prefill: bool = field(default=False)
enable_flashinfer_decode: bool = field(default=False)
Expand All @@ -101,7 +106,9 @@ class StartArgs:
)
ep_redundancy_expert_config_path: Optional[str] = field(default=None)
auto_update_redundancy_expert: bool = field(default=False)
mtp_mode: Optional[str] = field(default=None)
mtp_mode: Optional[str] = field(
default=None, metadata={"choices": ["deepseekv3_vanilla", "deepseekv3_eagle", None]}
)
mtp_draft_model_dir: Optional[str] = field(default=None)
mtp_step: int = field(default=0)
kv_quant_calibration_config_path: Optional[str] = field(default=None)
Expand All @@ -110,7 +117,7 @@ class StartArgs:
pd_node_id: int = field(default=-1)
enable_cpu_cache: bool = field(default=False)
cpu_cache_storage_size: float = field(default=2)
cpu_cache_token_page_size: int = field(default=64)
cpu_cache_token_page_size: int = field(default=256)
enable_disk_cache: bool = field(default=False)
disk_cache_storage_size: float = field(default=10)
disk_cache_dir: Optional[str] = field(default=None)
Expand All @@ -130,3 +137,16 @@ class StartArgs:

# kernel setting
enable_fa3: bool = field(default=False)

httpserver_workers: int = field(default=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

字段 httpserver_workersStartArgs dataclass 中被重复定义了(第一次在第17行)。这会导致 TypeError,是一个严重错误。请移除此处的重复定义。

disable_shm_warning: bool = field(default=False)
dp_balancer: str = field(default="bs_balancer", metadata={"choices": ["round_robin", "bs_balancer"]})
enable_custom_allgather: bool = field(default=False)
enable_fused_shared_experts: bool = field(default=False)
enable_mps: bool = field(default=False)
multinode_router_gloo_port: int = field(default=20001)
schedule_time_interval: float = field(default=0.03)
use_dynamic_prompt_cache: bool = field(default=False)
disable_custom_allreduce: bool = field(default=False)

weight_version: str = "default"
2 changes: 1 addition & 1 deletion lightllm/utils/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_current_device_name():
gpu_name = gpu_name.replace(" ", "_")
return gpu_name
else:
return None
raise RuntimeError("No GPU available")


@lru_cache(maxsize=None)
Expand Down