Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions scripts/jiuge.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10):
batch_id += 1

batch_inputs = JiugeBatchedTask(tasks[:batch_id])
logits = torch.zeros(
log_probs = torch.zeros(
(batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits
)
self.jiuge_model.forward_batch(
Expand All @@ -627,12 +627,12 @@ def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10):
batch_inputs.nreq,
batch_inputs.req_pos,
batch_inputs.kv_caches,
logits.data_ptr(),
log_probs.data_ptr(),
)

logits = logits.float()
# forward_batch now returns log_softmax results, no need for additional calculation
log_probs = log_probs.float()
token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,]
log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab)
token_logprobs = log_probs[
torch.arange(batch_inputs.ntok), token_ids
] # (ntok,)
Expand Down
166 changes: 158 additions & 8 deletions scripts/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from libinfinicore_infer import DeviceType
from infer_task import InferTask
from kvcache_pool import KVCachePool
import torch

import argparse
import queue
Expand Down Expand Up @@ -176,17 +177,27 @@ def worker_loop(app):


def build_task(id_, request_data, request: Request):
messages = request_data.get("messages", [])
input_content = request.app.state.model.tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
tokenize=False,
)
tokens = request.app.state.model.tokenizer.encode(input_content)
# Handle both chat and completion formats
if "messages" in request_data:
# Chat format
messages = request_data.get("messages", [])
input_content = request.app.state.model.tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
tokenize=False,
)
tokens = request.app.state.model.tokenizer.encode(input_content)
max_tokens = request_data.get("max_tokens", request.app.state.model.max_context_len())
else:
# Completion format
prompt = request_data.get("prompt", "")
tokens = request.app.state.model.tokenizer.encode(prompt)
max_tokens = request_data.get("max_tokens", 0)

