diff --git a/install.sh b/install.sh index 36ff7b5c..11905594 100755 --- a/install.sh +++ b/install.sh @@ -13,7 +13,7 @@ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" EXTRAS="${PARALLAX_EXTRAS:-}" PYTHON_VERSION="${PARALLAX_PYTHON_VERSION:-3.12}" VENV_DIR="$SCRIPT_DIR/.venv" -VLLM_REF="${VLLM_REF:-main}" +VLLM_REF="${VLLM_REF:-v0.22.0}" show_help() { cat <<'EOF' @@ -30,7 +30,7 @@ Options: Environment: PARALLAX_EXTRAS Same as --extras. PARALLAX_PYTHON_VERSION Same as --python. - VLLM_REF vLLM git branch/tag to clone. Defaults to main. + VLLM_REF vLLM git branch/tag to clone. Defaults to v0.22.0. EOF } @@ -218,15 +218,28 @@ build_vllm_rust_frontend() { local rust_dir local parallax_scripts_dir local target_path + local target_version_path + local existing_version local toolchain parallax_scripts_dir="$(resolve_venv_bin_dir)" target_path="$parallax_scripts_dir/vllm-rs" + target_version_path="$target_path.version" if [[ -f "$target_path" ]]; then - chmod +x "$target_path" - echo "vllm-rs already exists at $target_path, skipping Rust build." - return + existing_version="" + if [[ -f "$target_version_path" ]]; then + existing_version="$(<"$target_version_path")" + fi + if [[ "$existing_version" != "$VLLM_REF" ]]; then + echo "Existing vllm-rs version (${existing_version:-unknown}) does not match $VLLM_REF, rebuilding." + rm -f "$target_path" "$target_version_path" + else + chmod +x "$target_path" + printf '%s\n' "$VLLM_REF" > "$target_version_path" + echo "vllm-rs already exists at $target_path, skipping Rust build." + return + fi fi CLONE_PARENT="$(mktemp -d "${TMPDIR:-/tmp}/parallax-vllm-rs.XXXXXX")" @@ -264,6 +277,7 @@ build_vllm_rust_frontend() { mkdir -p "$(dirname "$target_path")" cp "$rust_dir/target/release/vllm-rs" "$target_path" chmod +x "$target_path" + printf '%s\n' "$VLLM_REF" > "$target_version_path" echo "Installed vllm-rs to $target_path" cleanup_clone trap - EXIT diff --git a/pyproject.toml b/pyproject.toml index ee8ed28b..10b926b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ mac = [ gpu = [ "sglang[all]==0.5.12", + "kernels<0.15", "accelerate", "mlx-lm==0.31.3", "mlx[cpu]==0.31.2", diff --git a/scripts/generate.py b/scripts/generate.py index fefbb46f..0348a205 100644 --- a/scripts/generate.py +++ b/scripts/generate.py @@ -22,6 +22,7 @@ """ import argparse +import statistics import time import mlx.core as mx @@ -80,6 +81,21 @@ def build_prompt(messages, tokenizer): return full_prompt, prompt_tokens +def percentile(values, percentile_value): + if not values: + return 0 + + sorted_values = sorted(values) + if len(sorted_values) == 1: + return sorted_values[0] + + rank = (len(sorted_values) - 1) * percentile_value / 100 + lower = int(rank) + upper = min(lower + 1, len(sorted_values) - 1) + weight = rank - lower + return sorted_values[lower] * (1 - weight) + sorted_values[upper] * weight + + def main(): parser = argparse.ArgumentParser(description="Simple offline inference script") parser.add_argument( @@ -226,7 +242,7 @@ def main(): print_rank(f"Token 1 (Prefill) time: {prefill_time * 1000:.2f} ms") # 5. Decode Loop - total_decode_time = 0 + decode_step_times = [] for i in range(args.max_tokens - 1): if is_finished: break @@ -260,13 +276,14 @@ def main(): request.commit_new_token(token_id) decode_step_time = time.perf_counter() - decode_step_start - total_decode_time += decode_step_time + decode_step_times.append(decode_step_time) print_rank(f"Token {i + 2} time: {decode_step_time * 1000:.2f} ms") print_rank("\nGenerated Content:") print_rank(tokenizer.decode(request.output_ids)) # Summary Statistics + total_decode_time = sum(decode_step_times) prompt_tps = request.prompt_len / prefill_time generation_tps = len(request.output_ids) / total_decode_time if total_decode_time > 0 else 0 peak_mem = mx.get_peak_memory() / 1024**3 @@ -274,6 +291,22 @@ def main(): print_rank("-" * 20) print_rank(f"Prompt: {request.prompt_len} tokens, {prompt_tps:.3f} tokens-per-sec") print_rank(f"Generation: {len(request.output_ids)} tokens, {generation_tps:.3f} tokens-per-sec") + if decode_step_times: + decode_step_times_ms = [step_time * 1000 for step_time in decode_step_times] + decode_step_time_mean = sum(decode_step_times_ms) / len(decode_step_times_ms) + print_rank( + "Decode step time (ms): " + f"min={min(decode_step_times_ms):.2f}, " + f"median={percentile(decode_step_times_ms, 50):.2f}, " + f"max={max(decode_step_times_ms):.2f}, " + f"mean={decode_step_time_mean:.2f}, " + f"std={statistics.pstdev(decode_step_times_ms):.2f}, " + f"p90={percentile(decode_step_times_ms, 90):.2f}, " + f"p95={percentile(decode_step_times_ms, 95):.2f}, " + f"p99={percentile(decode_step_times_ms, 99):.2f}" + ) + else: + print_rank("Decode step time (ms): n/a") print_rank(f"Peak memory: {peak_mem:.3f} GB") cache_manager.free_request(request.request_id)