Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ This project provides Python scripts to automatically build [llama.cpp](https://

### Installation
```bash
# install Python dependencies
# install Python dependencies (includes Transformers for reranking)
pip install -r requirements.txt

# run the first build immediately (includes reranking support)
Expand Down Expand Up @@ -49,9 +49,13 @@ python loadmodel.py ./models/bge-embedding-model.gguf --embed --local

#### 3. Reranking Models
```bash
# start reranking server

# start reranking server (GGUF models)
python loadmodel.py gpustack/bge-reranker-v2-m3-GGUF:Q8_0 --rerank

# start Qwen3 reranker server
python loadmodel.py Qwen/Qwen3-Reranker-4B --rerank

# local reranker model
python loadmodel.py ./models/bge-reranker-v2-m3-Q8_0.gguf --rerank --local
```
Expand Down
8 changes: 8 additions & 0 deletions loadmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import shlex
import shutil
import subprocess
import sys
from pathlib import Path

from dotenv import load_dotenv
Expand Down Expand Up @@ -114,6 +115,13 @@ def main():
elif args.rerank:
mode = "rerank"

if mode == "rerank":
script = SCRIPT_DIR / "reranker.py"
cmd = [str(script), args.model, "--serve", "--host", args.host, "--port", str(args.port)]
print("Starting reranker server:\n", shlex.join([sys.executable] + cmd))
os.execv(sys.executable, [sys.executable] + cmd)
return

model_path = resolve_model(args.model, args.local)
server = find_llama_server()
cmd = [str(server)] + build_args(model_path, mode, args.host, args.port, args.ctx_size, args.threads, args.gpu_layers, args.pooling, args.verbose)
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ requests
python-dotenv
colorama
huggingface-hub
torch
transformers
sentencepiece
84 changes: 84 additions & 0 deletions reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/usr/bin/env python3
"""Generic reranker using HuggingFace Transformers.

The script can score a list of documents for a query or run a minimal
HTTP API compatible with the llama.cpp ``/rerank`` endpoint.
"""

import argparse
import json
from typing import List
from http.server import BaseHTTPRequestHandler, HTTPServer

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer


def load_model(repo_or_path: str):
"""Load tokenizer and model from a repo or local path."""
tokenizer = AutoTokenizer.from_pretrained(repo_or_path, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(
repo_or_path, trust_remote_code=True
)
model.eval()
return tokenizer, model


def rerank(query: str, docs: List[str], tokenizer, model, device: str = "cpu") -> List[tuple[str, float]]:
pairs = [f"{query}\n{doc}" for doc in docs]
encoded = tokenizer(pairs, padding=True, truncation=True, return_tensors="pt").to(device)
with torch.no_grad():
logits = model(**encoded).logits[:, 0]
scores = logits.cpu().tolist()
ranked = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
return ranked


def main() -> None:
parser = argparse.ArgumentParser(description="Rerank documents with Transformers")
parser.add_argument("model", help="model repo or local path")
parser.add_argument("query", nargs="?", help="query text")
parser.add_argument("documents", nargs="*", help="documents to rank")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="torch device")
parser.add_argument("--serve", action="store_true", help="run an HTTP server instead of CLI output")
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()

tokenizer, model = load_model(args.model)
model.to(args.device)

if args.serve or not args.query:
class Handler(BaseHTTPRequestHandler):
def do_POST(self):
if self.path != "/rerank":
self.send_error(404)
return
length = int(self.headers.get("Content-Length", 0))
body = self.rfile.read(length)
data = json.loads(body)
query = data.get("query", "")
docs = data.get("documents", [])
top_n = data.get("top_n", len(docs))
ranked = rerank(query, docs, tokenizer, model, args.device)[:top_n]
resp = {
"results": [
{"document": d, "score": s} for d, s in ranked
]
}
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(resp).encode())

server = HTTPServer((args.host, args.port), Handler)
print(f"Reranker serving on {args.host}:{args.port}")
server.serve_forever()
else:
ranked = rerank(args.query, args.documents, tokenizer, model, args.device)
for idx, (doc, score) in enumerate(ranked, 1):
print(f"{idx}. {doc} (score={score:.4f})")


if __name__ == "__main__":
main()