Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
EXTRAS="${PARALLAX_EXTRAS:-}"
PYTHON_VERSION="${PARALLAX_PYTHON_VERSION:-3.12}"
VENV_DIR="$SCRIPT_DIR/.venv"
VLLM_REF="${VLLM_REF:-main}"
VLLM_REF="${VLLM_REF:-v0.22.0}"

show_help() {
cat <<'EOF'
Expand All @@ -30,7 +30,7 @@ Options:
Environment:
PARALLAX_EXTRAS Same as --extras.
PARALLAX_PYTHON_VERSION Same as --python.
VLLM_REF vLLM git branch/tag to clone. Defaults to main.
VLLM_REF vLLM git branch/tag to clone. Defaults to v0.22.0.
EOF
}

Expand Down Expand Up @@ -218,15 +218,28 @@ build_vllm_rust_frontend() {
local rust_dir
local parallax_scripts_dir
local target_path
local target_version_path
local existing_version
local toolchain

parallax_scripts_dir="$(resolve_venv_bin_dir)"
target_path="$parallax_scripts_dir/vllm-rs"
target_version_path="$target_path.version"

if [[ -f "$target_path" ]]; then
chmod +x "$target_path"
echo "vllm-rs already exists at $target_path, skipping Rust build."
return
existing_version=""
if [[ -f "$target_version_path" ]]; then
existing_version="$(<"$target_version_path")"
fi
if [[ "$existing_version" != "$VLLM_REF" ]]; then
echo "Existing vllm-rs version (${existing_version:-unknown}) does not match $VLLM_REF, rebuilding."
rm -f "$target_path" "$target_version_path"
else
chmod +x "$target_path"
printf '%s\n' "$VLLM_REF" > "$target_version_path"
echo "vllm-rs already exists at $target_path, skipping Rust build."
return
fi
fi

CLONE_PARENT="$(mktemp -d "${TMPDIR:-/tmp}/parallax-vllm-rs.XXXXXX")"
Expand Down Expand Up @@ -264,6 +277,7 @@ build_vllm_rust_frontend() {
mkdir -p "$(dirname "$target_path")"
cp "$rust_dir/target/release/vllm-rs" "$target_path"
chmod +x "$target_path"
printf '%s\n' "$VLLM_REF" > "$target_version_path"
echo "Installed vllm-rs to $target_path"
cleanup_clone
trap - EXIT
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ mac = [

gpu = [
"sglang[all]==0.5.12",
"kernels<0.15",
"accelerate",
"mlx-lm==0.31.3",
"mlx[cpu]==0.31.2",
Expand Down
37 changes: 35 additions & 2 deletions scripts/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""

import argparse
import statistics
import time

import mlx.core as mx
Expand Down Expand Up @@ -80,6 +81,21 @@ def build_prompt(messages, tokenizer):
return full_prompt, prompt_tokens


def percentile(values, percentile_value):
if not values:
return 0

sorted_values = sorted(values)
if len(sorted_values) == 1:
return sorted_values[0]

rank = (len(sorted_values) - 1) * percentile_value / 100
lower = int(rank)
upper = min(lower + 1, len(sorted_values) - 1)
weight = rank - lower
return sorted_values[lower] * (1 - weight) + sorted_values[upper] * weight


def main():
parser = argparse.ArgumentParser(description="Simple offline inference script")
parser.add_argument(
Expand Down Expand Up @@ -226,7 +242,7 @@ def main():
print_rank(f"Token 1 (Prefill) time: {prefill_time * 1000:.2f} ms")

# 5. Decode Loop
total_decode_time = 0
decode_step_times = []
for i in range(args.max_tokens - 1):
if is_finished:
break
Expand Down Expand Up @@ -260,20 +276,37 @@ def main():
request.commit_new_token(token_id)

decode_step_time = time.perf_counter() - decode_step_start
total_decode_time += decode_step_time
decode_step_times.append(decode_step_time)
print_rank(f"Token {i + 2} time: {decode_step_time * 1000:.2f} ms")

print_rank("\nGenerated Content:")
print_rank(tokenizer.decode(request.output_ids))

# Summary Statistics
total_decode_time = sum(decode_step_times)
prompt_tps = request.prompt_len / prefill_time
generation_tps = len(request.output_ids) / total_decode_time if total_decode_time > 0 else 0
peak_mem = mx.get_peak_memory() / 1024**3

print_rank("-" * 20)
print_rank(f"Prompt: {request.prompt_len} tokens, {prompt_tps:.3f} tokens-per-sec")
print_rank(f"Generation: {len(request.output_ids)} tokens, {generation_tps:.3f} tokens-per-sec")
if decode_step_times:
decode_step_times_ms = [step_time * 1000 for step_time in decode_step_times]
decode_step_time_mean = sum(decode_step_times_ms) / len(decode_step_times_ms)
print_rank(
"Decode step time (ms): "
f"min={min(decode_step_times_ms):.2f}, "
f"median={percentile(decode_step_times_ms, 50):.2f}, "
f"max={max(decode_step_times_ms):.2f}, "
f"mean={decode_step_time_mean:.2f}, "
f"std={statistics.pstdev(decode_step_times_ms):.2f}, "
f"p90={percentile(decode_step_times_ms, 90):.2f}, "
f"p95={percentile(decode_step_times_ms, 95):.2f}, "
f"p99={percentile(decode_step_times_ms, 99):.2f}"
)
else:
print_rank("Decode step time (ms): n/a")
print_rank(f"Peak memory: {peak_mem:.3f} GB")
cache_manager.free_request(request.request_id)

Expand Down
Loading