11import ast
2- import json
32import logging
43import math
54import os
@@ -34,8 +33,13 @@ def letters_for(n: int):
3433
3534
3635def parse_options (s ):
37- lst = ast .literal_eval (s )
38- return list (map (str , lst ))
36+ if isinstance (s , list ):
37+ return list (map (str , s ))
38+ try :
39+ lst = ast .literal_eval (str (s ))
40+ return list (map (str , lst ))
41+ except :
42+ return []
3943
4044
4145def norm_letter_dyn (x , letters ):
@@ -207,65 +211,48 @@ def _branch_c(q, choices, gold, model, max_tokens, subject, prev_answer, prev_re
207211
208212
209213# ------------ helpers for branch C ------------
210- def _load_incorrect_from_branch_a (a_jsonl_path : str , expected_model : str | None ) -> dict [int , dict ]:
211- bad : dict [ int , dict ] = {}
212- with open ( a_jsonl_path , "r" , encoding = "utf-8" ) as f :
213- for line in f :
214- try :
215- rec = json . loads ( line )
216- except Exception :
217- continue
218- inp = rec . get ( "input" ) or {}
219- out = rec . get ( "output" ) or {}
220- if "error" in out :
221- continue
222- if expected_model is not None and ( inp . get ( "model" ) != expected_model ) :
223- continue
224- row_id = inp .get ("row_id" )
225- if row_id is None :
226- continue
227- gold = ( inp . get ( "gold" ) or "" ). strip (). upper ( )
228- ans = ( out .get ("answer" ) or "" ). strip (). upper ( )
229- is_correct = out . get ( "is_correct" )
230- if is_correct is None :
231- is_correct = ans == gold
232- if not is_correct :
233- bad [int (row_id )] = {
234- "preivous_answer " : ans ,
235- "thinking" : out .get ("thinking" ) or "" ,
236- }
214+ def _load_incorrect_from_branch_a (a_parquet_path : str , expected_model : str | None ) -> dict [int , dict ]:
215+ if not a_parquet_path or not os . path . exists ( a_parquet_path ):
216+ return {}
217+
218+ try :
219+ df = pd . read_parquet ( a_parquet_path )
220+ except Exception :
221+ return {}
222+
223+ bad = {}
224+ for _ , row in df . iterrows () :
225+ inp , out = row [ "input" ], row [ "output" ]
226+ if "error" in out :
227+ continue
228+ if expected_model and inp .get ("model" ) != expected_model :
229+ continue
230+
231+ # Check correctness (prefer explicit flag, fallback to string comparison )
232+ is_correct = out .get ("is_correct" )
233+ if is_correct is None :
234+ is_correct = ( out . get ( "answer" ) or "" ). strip (). upper () == ( inp . get ( "gold" ) or "" ). strip (). upper ()
235+
236+ if not is_correct :
237+ bad [int (inp [ "question_id" ] )] = {
238+ "model_answer " : out . get ( "answer" ) ,
239+ "thinking" : out .get ("thinking" , "" ) ,
240+ }
237241 return bad
238242
239243
240- def _load_and_clean_existing (out_jsonl : str ) -> set [int ]:
241- if not os .path .exists (out_jsonl ):
244+ def _load_existing_ids (out_parquet : str ) -> set [int ]:
245+ if not os .path .exists (out_parquet ):
246+ return set ()
247+ try :
248+ df = pd .read_parquet (out_parquet , columns = ["input" , "output" ])
249+ return {
250+ int (row ["input" ]["question_id" ])
251+ for _ , row in df .iterrows ()
252+ if "error" not in row ["output" ]
253+ }
254+ except Exception :
242255 return set ()
243-
244- valid_ids = set ()
245- valid_lines = []
246-
247- with open (out_jsonl , "r" , encoding = "utf-8" ) as f :
248- for line in f :
249- line = line .strip ()
250- if not line :
251- continue
252- try :
253- rec = json .loads (line )
254- # Check if output has error
255- if "error" not in rec .get ("output" , {}):
256- rid = rec .get ("input" , {}).get ("question_id" )
257- if rid is not None :
258- valid_ids .add (int (rid ))
259- valid_lines .append (line )
260- except Exception :
261- pass
262-
263- # Rewrite file with only valid lines
264- with open (out_jsonl , "w" , encoding = "utf-8" ) as f :
265- for line in valid_lines :
266- f .write (line + "\n " )
267-
268- return valid_ids
269256
270257
271258# ------------ dataset -------------
@@ -316,52 +303,58 @@ def _run_job(job):
316303
317304def synth_on_dataset (
318305 in_filename : str ,
319- out_jsonl : str ,
306+ out_filename : str ,
320307 model : str ,
321308 max_tokens : int ,
322309 dump_every : int ,
323310 limit : int | None ,
324311 branch : str ,
325312 chunk_size : int ,
326- a_jsonl_path : str | None ,
327- temperature : float = 0 , # [warning]: temperature for all branches
313+ a_file_path : str | None ,
314+ temperature : float = 0 ,
328315):
329316 assert branch in {"A" , "B" , "C" }
330317 if branch == "C" :
331- assert a_jsonl_path and os .path .exists (a_jsonl_path ), (
332- "Branch C requires a valid path to branch-A results (a_jsonl_path )."
318+ assert a_file_path and os .path .exists (a_file_path ), (
319+ "Branch C requires a valid path to branch-A parquet results (a_file_path )."
333320 )
334321
322+ # Read input dataset (TSV/CSV)
335323 df = pd .read_csv (in_filename , sep = "\t " , dtype = str , keep_default_na = False )
336324 total_rows = len (df ) if limit is None else min (len (df ), int (limit ))
337325 total_chunks = max (1 , math .ceil (total_rows / max (1 , chunk_size )))
338326
339- os .makedirs (os .path .dirname (out_jsonl ) or "." , exist_ok = True )
327+ os .makedirs (os .path .dirname (out_filename ) or "." , exist_ok = True )
340328
341- existing_ids = _load_and_clean_existing (out_jsonl )
342- logging .warning (f"Found { len (existing_ids )} valid records in { out_jsonl } . Errors removed." )
329+ # Load existing progress
330+ existing_ids = _load_existing_ids (out_filename )
331+ logging .warning (f"Found { len (existing_ids )} valid records in { out_filename } ." )
343332
344- # pre -load A-incorrects for branch C
333+ # Pre -load A-incorrects for branch C
345334 a_incorrect_map : dict [int , dict ] = {}
346335 ids_for_c : set [int ] = set ()
347336 if branch == "C" :
348- a_incorrect_map = _load_incorrect_from_branch_a (a_jsonl_path , expected_model = model )
337+ a_incorrect_map = _load_incorrect_from_branch_a (a_file_path , expected_model = model )
349338 ids_for_c = set (a_incorrect_map .keys ())
339+ logging .info (f"Loaded { len (ids_for_c )} incorrect answers from Branch A for processing." )
350340
351341 written = 0
352342 stop = False
343+ buffer = []
353344
354- with open ( out_jsonl , "a" , encoding = "utf-8" ) as f , futures .ThreadPoolExecutor (max_workers = chunk_size ) as pool :
345+ with futures .ThreadPoolExecutor (max_workers = chunk_size ) as pool :
355346 for chunk_idx , chunk in tqdm (enumerate (chunker (df , chunk_size )), total = total_chunks , desc = f"Synth { branch } " ):
356347 if stop :
357348 break
358349
359350 args_list = []
360351 for index , row in chunk .iterrows ():
361- if int (row ["question_id" ]) in existing_ids :
352+ # Check if already processed
353+ qid = row .get ("question_id" )
354+ if qid and int (qid ) in existing_ids :
362355 continue
363356
364- if limit is not None and written >= limit :
357+ if limit is not None and ( len ( existing_ids ) + written ) >= limit :
365358 stop = True
366359 break
367360
@@ -389,10 +382,17 @@ def synth_on_dataset(
389382 prev_ans = None
390383 prev_thinking = None
391384 if branch == "C" :
392- if index not in ids_for_c :
385+ # Only process rows where Branch A failed
386+ # Use question_id for lookup, NOT the dataframe index
387+ try :
388+ qid_int = int (question_id )
389+ except (ValueError , TypeError ):
390+ continue
391+
392+ if qid_int not in ids_for_c :
393393 continue
394- prev_ans = a_incorrect_map [index ].get ("model_answer" )
395- prev_thinking = a_incorrect_map [index ].get ("thinking" )
394+ prev_ans = a_incorrect_map [qid_int ].get ("model_answer" )
395+ prev_thinking = a_incorrect_map [qid_int ].get ("thinking" )
396396
397397 args_list .append (
398398 (
@@ -417,10 +417,36 @@ def synth_on_dataset(
417417 results = list (pool .map (_run_job , args_list ))
418418
419419 for row_id , record_in , record_out in results :
420- f . write ( json . dumps ( {"input" : record_in , "output" : record_out }, ensure_ascii = False ) + " \n " )
420+ buffer . append ( {"input" : record_in , "output" : record_out })
421421 written += 1
422- if dump_every > 0 and (written % dump_every == 0 ):
423- f .flush ()
424-
425- print (f"Saved to { out_jsonl } . Rows considered: { len (df )} ; written: { written } ; branch={ branch } ; model={ model } ." )
426- return out_jsonl
422+
423+ # Dump to parquet periodically
424+ if dump_every > 0 and len (buffer ) >= dump_every :
425+ try :
426+ new_df = pd .DataFrame (buffer )
427+ if os .path .exists (out_filename ):
428+ existing_df = pd .read_parquet (out_filename )
429+ combined_df = pd .concat ([existing_df , new_df ], ignore_index = True )
430+ else :
431+ combined_df = new_df
432+
433+ combined_df .to_parquet (out_filename , index = False )
434+ buffer = [] # Clear buffer after successful write
435+ except Exception as e :
436+ logging .error (f"Failed to write parquet batch: { e } " )
437+
438+ # Final flush
439+ if buffer :
440+ try :
441+ new_df = pd .DataFrame (buffer )
442+ if os .path .exists (out_filename ):
443+ existing_df = pd .read_parquet (out_filename )
444+ combined_df = pd .concat ([existing_df , new_df ], ignore_index = True )
445+ else :
446+ combined_df = new_df
447+ combined_df .to_parquet (out_filename , index = False )
448+ except Exception as e :
449+ logging .error (f"Failed to write final parquet batch: { e } " )
450+
451+ print (f"Saved to { out_filename } . Rows considered: { len (df )} ; written: { written } ; branch={ branch } ; model={ model } ." )
452+ return out_filename
0 commit comments