-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathsolve_u_math.py
More file actions
73 lines (63 loc) · 1.92 KB
/
solve_u_math.py
File metadata and controls
73 lines (63 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import json
from openai import OpenAI
from datasets import load_dataset
from tqdm import tqdm
from prompts import solve_cot_prompt, _REASONERS
def main():
# Parse arguments
parser = argparse.ArgumentParser(
description="Solve U-Math problems using CoT prompt."
)
parser.add_argument(
"--base_url",
type=str,
default="https://api.openai.com/v1",
help="Base url for OpenAI-compatible endpoint.",
)
parser.add_argument(
"--api_key",
type=str,
default="stub",
help="Your API key for OpenAI-compatible endpoint.",
)
parser.add_argument(
"--model",
type=str,
default="gpt-4o-mini",
help="Model name for OpenAI-compatible endpoint.",
)
parser.add_argument(
"--output_file",
type=str,
default="predictions_u_math.json",
help="Output JSON file.",
)
args = parser.parse_args()
# Load the dataset
dataset = load_dataset("toloka/u-math", split="test")
# Make openai client
client = OpenAI(api_key=args.api_key, base_url=args.base_url)
# Predict with CoT prompt
predictions = {}
params = (
{'temperature': 0., 'max_tokens': 4096} if args.model not in _REASONERS else
{'reasoning_effort': 'high'}
)
for item in tqdm(dataset):
prompt = solve_cot_prompt(
problem_statement=item["problem_statement"],
image=item["image"],
)
response = client.chat.completions.create(
messages=prompt,
model=args.model,
**params
)
predictions[item["uuid"]] = response.choices[0].message.content
# Save predictions to JSON file
with open(args.output_file, "w") as f:
json.dump(predictions, f, indent=2)
print(f"Predictions saved to {args.output_file} as uuid -> prediction json.")
if __name__ == "__main__":
main()