Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions test/test_api/test_gsmk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# Adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py
import argparse
import ast
import json
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Optional

import numpy as np
import requests
from tqdm import tqdm

INVALID = -9999999


def read_jsonl(filename: str):
"""Read a JSONL file."""
with open(filename) as fin:
for line in fin:
if line.startswith("#"):
continue
yield json.loads(line)


def dump_state_text(filename: str, states: list, mode: str = "w"):
"""Dump program state in a text file."""
with open(filename, mode) as fout:
for i, s in enumerate(states):
if isinstance(s, str):
fout.write(f"==== {i} ====\n{s}\n")
else:
fout.write(f"==== {i} ====\n{str(s)}\n")


def download_and_cache_file(url: str, filename: Optional[str] = None):
"""Read and cache a file from a url."""
if filename is None:
filename = os.path.join("/tmp", url.split("/")[-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding the /tmp directory is not portable and will fail on non-Unix systems like Windows. It's better to use the tempfile module to get the path to the system's temporary directory. You'll need to add import tempfile at the top of the file.

Suggested change
filename = os.path.join("/tmp", url.split("/")[-1])
filename = os.path.join(tempfile.gettempdir(), url.split("/")[-1])


# Check if the cache file already exists
if os.path.exists(filename):
return filename

print(f"Downloading from {url} to {filename}")

# Stream the response to show the progress bar
response = requests.get(url, stream=True)
response.raise_for_status() # Check for request errors

# Total size of the file in bytes
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB

# Use tqdm to display the progress bar
with open(filename, "wb") as file, tqdm(
desc="Downloading",
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
size = file.write(chunk)
bar.update(size)

return filename


def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding type hints to function signatures improves code clarity, makes it easier to understand for other developers, and enables static analysis tools to catch potential bugs.

Suggested change
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
def call_generate_lightllm(prompt: str, temperature: float, max_tokens: int, stop: Optional[list] = None, url: Optional[str] = None) -> str:

"""Call LightLLM API for text generation."""
assert url is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert for input validation is not ideal, as assertions can be disabled in optimized builds (e.g., running with python -O). It's more robust to raise a ValueError to ensure the check is always performed.

Suggested change
assert url is not None
if url is None:
raise ValueError("The 'url' parameter must be provided.")


data = {
"inputs": prompt,
"parameters": {
"temperature": temperature,
"max_new_tokens": max_tokens,
"stop_sequences": stop,
},
}
res = requests.post(url, json=data)
assert res.status_code == 200, f"API request failed with status code {res.status_code}: {res.text}"

response_json = res.json()
if "generated_text" not in response_json:
raise ValueError(f"Invalid API response format. Expected 'generated_text' key, got: {response_json.keys()}")
if not isinstance(response_json["generated_text"], list) or len(response_json["generated_text"]) == 0:
raise ValueError(
"Invalid API response format. 'generated_text' should be a non-empty list, "
f"got: {response_json['generated_text']}"
)

pred = response_json["generated_text"][0]
return pred


def get_one_example(lines, i, include_answer):
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret
Comment on lines +100 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using f-strings is generally more readable and can be more performant than repeated string concatenation with +.

Suggested change
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret
ret = f"Question: {lines[i]['question']}\nAnswer:"
if include_answer:
ret += f" {lines[i]['answer']}"
return ret



def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
Comment on lines +107 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Building a string in a loop using += can be inefficient for a large number of iterations. A more Pythonic and performant approach is to use a generator expression with str.join().

Suggested change
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
return "".join(get_one_example(lines, i, True) + "\n\n" for i in range(k))



def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID


def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--num-questions", type=int, default=200)
parser.add_argument("--result-file", type=str, default="result.jsonl")
parser.add_argument("--data-path", type=str, default="test.jsonl")
return parser.parse_args()


def main(args):
# LightLLM API URL
url = f"{args.host}:{args.port}/generate"

# Read data
url_data = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The --data-path command-line argument is defined but is not being used. The script currently ignores it and always attempts to download the data. To allow users to specify a local data file or a custom cache path, you should pass this argument to the download_and_cache_file function.

Suggested change
filename = download_and_cache_file(url_data)
filename = download_and_cache_file(url_data, args.data_path)

lines = list(read_jsonl(filename))

# Construct prompts
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)

# Ensure we have enough samples and avoid data leakage
# Test questions should start after few-shot examples
max_available = len(lines) - num_shots
if num_questions > max_available:
print(
"Warning: Requested {} questions, but only {} available after reserving {} for few-shot. "
"Using {} questions.".format(num_questions, max_available, num_shots, max_available)
)
num_questions = max_available

questions = []
labels = []
for i in range(num_shots, num_shots + num_questions):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(label != INVALID for label in labels)

states = [None] * len(labels)

# Run requests using thread pool
def get_one_answer(i):
answer = call_generate_lightllm(
prompt=few_shot_examples + questions[i],
temperature=0,
max_tokens=256,
stop=["Question", "Assistant:", "<|separator|>"],
url=url,
)
states[i] = answer

tic = time.perf_counter()
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)

latency = time.perf_counter() - tic

preds = []
for i in range(len(states)):
preds.append(get_answer_value(states[i]))

# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)

# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")

# Dump results
dump_state_text("tmp_output_lightllm.txt", states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output filename tmp_output_lightllm.txt is hardcoded. This can be inconvenient, especially if running multiple tests, as they would overwrite the same file. Consider making this configurable via a command-line argument (e.g., --output-file).


with open(args.result_file, "a") as fout:
value = {
"task": "gsm8k",
"backend": "lightllm",
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
Comment on lines +220 to +223
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The result JSON contains num_requests at the top level and num_questions inside the other dictionary, both holding the same value. This is redundant. To improve clarity, it's best to remove the duplicate key from the other dictionary.

            "other": {
                "parallel": args.parallel,
            },

}
fout.write(json.dumps(value) + "\n")


if __name__ == "__main__":
args = parse_args()
main(args)