-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
62 lines (52 loc) · 2 KB
/
main.py
File metadata and controls
62 lines (52 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from fastapi import FastAPI, HTTPException
import uvicorn
from inference_types import (
GenerateRequest,
GenerateResponse,
BatchGenerateRequest,
BatchGenerateResponse,
)
from inference import batch_inference
app = FastAPI()
@app.post("/generate")
async def generate(generate_request: GenerateRequest) -> GenerateResponse:
try:
print("Request received:", generate_request)
batch_request = BatchGenerateRequest(
batch_messages=[generate_request.messages],
max_output_length=generate_request.max_output_length,
temperature=generate_request.temperature,
top_p=generate_request.top_p,
top_k=generate_request.top_k,
repetition_penalty=generate_request.repetition_penalty,
streaming=generate_request.streaming,
)
batch_response: BatchGenerateResponse = batch_inference(batch_request)
response = GenerateResponse(
response=batch_response.batch_responses[0][0],
generation_tokens=batch_response.batch_generation_tokens[0],
generation_time=batch_response.generation_time,
tokens_per_second=batch_response.max_throughput_tokens_per_second,
)
print("Response:", response)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/batch_generate")
async def batch_generate(
batch_generate_request: BatchGenerateRequest,
) -> BatchGenerateResponse:
try:
print("Batch request received:", batch_generate_request)
batch_response = batch_inference(batch_generate_request)
print("Batch response:", batch_response)
return batch_response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
if rank == 0:
print("Inference server ready!")
uvicorn.run(app, host="0.0.0.0", port=8000)