return AsyncInferTask(
id_,
tokens,
request_data.get("max_tokens", request.app.state.model.max_context_len()),
max_tokens,
request_data.get("temperature", 1.0),
request_data.get("top_k", 1),
request_data.get("top_p", 1.0),
Expand Down Expand Up @@ -294,6 +305,145 @@ async def chat_completions(request: Request):
return JSONResponse(content=response)





async def completion(id_, request_data, request: Request):
infer_task = None # Initialize to None to avoid UnboundLocalError
try:
# Check if max_tokens > 0 is requested
max_tokens = request_data.get("max_tokens", 0)
if max_tokens > 0:
return JSONResponse(
content={"error": "max_tokens > 0 is not supported yet. Please use max_tokens=0 for logprobs calculation."},
status_code=400
)

infer_task = build_task(id_, request_data, request)
await request.app.state.kv_cache_pool.acquire(infer_task)

output = []
logprobs = []

# Handle echo and logprobs calculation
echo = request_data.get("echo", False)
if echo:
# Add input tokens to output
input_tokens = infer_task.tokens
for token in input_tokens:
content = (
request.app.state.model.tokenizer._tokenizer.id_to_token(token)
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
output.append(content)

# Calculate logprobs for input tokens
from jiuge import JiugeBatchedTask
batch_inputs = JiugeBatchedTask([infer_task])
log_probs = torch.zeros(
(batch_inputs.ntok, request.app.state.model.meta.dvoc),
dtype=request.app.state.model.meta.torch_dtype_logits
)
request.app.state.model.jiuge_model.forward_batch(
request.app.state.model.model_instance,
batch_inputs.tokens,
batch_inputs.ntok,
batch_inputs.req_lens,
batch_inputs.nreq,
batch_inputs.req_pos,
batch_inputs.kv_caches,
log_probs.data_ptr(),
)

log_probs = log_probs.float()

# Calculate correct logprobs for input tokens
token_logprobs = []
for i in range(len(infer_task.tokens) - 1): # Only up to second-to-last token
next_token = infer_task.tokens[i+1] # Next token to predict
logprob = log_probs[i, next_token].item() # Use position i logits to predict position i+1 token
token_logprobs.append(logprob)

# First token has no context, so logprob is None
logprobs = [None] + token_logprobs
else:
# echo=false: don't calculate logprobs since user can't see input text
logprobs = []

# For max_tokens=0, we need to manually release the KV cache since we don't go through worker
await request.app.state.kv_cache_pool.release(infer_task)
print(f"[DEBUG] {id_} Released KV cache for max_tokens=0")

output_text = "".join(output).strip()

# Prepare tokens list for logprobs
tokens_list = []
text_offset_list = []
current_offset = 0

# Build tokens list and text offsets
for i, content in enumerate(output):
tokens_list.append(content)
text_offset_list.append(current_offset)
current_offset += len(content)

# Build response according to DeepSeek API completion format
response = {
"id": id_,
"object": "text_completion",
"created": int(time.time()),
"model": "jiuge",
"choices": [
{
"text": output_text,
"index": 0,
"logprobs": {
"token_logprobs": logprobs,
"tokens": tokens_list,
"text_offset": text_offset_list,
"top_logprobs": []
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": len(infer_task.tokens),
"prompt_cache_hit_tokens": 0,
"prompt_cache_miss_tokens": len(infer_task.tokens),
"completion_tokens": 0,
"total_tokens": len(infer_task.tokens),
"completion_tokens_details": {
"reasoning_tokens": 0
}
}
}
return response

except Exception as e:
print(f"[Error] ID: {id_} Exception: {e}")
return JSONResponse(content={"error": str(e)}, status_code=500)
finally:
if infer_task and infer_task.finish_reason is None:
infer_task.finish_reason = "cancel"


@App.post("/completions")
async def completions(request: Request):
data = await request.json()

if not data.get("prompt"):
return JSONResponse(content={"error": "No prompt provided"}, status_code=400)

id_ = f"cmpl-{uuid.uuid4().hex}"
response = await completion(id_, data, request)

# Check if response is already a JSONResponse (error case)
if isinstance(response, JSONResponse):
return response
else:
return JSONResponse(content=response)

if __name__ == "__main__":
uvicorn.run(App, host="0.0.0.0", port=8000)

Expand Down
5 changes: 3 additions & 2 deletions scripts/test_ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@

# endcode, chunk and decode
tokens = tokenizer.encode(text, add_special_tokens=False)
for i in range(0, len(tokens), CHUNK_SIZE):
chunk_tokens = tokens[i : min(i + CHUNK_SIZE, len(tokens))]
# 使用与jiuge_ppl.py相同的分割逻辑,只处理完整的chunk
for i in range(0, len(tokens) - CHUNK_SIZE + 1, CHUNK_SIZE):
chunk_tokens = tokens[i : i + CHUNK_SIZE]
chunk_text = tokenizer.decode(chunk_tokens)

resp = requests.post(
Expand Down
2 changes: 2 additions & 0 deletions src/cache_manager/opcache_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class CacheManager {
DECLARE_OP_CACHE(RoPE)
DECLARE_OP_CACHE(Rearrange)
DECLARE_OP_CACHE(CausalSoftmax)
DECLARE_OP_CACHE(LogSoftmax)
DECLARE_OP_CACHE(Topkrouter)
DECLARE_OP_CACHE(SwiGLU)
DECLARE_OP_CACHE(RandomSample)
Expand All @@ -170,6 +171,7 @@ class CacheManager {
RoPE_cache(capacity, DESTROY_FUNC(RoPE)),
Rearrange_cache(capacity, DESTROY_FUNC(Rearrange)),
CausalSoftmax_cache(capacity, DESTROY_FUNC(CausalSoftmax)),
LogSoftmax_cache(capacity, DESTROY_FUNC(LogSoftmax)),
Topkrouter_cache(capacity, DESTROY_FUNC(Topkrouter)),
SwiGLU_cache(capacity, DESTROY_FUNC(SwiGLU)),
RandomSample_cache(capacity, DESTROY_FUNC(RandomSample)),
Expand Down
20 changes: 20 additions & 0 deletions src/models/inference_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,26 @@ void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y,
y->data(), x->data(), stream));
}

void InferenceContext::logSoftmax(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x) {
size_t key = CacheManager::createDescriptorKey(y, x);

infiniopLogSoftmaxDescriptor_t desc;
if (!cache_manager->getLogSoftmaxDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateLogSoftmaxDescriptor(
op_handle, &desc, y->desc(), x->desc()));
cache_manager->putLogSoftmaxDescriptor(key, desc);
}

size_t workspace_size = 0;
RUN_INFINI(infiniopGetLogSoftmaxWorkspaceSize(desc, &workspace_size));
ensure_workspace(workspace_size);
void *workspace = workspace_storage->memory();

RUN_INFINI(infiniopLogSoftmax(desc, workspace, workspace_size,
y->data(), x->data(), stream));
}

void InferenceContext::topkrouter(std::shared_ptr<Tensor> values, // F32
std::shared_ptr<Tensor> indices, // I32
std::shared_ptr<Tensor> x,
Expand Down
6 changes: 6 additions & 0 deletions src/models/inference_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ struct InferenceContext {
infiniopRoPEAlgo_t algo);
void causalSoftmax(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x);
void logSoftmax(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x);

void topkrouter(std::shared_ptr<Tensor> values, // F32
std::shared_ptr<Tensor> indices, // I32
Expand Down Expand Up @@ -111,6 +113,10 @@ inline void causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x)
getInferenceContext().causalSoftmax(y, x);
}

inline void logSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x) {
getInferenceContext().logSoftmax(y, x);
}

inline void topkrouter(std::shared_ptr<Tensor> values, // F32
std::shared_ptr<Tensor> indices, // I32
std::shared_ptr<Tensor> x,
Expand Down
6 changes: 5 additions & 1 deletion src/models/jiuge/jiuge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,12 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
rmsnorm(logits_out, logits_in, rsrc.w_out_norm, meta.epsilon);
auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool);
linear(last_logits_buf, logits_out, rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr);

auto log_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool);
logSoftmax(log_logits_buf, last_logits_buf);

RUN_INFINI(infinirtStreamSynchronize(stream));
RUN_INFINI(infinirtMemcpy(last_logits, last_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H));
RUN_INFINI(infinirtMemcpy(last_logits, log_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H));
}
if (output != nullptr) {
size_t token_offset = 0;
Expand Down