Skip to content

Commit 0a56823

Browse files
SemyonEpanovfxlrnrpt
authored andcommitted
Refactor synth_aug_mmlu.py to use Parquet instead of JSONL
1 parent 7729986 commit 0a56823

File tree

1 file changed

+106
-80
lines changed

1 file changed

+106
-80
lines changed

src/core/distillation/synth_aug_mmlu.py

Lines changed: 106 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import ast
2-
import json
32
import logging
43
import math
54
import os
@@ -34,8 +33,13 @@ def letters_for(n: int):
3433

3534

3635
def parse_options(s):
37-
lst = ast.literal_eval(s)
38-
return list(map(str, lst))
36+
if isinstance(s, list):
37+
return list(map(str, s))
38+
try:
39+
lst = ast.literal_eval(str(s))
40+
return list(map(str, lst))
41+
except:
42+
return []
3943

4044

4145
def norm_letter_dyn(x, letters):
@@ -207,65 +211,48 @@ def _branch_c(q, choices, gold, model, max_tokens, subject, prev_answer, prev_re
207211

208212

209213
# ------------ helpers for branch C ------------
210-
def _load_incorrect_from_branch_a(a_jsonl_path: str, expected_model: str | None) -> dict[int, dict]:
211-
bad: dict[int, dict] = {}
212-
with open(a_jsonl_path, "r", encoding="utf-8") as f:
213-
for line in f:
214-
try:
215-
rec = json.loads(line)
216-
except Exception:
217-
continue
218-
inp = rec.get("input") or {}
219-
out = rec.get("output") or {}
220-
if "error" in out:
221-
continue
222-
if expected_model is not None and (inp.get("model") != expected_model):
223-
continue
224-
row_id = inp.get("row_id")
225-
if row_id is None:
226-
continue
227-
gold = (inp.get("gold") or "").strip().upper()
228-
ans = (out.get("answer") or "").strip().upper()
229-
is_correct = out.get("is_correct")
230-
if is_correct is None:
231-
is_correct = ans == gold
232-
if not is_correct:
233-
bad[int(row_id)] = {
234-
"preivous_answer": ans,
235-
"thinking": out.get("thinking") or "",
236-
}
214+
def _load_incorrect_from_branch_a(a_parquet_path: str, expected_model: str | None) -> dict[int, dict]:
215+
if not a_parquet_path or not os.path.exists(a_parquet_path):
216+
return {}
217+
218+
try:
219+
df = pd.read_parquet(a_parquet_path)
220+
except Exception:
221+
return {}
222+
223+
bad = {}
224+
for _, row in df.iterrows():
225+
inp, out = row["input"], row["output"]
226+
if "error" in out:
227+
continue
228+
if expected_model and inp.get("model") != expected_model:
229+
continue
230+
231+
# Check correctness (prefer explicit flag, fallback to string comparison)
232+
is_correct = out.get("is_correct")
233+
if is_correct is None:
234+
is_correct = (out.get("answer") or "").strip().upper() == (inp.get("gold") or "").strip().upper()
235+
236+
if not is_correct:
237+
bad[int(inp["question_id"])] = {
238+
"model_answer": out.get("answer"),
239+
"thinking": out.get("thinking", ""),
240+
}
237241
return bad
238242

239243

240-
def _load_and_clean_existing(out_jsonl: str) -> set[int]:
241-
if not os.path.exists(out_jsonl):
244+
def _load_existing_ids(out_parquet: str) -> set[int]:
245+
if not os.path.exists(out_parquet):
246+
return set()
247+
try:
248+
df = pd.read_parquet(out_parquet, columns=["input", "output"])
249+
return {
250+
int(row["input"]["question_id"])
251+
for _, row in df.iterrows()
252+
if "error" not in row["output"]
253+
}
254+
except Exception:
242255
return set()
243-
244-
valid_ids = set()
245-
valid_lines = []
246-
247-
with open(out_jsonl, "r", encoding="utf-8") as f:
248-
for line in f:
249-
line = line.strip()
250-
if not line:
251-
continue
252-
try:
253-
rec = json.loads(line)
254-
# Check if output has error
255-
if "error" not in rec.get("output", {}):
256-
rid = rec.get("input", {}).get("question_id")
257-
if rid is not None:
258-
valid_ids.add(int(rid))
259-
valid_lines.append(line)
260-
except Exception:
261-
pass
262-
263-
# Rewrite file with only valid lines
264-
with open(out_jsonl, "w", encoding="utf-8") as f:
265-
for line in valid_lines:
266-
f.write(line + "\n")
267-
268-
return valid_ids
269256

270257

271258
# ------------ dataset -------------
@@ -316,52 +303,58 @@ def _run_job(job):
316303

317304
def synth_on_dataset(
318305
in_filename: str,
319-
out_jsonl: str,
306+
out_filename: str,
320307
model: str,
321308
max_tokens: int,
322309
dump_every: int,
323310
limit: int | None,
324311
branch: str,
325312
chunk_size: int,
326-
a_jsonl_path: str | None,
327-
temperature: float = 0, # [warning]: temperature for all branches
313+
a_file_path: str | None,
314+
temperature: float = 0,
328315
):
329316
assert branch in {"A", "B", "C"}
330317
if branch == "C":
331-
assert a_jsonl_path and os.path.exists(a_jsonl_path), (
332-
"Branch C requires a valid path to branch-A results (a_jsonl_path)."
318+
assert a_file_path and os.path.exists(a_file_path), (
319+
"Branch C requires a valid path to branch-A parquet results (a_file_path)."
333320
)
334321

322+
# Read input dataset (TSV/CSV)
335323
df = pd.read_csv(in_filename, sep="\t", dtype=str, keep_default_na=False)
336324
total_rows = len(df) if limit is None else min(len(df), int(limit))
337325
total_chunks = max(1, math.ceil(total_rows / max(1, chunk_size)))
338326

339-
os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
327+
os.makedirs(os.path.dirname(out_filename) or ".", exist_ok=True)
340328

341-
existing_ids = _load_and_clean_existing(out_jsonl)
342-
logging.warning(f"Found {len(existing_ids)} valid records in {out_jsonl}. Errors removed.")
329+
# Load existing progress
330+
existing_ids = _load_existing_ids(out_filename)
331+
logging.warning(f"Found {len(existing_ids)} valid records in {out_filename}.")
343332

344-
# pre-load A-incorrects for branch C
333+
# Pre-load A-incorrects for branch C
345334
a_incorrect_map: dict[int, dict] = {}
346335
ids_for_c: set[int] = set()
347336
if branch == "C":
348-
a_incorrect_map = _load_incorrect_from_branch_a(a_jsonl_path, expected_model=model)
337+
a_incorrect_map = _load_incorrect_from_branch_a(a_file_path, expected_model=model)
349338
ids_for_c = set(a_incorrect_map.keys())
339+
logging.info(f"Loaded {len(ids_for_c)} incorrect answers from Branch A for processing.")
350340

351341
written = 0
352342
stop = False
343+
buffer = []
353344

354-
with open(out_jsonl, "a", encoding="utf-8") as f, futures.ThreadPoolExecutor(max_workers=chunk_size) as pool:
345+
with futures.ThreadPoolExecutor(max_workers=chunk_size) as pool:
355346
for chunk_idx, chunk in tqdm(enumerate(chunker(df, chunk_size)), total=total_chunks, desc=f"Synth {branch}"):
356347
if stop:
357348
break
358349

359350
args_list = []
360351
for index, row in chunk.iterrows():
361-
if int(row["question_id"]) in existing_ids:
352+
# Check if already processed
353+
qid = row.get("question_id")
354+
if qid and int(qid) in existing_ids:
362355
continue
363356

364-
if limit is not None and written >= limit:
357+
if limit is not None and (len(existing_ids) + written) >= limit:
365358
stop = True
366359
break
367360

@@ -389,10 +382,17 @@ def synth_on_dataset(
389382
prev_ans = None
390383
prev_thinking = None
391384
if branch == "C":
392-
if index not in ids_for_c:
385+
# Only process rows where Branch A failed
386+
# Use question_id for lookup, NOT the dataframe index
387+
try:
388+
qid_int = int(question_id)
389+
except (ValueError, TypeError):
390+
continue
391+
392+
if qid_int not in ids_for_c:
393393
continue
394-
prev_ans = a_incorrect_map[index].get("model_answer")
395-
prev_thinking = a_incorrect_map[index].get("thinking")
394+
prev_ans = a_incorrect_map[qid_int].get("model_answer")
395+
prev_thinking = a_incorrect_map[qid_int].get("thinking")
396396

397397
args_list.append(
398398
(
@@ -417,10 +417,36 @@ def synth_on_dataset(
417417
results = list(pool.map(_run_job, args_list))
418418

419419
for row_id, record_in, record_out in results:
420-
f.write(json.dumps({"input": record_in, "output": record_out}, ensure_ascii=False) + "\n")
420+
buffer.append({"input": record_in, "output": record_out})
421421
written += 1
422-
if dump_every > 0 and (written % dump_every == 0):
423-
f.flush()
424-
425-
print(f"Saved to {out_jsonl}. Rows considered: {len(df)}; written: {written}; branch={branch}; model={model}.")
426-
return out_jsonl
422+
423+
# Dump to parquet periodically
424+
if dump_every > 0 and len(buffer) >= dump_every:
425+
try:
426+
new_df = pd.DataFrame(buffer)
427+
if os.path.exists(out_filename):
428+
existing_df = pd.read_parquet(out_filename)
429+
combined_df = pd.concat([existing_df, new_df], ignore_index=True)
430+
else:
431+
combined_df = new_df
432+
433+
combined_df.to_parquet(out_filename, index=False)
434+
buffer = [] # Clear buffer after successful write
435+
except Exception as e:
436+
logging.error(f"Failed to write parquet batch: {e}")
437+
438+
# Final flush
439+
if buffer:
440+
try:
441+
new_df = pd.DataFrame(buffer)
442+
if os.path.exists(out_filename):
443+
existing_df = pd.read_parquet(out_filename)
444+
combined_df = pd.concat([existing_df, new_df], ignore_index=True)
445+
else:
446+
combined_df = new_df
447+
combined_df.to_parquet(out_filename, index=False)
448+
except Exception as e:
449+
logging.error(f"Failed to write final parquet batch: {e}")
450+
451+
print(f"Saved to {out_filename}. Rows considered: {len(df)}; written: {written}; branch={branch}; model={model}.")
452+
return out_filename

0 commit comments

Comments
 (0)