Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions fastchat/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import requests
import uvicorn


from fastchat.constants import (
CONTROLLER_HEART_BEAT_EXPIRATION,
WORKER_API_TIMEOUT,
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(self, dispatch_method: str):
)
self.heart_beat_thread.start()

def register_worker(
async def register_worker(
self,
worker_name: str,
check_heart_beat: bool,
Expand All @@ -85,7 +86,7 @@ def register_worker(
logger.info(f"Register an existing worker: {worker_name}")

if not worker_status:
worker_status = self.get_worker_status(worker_name)
worker_status = await self.get_worker_status(worker_name)
if not worker_status:
return False

Expand All @@ -101,9 +102,10 @@ def register_worker(
logger.info(f"Register done: {worker_name}, {worker_status}")
return True

def get_worker_status(self, worker_name: str):
async def get_worker_status(self, worker_name: str):
loop = asyncio.get_event_loop()
try:
r = requests.post(worker_name + "/worker_get_status", timeout=5)
r = await loop.run_in_executor(None, lambda: requests.post(worker_name + "/worker_get_status", timeout=1.0))
except requests.exceptions.RequestException as e:
logger.error(f"Get status fails: {worker_name}, {e}")
return None
Expand All @@ -117,14 +119,22 @@ def get_worker_status(self, worker_name: str):
def remove_worker(self, worker_name: str):
del self.worker_info[worker_name]

def refresh_all_workers(self):
async def refresh_all_workers(self):
old_info = dict(self.worker_info)
self.worker_info = {}

tasks = []
for w_name, w_info in old_info.items():
if not self.register_worker(
tasks.append(self.register_worker(
w_name, w_info.check_heart_beat, None, w_info.multimodal
):
))

results = await asyncio.gather(*tasks, return_exceptions=True)
for (w_name, w_info), result in zip(old_info.items(), results):
if isinstance(result, Exception):
logger.error(f"Error registering worker {w_name}: {result}")
continue
if not result:
logger.info(f"Remove stale worker: {w_name}")

def list_models(self):
Expand Down Expand Up @@ -153,7 +163,7 @@ def list_language_models(self):

return list(model_names)

def get_worker_address(self, model_name: str):
async def get_worker_address(self, model_name: str):
if self.dispatch_method == DispatchMethod.LOTTERY:
worker_names = []
worker_speeds = []
Expand All @@ -176,7 +186,7 @@ def get_worker_address(self, model_name: str):
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
worker_name = worker_names[pt]

if self.get_worker_status(worker_name):
if await self.get_worker_status(worker_name):
break
else:
self.remove_worker(worker_name)
Expand Down Expand Up @@ -244,17 +254,20 @@ def handle_worker_timeout(self, worker_address):

# Let the controller act as a worker to achieve hierarchical
# management. This can be used to connect isolated sub networks.
def worker_api_get_status(self):
async def worker_api_get_status(self):
model_names = set()
speed = 0
queue_length = 0

for w_name in self.worker_info:
worker_status = self.get_worker_status(w_name)
if worker_status is not None:
model_names.update(worker_status["model_names"])
speed += worker_status["speed"]
queue_length += worker_status["queue_length"]
tasks = [self.get_worker_status(w_name) for w_name in self.worker_info]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception) or result is None:
continue
worker_status = result
model_names.update(worker_status["model_names"])
speed += worker_status["speed"]
queue_length += worker_status["queue_length"]

model_names = sorted(list(model_names))
return {
Expand All @@ -263,8 +276,8 @@ def worker_api_get_status(self):
"queue_length": queue_length,
}

def worker_api_generate_stream(self, params):
worker_addr = self.get_worker_address(params["model"])
async def worker_api_generate_stream(self, params):
worker_addr = await self.get_worker_address(params["model"])
if not worker_addr:
yield self.handle_no_worker(params)

Expand All @@ -288,7 +301,7 @@ def worker_api_generate_stream(self, params):
@app.post("/register_worker")
async def register_worker(request: Request):
data = await request.json()
controller.register_worker(
await controller.register_worker(
data["worker_name"],
data["check_heart_beat"],
data.get("worker_status", None),
Expand All @@ -298,7 +311,7 @@ async def register_worker(request: Request):

@app.post("/refresh_all_workers")
async def refresh_all_workers():
models = controller.refresh_all_workers()
await controller.refresh_all_workers()


@app.post("/list_models")
Expand All @@ -322,7 +335,7 @@ async def list_language_models():
@app.post("/get_worker_address")
async def get_worker_address(request: Request):
data = await request.json()
addr = controller.get_worker_address(data["model"])
addr = await controller.get_worker_address(data["model"])
return {"address": addr}


Expand All @@ -336,13 +349,13 @@ async def receive_heart_beat(request: Request):
@app.post("/worker_generate_stream")
async def worker_api_generate_stream(request: Request):
params = await request.json()
generator = controller.worker_api_generate_stream(params)
generator = await controller.worker_api_generate_stream(params)
return StreamingResponse(generator)


@app.post("/worker_get_status")
async def worker_api_get_status(request: Request):
return controller.worker_api_get_status()
return await controller.worker_api_get_status()


@app.get("/test_connection")
Expand Down