diff --git a/pyproject.toml b/pyproject.toml index a08d956f..8c4966f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -178,6 +178,9 @@ select = [ "src/olmo_eval/evals/tasks/squad.py" = ["E501"] # Static fewshot data with long string literals "src/olmo_eval/evals/tasks/constants/*" = ["E501"] +# Verbatim tables/templates vendored from the mm_olmo reference implementation +"src/olmo_eval/common/image_qa/*" = ["E501"] +"tests/core/test_image_qa_scorers.py" = ["E501"] [tool.ruff.format] docstring-code-format = true diff --git a/src/olmo_eval/common/image_qa/__init__.py b/src/olmo_eval/common/image_qa/__init__.py new file mode 100644 index 00000000..412da5c6 --- /dev/null +++ b/src/olmo_eval/common/image_qa/__init__.py @@ -0,0 +1,78 @@ +"""Pure-python scoring and prompt utilities for the Molmo2 image-QA benchmarks. + +Everything here is vendored from the mm_olmo reference implementation (no +mm_olmo imports) and is dependency-light so it stays unit-testable without +datasets/torch. +""" + +from olmo_eval.common.image_qa.count_parsing import ( + WORD_TO_NUM, + extract_image_points, + parse_count, +) +from olmo_eval.common.image_qa.math_vista_offline import ( + DEMO_PROMPT, + create_test_prompt, + extract_answer_offline, + extract_answer_quick, + math_vista_score_from_extraction, + math_vista_score_offline, + normalize_extracted_answer, + safe_equal, +) +from olmo_eval.common.image_qa.mmmu_parsing import ( + eval_multi_choice, + eval_open, + mmmu_score, + parse_multi_choice_response, + parse_open_response, +) +from olmo_eval.common.image_qa.prompt_templates import ( + EVAL_LOADER_SEED, + POINT_COUNT_TEMPLATES, + format_mc_question, + pixmo_count_question, +) +from olmo_eval.common.image_qa.vqa_normalization import ( + anls_metric, + clean_prediction, + levenshtein, + preprocess_answer, + real_world_qa_score, + relaxed_correctness, + scifi_relaxed_correctness, + select_mc_option, + vqa_score, +) + +__all__ = [ + "DEMO_PROMPT", + "EVAL_LOADER_SEED", + "POINT_COUNT_TEMPLATES", + "WORD_TO_NUM", + "anls_metric", + "clean_prediction", + "create_test_prompt", + "eval_multi_choice", + "eval_open", + "extract_answer_offline", + "extract_answer_quick", + "extract_image_points", + "format_mc_question", + "levenshtein", + "math_vista_score_from_extraction", + "math_vista_score_offline", + "mmmu_score", + "normalize_extracted_answer", + "parse_count", + "parse_multi_choice_response", + "parse_open_response", + "pixmo_count_question", + "preprocess_answer", + "real_world_qa_score", + "safe_equal", + "relaxed_correctness", + "scifi_relaxed_correctness", + "select_mc_option", + "vqa_score", +] diff --git a/src/olmo_eval/common/image_qa/count_parsing.py b/src/olmo_eval/common/image_qa/count_parsing.py new file mode 100644 index 00000000..a9623177 --- /dev/null +++ b/src/olmo_eval/common/image_qa/count_parsing.py @@ -0,0 +1,146 @@ +"""Count parsing for CountBench QA / PixMo Count (the ``point_count`` style). + +Vendored from ``mm_olmo/olmo/eval/molmo_prediction_evaluators.py`` +(``PointCountEval``) and the universal point-extraction regexes in +``mm_olmo/olmo/preprocessing/point_formatter.py`` (``UnifiedPointFormatter`` +and ``PointFormattingV1``). Behavior is preserved exactly. + +The parse ladder for a predicted count: + 1. last whitespace token as int + 2. last token as a number word ("one" … "twenty") + 3. ``"a total of N"`` regex + 4. a bare "none" → 0 + 5. fall back to counting the points the model emitted +""" + +from __future__ import annotations + +import contextlib +import re + +WORD_TO_NUM = { + "one": 1, + "two": 2, + "three": 3, + "four": 4, + "five": 5, + "six": 6, + "seven": 7, + "eight": 8, + "nine": 9, + "zero": 0, + "ten": 10, + "eleven": 11, + "twelve": 12, + "thirteen": 13, + "fourteen": 14, + "fifteen": 15, + "sixteen": 16, + "seventeen": 17, + "eighteen": 18, + "nineteen": 19, + "twenty": 20, +} + +# --- UnifiedPointFormatter regexes (coordinate_scale="1000", image_sep="\t") --- +_COORD_REGEX = re.compile(r"<(?:points|tracks|bboxes).*? coords=\"([0-9\t:;, .]+)\"/?>") +_FRAME_REGEX = re.compile(r"(?:^|\t|:|,|;)([0-9\.]+) ([0-9\. ]+)") +_POINTS_REGEX = re.compile(r"([0-9]+) ([0-9]{3,4}) ([0-9]{3,4})") + + +def _extract_points_unified(text: str, image_w: float, image_h: float) -> list[tuple[float, float]]: + all_points: list[tuple[float, float]] = [] + for coord in _COORD_REGEX.finditer(text): + for point_grp in _FRAME_REGEX.finditer(coord.group(1)): + for triplet in _POINTS_REGEX.finditer(point_grp.group(2)): + x = float(triplet.group(2)) / 1000 * image_w + y = float(triplet.group(3)) / 1000 * image_h + if 0 <= x <= image_w and 0 <= y <= image_h: + all_points.append((x, y)) + return all_points + + +def _extract_points_v1(text: str, image_w: float, image_h: float) -> list[tuple[float, float]]: + """Legacy point-format fallback chain (``PointFormattingV1``).""" + all_points: list[tuple[float, float]] = [] + + def _scaled(x: float, y: float, bound: float, scale: float) -> tuple[float, float] | None: + if max(x, y) > bound: + return None # treat as an invalid output + return x / scale * image_w, y / scale * image_h + + for match in re.finditer(r"Click\(([0-9]+\.[0-9]), ?([0-9]+\.[0-9])\)", text): + point = _scaled(float(match.group(1)), float(match.group(2)), 100, 100.0) + if point is not None: + all_points.append(point) + if all_points: + return all_points + + for match in re.finditer(r"[0-9]+ ([0-9]{3}) ([0-9]{3})", text): + point = _scaled(float(match.group(1)), float(match.group(2)), 1000, 1000.0) + if point is not None: + all_points.append(point) + if all_points: + return all_points + + for match in re.finditer(r"[0-9]+ ([0-9]+\.[0-9]) ([0-9]+\.[0-9])", text): + point = _scaled(float(match.group(1)), float(match.group(2)), 100, 100.0) + if point is not None: + all_points.append(point) + if all_points: + return all_points + + for match in re.finditer(r"\(([0-9]+\.[0-9]),? ?([0-9]+\.[0-9])\)", text): + point = _scaled(float(match.group(1)), float(match.group(2)), 100, 100.0) + if point is not None: + all_points.append(point) + for match in re.finditer( + r'x\d*="\s*([0-9]+(?:\.[0-9]+)?)"\s+y\d*="\s*([0-9]+(?:\.[0-9]+)?)"', text + ): + point = _scaled(float(match.group(1)), float(match.group(2)), 100, 100.0) + if point is not None: + all_points.append(point) + for match in re.finditer(r"(?:\d+|p)\s*=\s*([0-9]{3})\s*,\s*([0-9]{3})", text): + point = _scaled(int(match.group(1)) / 10.0, int(match.group(2)) / 10.0, 100, 100.0) + if point is not None: + all_points.append(point) + return all_points + + +def extract_image_points( + text: str, image_w: float = 100, image_h: float = 100 +) -> list[tuple[float, float]]: + """Universal point extraction: unified format first, then legacy formats.""" + points = _extract_points_unified(text, image_w, image_h) + if points: + return points + return _extract_points_v1(text, image_w, image_h) + + +def parse_count(original_pred: str) -> int: + """Parse the predicted count from a ``point_count``-style response.""" + pred = original_pred.lower().rstrip(".").strip() + pred_int: int | None = None + parts = pred.split() + + if parts: + with contextlib.suppress(ValueError): + pred_int = int(parts[-1].strip(". ")) + + if pred_int is None and parts[-1] in WORD_TO_NUM: + pred_int = WORD_TO_NUM[parts[-1]] + + if pred_int is None: + match = re.match(".*a total of ([0-9]+).*", pred) + if match: + pred_int = int(match.group(1)) + + if pred_int is None: + match = re.match(".*\\bnone\\b.*", pred, re.IGNORECASE) + if match: + pred_int = 0 + + if pred_int is None: + pred_int = len(extract_image_points(pred, 100, 100)) + + return pred_int diff --git a/src/olmo_eval/common/image_qa/math_vista_offline.py b/src/olmo_eval/common/image_qa/math_vista_offline.py new file mode 100644 index 00000000..368b97d3 --- /dev/null +++ b/src/olmo_eval/common/image_qa/math_vista_offline.py @@ -0,0 +1,218 @@ +"""MathVista scoring: offline answer extraction + official normalization. + +Vendored from ``mm_olmo/olmo/eval/math_vista_utils.py`` and the offline +(`use_api=False`) branch of ``math_vista_score`` in ``mm_olmo/olmo/eval/vqa.py``. + +The official MathVista protocol extracts the final answer from the model +response with GPT-4 (``gpt-4-0613``) before comparison; :data:`DEMO_PROMPT` +and :func:`create_test_prompt` are vendored here for the optional GPT-backed +scorer. The offline path replaces only the extraction step (letter matching +for multiple choice, int/float parsing otherwise); normalization and the +final comparison are identical. +""" + +from __future__ import annotations + +import re + +from olmo_eval.common.image_qa.vqa_normalization import levenshtein, select_mc_option + +DEMO_PROMPT = """ +Please read the following example. Then extract the answer from the model response and type it at the end of the prompt. + +Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end. +Question: Which number is missing? + +Model response: The number missing in the sequence is 14. + +Extracted answer: 14 + +Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end. +Question: What is the fraction of females facing the camera? + +Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera. + +Extracted answer: 0.6 + +Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end. +Question: How much money does Luca need to buy a sour apple candy and a butterscotch candy? (Unit: $) + +Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy. + +Extracted answer: 1.45 + +Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end. +Question: Between which two years does the line graph saw its maximum peak? + +Model response: The line graph saw its maximum peak between 2007 and 2008. + +Extracted answer: [2007, 2008] + +Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end. +Question: What fraction of the shape is blue? +Choices: +A: 3/11 +B: 8/11 +C: 6/11 +D: 3/5 + +Model response: The correct answer is B: 8/11. + +Extracted answer: B +""" + + +def create_test_prompt(query: str, response: str) -> str: + """Build the GPT-4 answer-extraction prompt (official MathVista).""" + demo = DEMO_PROMPT.strip() + test_prompt = f"{query}\n\n{response}" + return f"{demo}\n\n{test_prompt}\n\nExtracted answer: " + + +def get_most_similar(prediction: str, choices: list[str]) -> str: + """Return the choice closest to ``prediction`` by edit distance.""" + distances = [levenshtein(prediction, choice) for choice in choices] + return choices[distances.index(min(distances))] + + +def normalize_extracted_answer( + extraction, + choices: list[str], + question_type: str, + answer_type: str, + precision, +): + """Normalize the extracted answer to match the answer type (official).""" + if question_type == "multi_choice": + if isinstance(extraction, str): + extraction = extraction.strip() + else: + try: + extraction = str(extraction) + except Exception: + extraction = "" + + # extract "A" from "(A) text" + letter = re.findall(r"([a-zA-Z]):", extraction) + if len(letter) > 0: + extraction = letter[0].upper() + + options = [chr(ord("A") + i) for i in range(len(choices))] + + if extraction in options: + ind = options.index(extraction) + extraction = choices[ind] + else: + extraction = get_most_similar(extraction, choices) + assert extraction in choices + + elif answer_type == "integer": + try: + extraction = str(int(float(extraction))) + except Exception: + extraction = None + + elif answer_type == "float": + try: + extraction = str(round(float(extraction), precision)) + except Exception: + extraction = None + + elif answer_type == "list": + try: + extraction = str(extraction) + except Exception: + extraction = None + + return extraction + + +def safe_equal(prediction, answer) -> bool: + """Compare prediction and answer, tolerating type mismatches.""" + try: + return prediction == answer + except Exception: + return False + + +def extract_answer_offline( + response: str, + question_type: str, + answer_type: str, + choices: list[str], +) -> str: + """Offline answer extraction (no GPT call). + + Applies the official ``extract_answer`` deterministic short-circuits first + (empty response, response verbatim in choices, int/float parsing), then + falls back to the mm_olmo ``use_api=False`` branch: letter matching via + :func:`select_mc_option` for multiple choice, raw response otherwise. + """ + quick = extract_answer_quick(response, question_type, answer_type, choices) + if quick is not None: + return quick + if question_type == "multi_choice": + options = [chr(ord("A") + i) for i in range(len(choices))] + pred_idx = select_mc_option(response, options) + return choices[pred_idx] + return response + + +def extract_answer_quick( + response: str, + question_type: str, + answer_type: str, + choices: list[str], +) -> str | None: + """The deterministic pre-GPT short-circuits of the official ``extract_answer``. + + Returns None when GPT extraction would be required. + """ + if response == "": + return "" + if question_type == "multi_choice" and response in choices: + return response + if answer_type == "integer": + try: + return str(int(response)) + except Exception: + pass + if answer_type == "float": + try: + return str(float(response)) + except Exception: + pass + return None + + +def math_vista_score_offline( + response: str, + *, + question_type: str, + answer_type: str, + choices: list[str], + precision, + target, +) -> bool: + """Score one MathVista example with offline extraction.""" + extraction = extract_answer_offline(response, question_type, answer_type, choices) + prediction = normalize_extracted_answer( + extraction, choices, question_type, answer_type, precision + ) + return safe_equal(prediction, target) + + +def math_vista_score_from_extraction( + extraction, + *, + question_type: str, + answer_type: str, + choices: list[str], + precision, + target, +) -> bool: + """Score from an already-extracted answer (offline or GPT-based).""" + prediction = normalize_extracted_answer( + extraction, choices, question_type, answer_type, precision + ) + return safe_equal(prediction, target) diff --git a/src/olmo_eval/common/image_qa/mmmu_parsing.py b/src/olmo_eval/common/image_qa/mmmu_parsing.py new file mode 100644 index 00000000..2d532f23 --- /dev/null +++ b/src/olmo_eval/common/image_qa/mmmu_parsing.py @@ -0,0 +1,269 @@ +"""MMMU response parsing and scoring. + +Vendored from ``mm_olmo/olmo/eval/mmmu_eval_utils.py`` (itself adapted from +the official MMMU repository) plus the ``mmmu_score`` dispatcher from +``mm_olmo/olmo/eval/vqa.py``. + +One intentional change: the original falls back to ``random.choice`` (with a +module-level ``random.seed(42)``) when a multiple-choice response cannot be +parsed, which makes scores depend on scoring order. Here the fallback uses a +``random.Random`` seeded per call from a stable instance id so results are +deterministic and order-independent. +""" + +from __future__ import annotations + +import random +import re +import string +import zlib + + +def _argmax(values: list[int]) -> int: + """Index of the maximum value, first occurrence (matches np.argmax).""" + best_ix = 0 + for ix in range(1, len(values)): + if values[ix] > values[best_ix]: + best_ix = ix + return best_ix + + +def _fallback_rng(stable_id: str | None) -> random.Random: + seed = 42 if stable_id is None else zlib.crc32(stable_id.encode("utf-8")) + return random.Random(seed) + + +# ----------- Process Multi-choice ------------- +def parse_multi_choice_response( + response: str, + all_choices: list[str], + index2ans: dict[str, str], + stable_id: str | None = None, +) -> str: + """Parse the predicted option letter (e.g. A/B/C/D) from a response.""" + for char in [",", ".", "!", "?", ";", ":", "'"]: + response = response.strip(char) + response = " " + response + " " # add space to avoid partial match + + index_ans = True + ans_with_brack = False + ans_with_last_brack = False + ans_with_dot = False + ans_with_colon = False + candidates = [] + for choice in all_choices: # e.g., (A) (B) (C) (D) + if f"({choice})" in response: + candidates.append(choice) + ans_with_brack = True + + for choice in all_choices: # e.g., A), B), C), D) + if f"{choice})" in response: + candidates.append(choice) + ans_with_last_brack = True + + for choice in all_choices: # e.g., A. B. C. D. + if f"{choice}." in response: + candidates.append(choice) + ans_with_dot = True + + for choice in all_choices: # e.g., A: B: C: D: + if f"{choice}:" in response: + candidates.append(choice) + ans_with_colon = True + + if len(candidates) == 0: + for choice in all_choices: # e.g., A B C D + if response.strip() == choice: + return choice + if f" {choice} " in response: + candidates.append(choice) + + # if all above doesn't get candidates, check if the content is larger than + # 5 tokens and try to parse the example + if len(candidates) == 0 and len(response.split()) > 5: + for index, ans in index2ans.items(): + if ans.lower() in response.lower(): + candidates.append(index) + index_ans = False # it's content ans. + + if len(candidates) == 0: # still not get answer, choose one deterministically. + pred_index = _fallback_rng(stable_id).choice(all_choices) + elif len(candidates) > 1: + start_indexes = [] + if index_ans: + if ans_with_brack: + for can in candidates: + start_indexes.append(response.rfind(f"({can})")) + elif ans_with_last_brack: + for can in candidates: + start_indexes.append(response.rfind(f"{can})")) + elif ans_with_dot: + for can in candidates: + start_indexes.append(response.rfind(f"{can}.")) + elif ans_with_colon: + for can in candidates: + start_indexes.append(response.rfind(f"{can}:")) + else: + for can in candidates: + start_indexes.append(response.rfind(f" {can} ")) + else: + for can in candidates: + start_indexes.append(response.lower().rfind(index2ans[can].lower())) + # get the last one + pred_index = candidates[_argmax(start_indexes)] + else: # if only one candidate, use it. + pred_index = candidates[0] + + return pred_index + + +# ----------- Process Open ------------- +def check_is_number(value: str) -> bool: + """Check if the given string is a number.""" + try: + float(value.replace(",", "")) + return True + except ValueError: + return False + + +def normalize_str(value: str) -> list: + """Normalize a string to lower case, converting to float when possible.""" + value = value.strip() + + if check_is_number(value): + value = value.replace(",", "") + number = round(float(value), 2) + return [number] + value = value.lower() + if len(value) == 1: + return [" " + value, value + " "] # avoid trivial matches + return [value] + + +def extract_numbers(value: str) -> list[str]: + """Extract all forms of numbers from a string with regex.""" + pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b" + pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+" + pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])" + + numbers_with_commas = re.findall(pattern_commas, value) + numbers_scientific = re.findall(pattern_scientific, value) + numbers_simple = re.findall(pattern_simple, value) + + return numbers_with_commas + numbers_scientific + numbers_simple + + +def parse_open_response(response: str) -> list: + """Parse predicted strings/numbers from an open-ended response.""" + + def get_key_subresponses(resp_text: str) -> list[str]: + resp_text = resp_text.strip().strip(".").lower() + sub_responses = re.split(r"\.\s(?=[A-Z])|\n", resp_text) + indicators_of_keys = [ + "could be ", + "so ", + "is ", + "thus ", + "therefore ", + "final ", + "answer ", + "result ", + ] + key_responses = [] + for index, resp in enumerate(sub_responses): + # if last one, accept it's an equation (the entire response can be + # just one sentence with equation) + if index == len(sub_responses) - 1: + indicators_of_keys.extend(["="]) + shortest_key_response = None + for indicator in indicators_of_keys: + if indicator in resp: + if not shortest_key_response: + shortest_key_response = resp.split(indicator)[-1].strip() + else: + if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response): + shortest_key_response = resp.split(indicator)[-1].strip() + + # accept the shortest key response if it's not trivial + if shortest_key_response and shortest_key_response.strip() not in [ + ":", + ",", + ".", + "!", + "?", + ";", + ":", + "'", + ]: + key_responses.append(shortest_key_response) + if len(key_responses) == 0: + return [resp_text] + return key_responses + + key_responses = get_key_subresponses(response) + + pred_list = key_responses.copy() # keep the original string response + for resp in key_responses: + pred_list.extend(extract_numbers(resp)) + + tmp_pred_list = [] + for pred in pred_list: + tmp_pred_list.extend(normalize_str(pred)) + + # remove duplicates + return list(set(tmp_pred_list)) + + +# ----------- Evaluation ------------- +def eval_multi_choice(gold_i: list | str, pred_i: str) -> bool: + """Evaluate a multiple-choice instance.""" + if isinstance(gold_i, list): + return any(answer == pred_i for answer in gold_i) + return gold_i == pred_i + + +def eval_open(gold_i: list | str, pred_i: list) -> bool: + """Evaluate an open-question instance.""" + correct = False + if isinstance(gold_i, list): + norm_answers = [] + for answer in gold_i: + norm_answers.extend(normalize_str(answer)) + else: + norm_answers = normalize_str(gold_i) + for pred in pred_i: # pred is already normalized in parse response phase + if isinstance(pred, str): # if it's a string, then find if ans in the pred_i + for norm_ans in norm_answers: + if isinstance(norm_ans, str) and norm_ans in pred: + if not correct: + correct = True + break + else: # it's a float number + if pred in norm_answers: + if not correct: + correct = True + break + return correct + + +def mmmu_score( + target: list[str] | str, + response: str, + question_type: str, + options: list[str], + stable_id: str | None = None, +) -> float: + """Score one MMMU example following the official protocol.""" + if question_type == "multiple-choice": + options = [opt for opt in options if len(opt) > 0] + all_choices = list(string.ascii_uppercase[: len(options)]) + index2ans = dict(zip(all_choices, options, strict=False)) + parsed_pred = parse_multi_choice_response( + response, all_choices, index2ans, stable_id=stable_id + ) + correct = eval_multi_choice(target, parsed_pred) + else: # open + parsed_pred = parse_open_response(response) + correct = eval_open(target, parsed_pred) + return float(correct) diff --git a/src/olmo_eval/common/image_qa/prompt_templates.py b/src/olmo_eval/common/image_qa/prompt_templates.py new file mode 100644 index 00000000..93cf89de --- /dev/null +++ b/src/olmo_eval/common/image_qa/prompt_templates.py @@ -0,0 +1,137 @@ +"""Prompt construction for the Molmo2 image-QA benchmarks. + +Vendored from ``mm_olmo/olmo/data/data_formatter.py``: + +* :data:`POINT_COUNT_TEMPLATES` — the ``GENERAL_PROMPTS_V1["point_count"]`` + list, verbatim (60 entries; duplicates and typos are intentional — the + template *index* is what the seeded RNG selects). +* :func:`pixmo_count_question` — replicates the per-example template choice of + mm_olmo's eval data pipeline (``DeterministicDataset`` seed arithmetic with + the eval loader seed 691203, then ``rng.randint`` over the template list). +* :func:`format_mc_question` — the eval branch of ``template_options`` + (``"Only return the correct answer option."``). + +Style-prefix rules (``demo_or_style_v2`` system prompts): short-answer styles +(``vqa2``, ``chart_qa``, ``doc_qa``, ``info_qa``, ``text_vqa``) are rendered +as ``"{style}: {question}"``; multiple-choice and counting styles get no +prefix. Tasks bake the prefix into the instance question directly. +""" + +from __future__ import annotations + +import string + +import numpy as np + +EVAL_LOADER_SEED = 691203 +"""DataLoaderConfig seed used by mm_olmo's eval pipeline.""" + +POINT_COUNT_TEMPLATES: list[str] = [ + "How many {label} are there?", + "How many {label}?", + "How many {label}.", + "how many {label}.", + "how many {label}?", + 'How many "{label}" are there in the image?', + "How many {label} are there in the image?", + "Tell me how many {label} there are", + "Tell me how many {label} there are and point to them.", + "how many {label}", + "Tell me where each {label} is.", + "Tell me how many {label} are in the image", + "count {label}", + "count every {label}", + "count each {label}", + "count {label}.", + "Count the {label}.", + "How many {label} do you see?", + "How many {label} are visible?", + "Count all the {label}", + "how mmny {label}?", + "Count every {label} in the picture.", + "Count all the {label}", + "Count each {label}", + "Point to and count the {label} in the picture.", + "Point and count {label}", + "Point to every {label}", + "Locate the {label} and count them", + "Locate every {label} and count them", + "Find all the {label}. How many are there?", + "Find each {label}. How many are there?", + "Point at {label} and then tell me the count.", + "What is the total number of {label} in the image?", + "What is the number of {label}?", + "In this image, how many {label} are there?", + "In all the picture, how many {label} are there?", + "Point at the {label} and then count them.", + "Point to all the visible {label} output the total count.", + "Point to all the {label} visible and output the total count. \nPlease say 'There are none.' if it is not in the image.", + 'Point to all occurrences of "{label}" and output the total count.', + "Show me where the {label} are and output the total count.", + "Where are the {label}? How many are there?", + "Generate list of points showing where the {label} are and output the total count.", + "Object: {label}\nInstruction: Point to the object and output the total count.", + "find any {label} in the picture and output the total count.", + "Can you see any {label} in the image? Point to them and output the total count.", + "Can you point out all {label} in this image? How many are there?", + "If there are any {label} present, indicate their positions and output the total count.", + "How many {label} are there in the image? Point to them and output the total count.", + "How many {label} are there in the image?", + "Give me the count of {label} in the image.", + "How many {label} are visible in the image?", + "How many {label} are there?", + "In the image, how many {label} are there?", + "Can you count the number of {label} in the image?", + "Can you count every {label} in the picture?", + "Can you see any {label} in the image? How many are there?", + "Are there any {label} in the image? How many are there?", + "If you see any {label} in the image, give me the count. Otherwise, say 'There are none.'", + "Object: {label}\nInstruction: How many are there?", +] + + +def _apply_label(template: str, label: str) -> str: + """``apply_keywords`` from mm_olmo: replaces only the first occurrence.""" + res = template.split("{label}", 2) + return res[0] + label + res[1] + + +def pixmo_count_question(label: str, arrow_idx: int, seed: int = EVAL_LOADER_SEED) -> str: + """Reproduce mm_olmo's per-example PixMo-Count question template. + + ``arrow_idx`` is the example's position in the on-disk arrow dataset; the + RNG seed arithmetic matches ``DeterministicDataset.get`` (epoch 0) and the + template pick matches ``apply_keyword_prompt``'s ``rng.randint``. + """ + rng = np.random.RandomState((seed * 195172 + arrow_idx) % (2**32 - 1)) + template = POINT_COUNT_TEMPLATES[rng.randint(0, len(POINT_COUNT_TEMPLATES))] + return _apply_label(template, label.lower()) + + +def format_mc_question( + question: str, + options: list[str], + *, + labelled: bool = True, +) -> tuple[str, str | list[str]]: + """Eval-time multiple-choice templating (``template_options`` eval branch). + + Returns ``(formatted_question, option_names)``. With ``labelled=True`` + options are rendered as ``A. …`` lines and ``option_names`` is the letter + string (e.g. ``"ABCD"``, matching mm_olmo where it is a slice of + ``string.ascii_uppercase``). With ``labelled=False`` (AI2D + ``ai2_diagram_no_letter``) options are listed verbatim and + ``option_names`` is the option list itself. + """ + if labelled: + prefixes = string.ascii_uppercase + # zip-shortest on purpose: prefixes covers up to 26 options + option_text = "\n".join( + f"{prefix}. {opt}" for prefix, opt in zip(prefixes, options, strict=False) + ) + option_names: str | list[str] = prefixes[: len(options)] + else: + option_text = "\n".join(options) + option_names = options + formatted = question + "\nOnly return the correct answer option.\n" + option_text + return formatted, option_names diff --git a/src/olmo_eval/common/image_qa/vqa_normalization.py b/src/olmo_eval/common/image_qa/vqa_normalization.py new file mode 100644 index 00000000..822f8970 --- /dev/null +++ b/src/olmo_eval/common/image_qa/vqa_normalization.py @@ -0,0 +1,438 @@ +"""VQA-family answer normalization and matching metrics. + +Vendored from ``mm_olmo/olmo/eval/vqa.py`` (reference implementation for the +Molmo2 image-QA benchmarks). The normalization tables and several regex +quirks are preserved byte-for-byte so scores match the original evaluation +exactly — do not "fix" them. + +The only intentional change is that ``editdistance.eval`` is replaced by the +pure-python :func:`levenshtein` below to avoid a new dependency. +""" + +from __future__ import annotations + +import re +from collections import Counter +from collections.abc import Sequence + +# --------------------------------------------------------------------------- +# Official VQA v2 normalization tables (verbatim) +# --------------------------------------------------------------------------- + +CONTRACTIONS = { + "aint": "ain't", + "arent": "aren't", + "cant": "can't", + "couldve": "could've", + "couldnt": "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + "didnt": "didn't", + "doesnt": "doesn't", + "dont": "don't", + "hadnt": "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + "hasnt": "hasn't", + "havent": "haven't", + "hed": "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + "hes": "he's", + "howd": "how'd", + "howll": "how'll", + "hows": "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + "Im": "I'm", + "Ive": "I've", + "isnt": "isn't", + "itd": "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + "itll": "it'll", + "let's": "let's", + "maam": "ma'am", + "mightnt": "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + "mightve": "might've", + "mustnt": "mustn't", + "mustve": "must've", + "neednt": "needn't", + "notve": "not've", + "oclock": "o'clock", + "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + "shant": "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + "shouldve": "should've", + "shouldnt": "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + "somebodyll": "somebody'll", + "somebodys": "somebody's", + "someoned": "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + "someonell": "someone'll", + "someones": "someone's", + "somethingd": "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + "somethingll": "something'll", + "thats": "that's", + "thered": "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + "therere": "there're", + "theres": "there's", + "theyd": "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + "theyll": "they'll", + "theyre": "they're", + "theyve": "they've", + "twas": "'twas", + "wasnt": "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + "weve": "we've", + "werent": "weren't", + "whatll": "what'll", + "whatre": "what're", + "whats": "what's", + "whatve": "what've", + "whens": "when's", + "whered": "where'd", + "wheres": "where's", + "whereve": "where've", + "whod": "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + "wholl": "who'll", + "whos": "who's", + "whove": "who've", + "whyll": "why'll", + "whyre": "why're", + "whys": "why's", + "wont": "won't", + "wouldve": "would've", + "wouldnt": "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + "yall": "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + "youd": "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + "youll": "you'll", + "youre": "you're", + "youve": "you've", +} + +MANUAL_MAP = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", +} + +ARTICLES = ["a", "an", "the"] + +PUNCT = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", +] + +# NOTE: both regex quirks below are upstream VQA-eval bugs preserved on +# purpose: `(?!<=\d)` is a (useless) negative lookahead for the literal text +# "<=", not the intended look-behind, and the original code passes +# ``re.UNICODE`` (== 32) as the *count* argument of ``periodStrip.sub``. +_PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)") +_COMMA_STRIP = re.compile(r"(\d)(\,)(\d)") + + +def process_punctuation(in_text: str) -> str: + out_text = in_text + for p in PUNCT: + if (p + " " in in_text or " " + p in in_text) or ( + re.search(_COMMA_STRIP, in_text) is not None + ): + out_text = out_text.replace(p, "") + else: + out_text = out_text.replace(p, " ") + out_text = _PERIOD_STRIP.sub("", out_text, re.UNICODE) + return out_text + + +def process_digit_article(in_text: str) -> str: + out_text = [] + temp_text = in_text.lower().split() + for word in temp_text: + word = MANUAL_MAP.setdefault(word, word) + if word not in ARTICLES: + out_text.append(word) + for word_id, word in enumerate(out_text): + if word in CONTRACTIONS: + out_text[word_id] = CONTRACTIONS[word] + return " ".join(out_text) + + +_PREPROCESS_CACHE: dict[str, str] = {} + + +def preprocess_answer(ans: str) -> str: + """Official VQA v2 answer normalization (cached).""" + if ans in _PREPROCESS_CACHE: + return _PREPROCESS_CACHE[ans] + out = ans.replace("\n", " ").replace("\t", " ").lower().strip() + preprocessed = process_digit_article(process_punctuation(out)) + _PREPROCESS_CACHE[ans] = preprocessed + return preprocessed + + +# --------------------------------------------------------------------------- +# Edit distance (replaces the `editdistance` dependency) +# --------------------------------------------------------------------------- + + +def levenshtein(a: str, b: str) -> int: + """Plain Levenshtein edit distance (insert/delete/substitute, cost 1).""" + if a == b: + return 0 + if not a: + return len(b) + if not b: + return len(a) + prev = list(range(len(b) + 1)) + for i, ca in enumerate(a, start=1): + cur = [i] + for j, cb in enumerate(b, start=1): + cur.append(min(prev[j] + 1, cur[j - 1] + 1, prev[j - 1] + (ca != cb))) + prev = cur + return prev[-1] + + +def _argmin(values: Sequence[float]) -> int: + """Index of the minimum value, first occurrence (matches np.argmin).""" + best_ix = 0 + for ix in range(1, len(values)): + if values[ix] < values[best_ix]: + best_ix = ix + return best_ix + + +# --------------------------------------------------------------------------- +# Prediction cleanup (from mm_olmo VqaEval.__call__) +# --------------------------------------------------------------------------- + + +def clean_prediction(pred: str) -> str: + """Cleanup applied by the original ``VqaEval`` before every metric. + + Strips whitespace; takes the text after the first "Answer:" if present; + for multi-line output keeps the most frequent line; otherwise collapses + inner whitespace. Counting/MathVista evaluators do NOT use this. + """ + pred = pred.strip() + if "Answer:" in pred: + pred = pred.split("Answer:")[1].strip() + elif "\n" in pred: + preds = [" ".join(x.strip().split()) for x in pred.split("\n")] + counts = Counter(preds) + max_count = max(counts.values()) + pred = [x for x in preds if counts[x] == max_count][0] + else: + pred = " ".join(pred.strip().split()) + return pred + + +# --------------------------------------------------------------------------- +# Metrics (verbatim logic) +# --------------------------------------------------------------------------- + + +def vqa_score(target: list[str] | str, pred: str) -> float: + """Official VQA v2 accuracy: min(#matching annotator answers / 3, 1).""" + pred = preprocess_answer(pred) + if isinstance(target, list): + counts = Counter(preprocess_answer(x) for x in target) + return min(counts[pred] / 3.0, 1) + return float(pred == target) + + +def select_mc_option(target: str, options: list[str] | str) -> int: + """Select a multiple-choice option index from the model output. + + Exact match, then unique prefix containment in both directions, then + unique substring, then minimum edit distance. + """ + target = target.lower().strip() + n = len(options) + options = [x.lower().strip() for x in options] + assert len(set(options)) == n + + for ix, option in enumerate(options): + if option == target: + return ix + + contains = [ix for ix, option in enumerate(options) if target.startswith(option)] + if len(contains) == 1: + return contains[0] + + contains = [ix for ix, option in enumerate(options) if option.startswith(target)] + if len(contains) == 1: + return contains[0] + + contains = [ix for ix, option in enumerate(options) if target in option] + if len(contains) == 1: + return contains[0] + + distances = [levenshtein(opt, target) for opt in options] + return _argmin(distances) + + +def anls_metric(target: str, prediction: str, theta: float = 0.5) -> float: + """ANLS for DocVQA/InfographicVQA (case-insensitive, θ=0.5).""" + if not target and not prediction: + # Degenerate case (placeholder test-split answers); mm_olmo would + # divide by zero here. Treat two empty strings as an exact match. + return 1.0 + edit_distance = levenshtein(target.lower(), prediction.lower()) + normalized_ld = edit_distance / max(len(target), len(prediction)) + return 1 - normalized_ld if normalized_ld < theta else 0 + + +def relaxed_correctness(target: str, prediction: str, max_relative_change: float = 0.05) -> bool: + """ChartQA relaxed accuracy: 5% numeric tolerance, exact match otherwise.""" + + def _to_float(text: str) -> float | None: + try: + if text.endswith("%"): + return float(text.rstrip("%")) / 100.0 + return float(text) + except ValueError: + return None + + prediction_float = _to_float(prediction) + target_float = _to_float(target) + if prediction_float is not None and target_float: + relative_change = abs(prediction_float - target_float) / abs(target_float) + return relative_change <= max_relative_change + return prediction.lower() == target.lower() + + +def scifi_relaxed_correctness( + target: str, prediction: str, max_relative_change: float = 0.05 +) -> bool: + """Lenient ChartQA variant: number extraction, word→digit, ÷100, substring.""" + + def _to_float(text: str) -> float | None: + try: + return float(text) + except ValueError: + return None + + def compute_relative_change(target_f: float, prediction_f: float) -> float: + if target_f == 0: + return abs(target_f - prediction_f) + return abs(target_f - prediction_f) / abs(target_f) + + def extract_short_answer(text: str) -> str: + if "answer:" in text: + return text.split("answer:")[1].strip() + return text + + prediction = extract_short_answer(prediction.lower().strip()) + target = extract_short_answer(target.lower().strip()) + + if len(prediction) == 0: + return False + + if prediction[-1] == ".": + prediction = prediction[:-1] + + word_to_num = {k: v for k, v in MANUAL_MAP.items() if k != "none"} + + target_float = _to_float(target) + if target_float is not None: + if "," in prediction: + prediction = prediction.replace(",", "") + + for word, num in word_to_num.items(): + prediction = prediction.replace(word, str(num)) + + match = re.search(r"[-+]?\d*\.\d+|\d+", prediction) + prediction_float = _to_float(match.group()) if match else None + if prediction_float is None: + return False + + relative_change = compute_relative_change(target_float, prediction_float) + + prediction_float_normalized = prediction_float / 100 + relative_change_normalized = compute_relative_change( + target_float, prediction_float_normalized + ) + + return bool( + relative_change <= max_relative_change + or relative_change_normalized <= max_relative_change + ) + + if "[" in target and "," in target: + # target is a list + targets = target.replace("[", "").replace("]", "").split(",") + return all(t.strip().lower() in prediction for t in targets) + + return target.strip().lower() in prediction + + +def real_world_qa_score(target: str, prediction: str, question_type: str) -> float: + """RealWorldQA: A–D letter selection for MC, VQA2-normalized EM otherwise.""" + if question_type == "multiple_choice": + options = ["A", "B", "C", "D"] + pred_idx = select_mc_option(prediction, options) + gt_idx = options.index(target) + return float(pred_idx == gt_idx) + pred = preprocess_answer(prediction) + gt = preprocess_answer(target) + return float(pred == gt) diff --git a/src/olmo_eval/common/scorers/__init__.py b/src/olmo_eval/common/scorers/__init__.py index 551181cd..5f7ce53e 100644 --- a/src/olmo_eval/common/scorers/__init__.py +++ b/src/olmo_eval/common/scorers/__init__.py @@ -15,8 +15,22 @@ SQuADF1Scorer, ) from .code_execution import CodeExecutionScorer, MultiplEScorer +from .dense_caption_judge import DenseCaptionJudgeScorer from .execution import ContextScorer, ExecutionScorer, SandboxRequiredError from .ifeval import IFEvalScorer +from .image_qa import ( + Ai2dScorer, + AnlsScorer, + EmScorer, + MathVistaGptScorer, + MathVistaOfflineScorer, + MmmuScorer, + PointCountScorer, + RealWorldQaScorer, + RelaxedCorrectnessScorer, + ScifiRelaxedScorer, + VqaScoreScorer, +) from .llm_judge import ( JudgeFn, LLMJudgeScorer, @@ -40,12 +54,16 @@ ) __all__ = [ + "Ai2dScorer", + "AnlsScorer", "BitsPerByteScorer", "build_openai_judge_fn", "CodeExecutionScorer", "ContextScorer", + "DenseCaptionJudgeScorer", "ExactMatchFlexScorer", "ExactMatchScorer", + "EmScorer", "ExecutionScorer", "F1Scorer", "IFEvalScorer", @@ -53,13 +71,20 @@ "LLMJudgeScorer", "LogprobScorer", "MathVerifyScorer", + "MathVistaGptScorer", + "MathVistaOfflineScorer", + "MmmuScorer", "MinervaMathScorer", "MultipleChoiceScorer", "MultiplEScorer", "PerplexityScorer", + "PointCountScorer", "ProcessScorer", + "RealWorldQaScorer", + "RelaxedCorrectnessScorer", "RubricJudgeScorer", "SafetyScorer", + "ScifiRelaxedScorer", "SandboxRequiredError", "Scorer", "SQuADF1Scorer", @@ -73,4 +98,5 @@ "TrajectoryEfficiencyScorer", "TrajectoryResponseScorer", "TrajectoryStateScorer", + "VqaScoreScorer", ] diff --git a/src/olmo_eval/common/scorers/dense_caption_judge.py b/src/olmo_eval/common/scorers/dense_caption_judge.py new file mode 100644 index 00000000..fa719396 --- /dev/null +++ b/src/olmo_eval/common/scorers/dense_caption_judge.py @@ -0,0 +1,352 @@ +"""GPT-judge scorer for pixmo-cap dense-caption evaluation. + +Ports the scoring logic from mm_olmo/scripts/gpt_dense_caption_eval.py into +the olmo-eval-internal ContextScorer abstraction. The cache-key scheme is +byte-identical to the legacy Gpt4WithCache so the existing gpt4-cache/ files +are reused for offline/reproducible runs. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import os +import re +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from olmo_eval.common.execution import ScoringContext +from olmo_eval.common.scorers.execution import ContextScorer +from olmo_eval.common.types import Instance, LMOutput + +logger = logging.getLogger(__name__) + +_DEFAULT_CACHE_DIR = "/weka/oe-training-default/mm-olmo/dense_caption_eval/gpt4-cache" + +# Labels that GPT returns instead of Consistent/Inconsistent; skip them silently. +_UNKNOWN_CONSISTENCY_LABELS = [ + "not specified", + "cannot determine", + "not determinable", + "no verification", + "n/a", + "not confirmed", + "neither", + "not stated", + "no judgement", + "unable to determine", + "inconclusive", + "undetermined", + "insufficient information", + "no relevant information", + "no conclusion", + "not clear", + "unknown", + "uncertain", + "ambiguous", + "not addressed", + "not enough information", + "not mentioned", + "not enough info", + "no information", + "not verifiable", + "not applicable", +] +_UNKNOWN_PATTERN = re.compile( + r".*\b(" + "|".join(re.escape(s) for s in _UNKNOWN_CONSISTENCY_LABELS) + r").*$", + re.IGNORECASE, +) + +# Module-level lazy async clients, keyed by model name. +_ASYNC_CLIENTS: dict[str, Any] = {} + + +# --------------------------------------------------------------------------- +# Cache helpers (identical semantics to legacy Gpt4WithCache) +# --------------------------------------------------------------------------- + + +def _compute_hash(s: str) -> str: + return hashlib.sha256(s.encode("utf-8")).hexdigest() + + +def _cache_key(model: str, prompt: str) -> str: + kwargs = {"temperature": 0} + return _compute_hash( + model + "::::" + json.dumps(prompt) + "::::" + json.dumps(kwargs, sort_keys=True) + ) + + +async def _cached_gpt_call( + prompt: str, + *, + model: str, + cache_dir: str, + cache_only: bool, + recompute: bool = False, +) -> str: + """Async GPT call with file-based caching compatible with legacy gpt4-cache/. + + When ``recompute=True`` an existing cache entry is ignored and a fresh API + call is made; the new result overwrites the old cache file. + """ + key = _cache_key(model, prompt) + cache_file = Path(cache_dir) / f"{key}-v1.json" + + if not recompute and cache_file.exists(): + with open(cache_file) as f: + data = json.load(f) + return data["choices"][0]["message"]["content"] + + if cache_only: + raise ValueError(f"Cache miss (cache_only=True) for key {key[:16]}…") + + if model not in _ASYNC_CLIENTS: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError( + "OPENAI_API_KEY is required for DenseCaptionJudgeScorer on a cache miss." + ) + try: + from openai import AsyncOpenAI + except ImportError: + raise ImportError("openai package required: pip install openai") from None + _ASYNC_CLIENTS[model] = AsyncOpenAI(api_key=api_key) + + client = _ASYNC_CLIENTS[model] + response = await client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + temperature=0, + ) + completion = response.model_dump() + + # Atomic write: tmp → rename, identical to legacy Gpt4WithCache. + fd, tmp = tempfile.mkstemp(".tmp", prefix=f"{key}-v1.json", text=True, dir=cache_dir) + os.close(fd) + with open(tmp, "w") as f: + json.dump(completion, f) + os.rename(tmp, str(cache_file)) + + return completion["choices"][0]["message"]["content"] + + +# --------------------------------------------------------------------------- +# GPT prompt builders (verbatim from gpt_dense_caption_eval.py) +# --------------------------------------------------------------------------- + + +def _recall_prompt(mturk_statements: str, caption: str) -> str: + return ( + "Here are statements that annotators gave for an image.\n\n" + + mturk_statements.strip() + + ( + "\n\nNext, consider the following caption of the image. For each statement above," + ' state whether the fact is "Stated" or "Not Stated" in the caption.' + " The output should be in the form\n\n1. Stated\n2. Not Stated\n3. Stated\n\n" + "Do not output anything other than an ordered list of Stated and Not Stated.\n\n" + " Here is the caption: " + ) + + (caption.strip() if caption else "No caption provided.") + ) + + +def _canonical_prompt(caption: str) -> str: + return ( + "Based on the description of the image, come up with a list of the MOST canonical" + " statements that are mentioned in it. Each statement should be broken down as much" + " as possible. The statements should be an ordered list, where each item is separated" + " a newline. For instance, the rseponse may look like:\n\n" + "1. Statement A\n2. Statement B\n3. Statement C\n\n\n" + f"\n\n\nHere is the image description: {caption}" + ) + + +def _consistency_prompt(num_transcripts: int, transcripts_str: str, statements_str: str) -> str: + return ( + f"Here are {num_transcripts} captions people gave for an image using their voice.\n\n" + + transcripts_str + + ( + "\n\nHere are statements that a captioning model made about the image." + ' For each statement, state whether it\'s "Consistent" or "Inconsistent"' + " with the statements provided above. The output should be in the form\n\n" + "1. Consistent\n2. Inconsistent\n3. Consistent\n\n" + "Do not output anything other than an ordered list of Consistent and Inconsistent.\n\n" + ) + + statements_str + ) + + +# --------------------------------------------------------------------------- +# Parse helpers (verbatim logic from gpt_dense_caption_eval.py) +# --------------------------------------------------------------------------- + + +def parse_recall_output(text: str) -> tuple[int, int]: + """Parse GPT stated/not-stated output. + + Returns (num_covered, num_statements) counting only unambiguous lines. + Mirrors eval_recall() lines 323–346 in gpt_dense_caption_eval.py. + """ + lines = [x.strip() for x in text.split("\n") if x.strip()] + valid_scores: list[bool] = [] + for line in lines: + if re.fullmatch(r".*\bnot st[a-z]+$", line, flags=re.IGNORECASE): + valid_scores.append(False) + elif " stated" in line.lower(): + valid_scores.append(True) + # else: ambiguous line — skip (like legacy code) + return int(sum(valid_scores)), len(valid_scores) + + +def parse_consistency_output(text: str) -> tuple[int, int]: + """Parse GPT consistent/inconsistent output. + + Returns (num_consistent, num_valid) counting only unambiguous lines. + Mirrors eval_consistency() lines 403–461 in gpt_dense_caption_eval.py. + """ + lines = [x.strip() for x in text.split("\n") if x.strip()] + valid_scores: list[bool] = [] + for line in lines: + inconsistent: bool | None = None + if re.fullmatch( + r".*[^a-z]((i?inconsis?ten(t|cy)?)|incorrect|inconsistence|iconsistent" + r"|inconsisent|incomplete|contradictory).*", + line, + flags=re.IGNORECASE, + ): + inconsistent = True + if re.fullmatch( + r".*[^a-z](consistent(ly)?|constistent|correct).*$", + line, + flags=re.IGNORECASE, + ): + # both matched — treat as ambiguous (None); otherwise consistent (False) + inconsistent = None if inconsistent else False + if inconsistent is None: + if not _UNKNOWN_PATTERN.match(line): + logger.warning("Unexpected consistency label: %r", line) + continue + valid_scores.append(inconsistent) + num_consistent = sum(not x for x in valid_scores) + return num_consistent, len(valid_scores) + + +# --------------------------------------------------------------------------- +# Scorer +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class DenseCaptionJudgeScorer(ContextScorer): + """GPT-as-judge scorer for pixmo-cap dense-caption evaluation. + + Runs up to three GPT calls per example (recall stated-check, canonical + statements, consistency check) and stashes all per-example results in + ``output.metadata["dense_caption_result"]``. The primary ``float`` + return value is the raw recall ratio (0–1) for that output, or 0.0 if + the example is invalid. + + Cache keys are byte-identical to the legacy ``Gpt4WithCache`` in + ``mm_olmo/scripts/gpt_dense_caption_eval.py``, so existing ``gpt4-cache/`` + entries are reused automatically. + + ``instance.metadata`` must contain: + - ``mturk_statements`` (str): canonical_statements string from + ``mturk-eval-statements/{sha256(url)}.json``. + - ``transcripts`` (list[dict]): dicts with a ``"whisperTranscript"`` + key, from ``final-data.json``. + """ + + name: str = "dense_caption_judge" + + model: str = "gpt-4o-2024-05-13" + cache_dir: str = _DEFAULT_CACHE_DIR + cache_only: bool = False + recompute: bool = False + target_metrics: tuple[str, ...] = ("recall", "consistency") + + async def ascore_with_context( + self, + instance: Instance, + output: LMOutput, + context: ScoringContext, + ) -> float: + caption = (output.extracted_answer or output.text or "").strip() + mturk_statements: str = instance.metadata.get("mturk_statements", "") + transcripts: list[dict] = instance.metadata.get("transcripts", []) + transcripts_str = "\n\n".join( + t["whisperTranscript"] for t in transcripts if "whisperTranscript" in t + ) + + result: dict = {} + + if "recall" in self.target_metrics: + try: + raw = await _cached_gpt_call( + _recall_prompt(mturk_statements, caption), + model=self.model, + cache_dir=self.cache_dir, + cache_only=self.cache_only, + recompute=self.recompute, + ) + num_covered, num_statements = parse_recall_output(raw) + recall_valid = num_statements > 0 + result["recall"] = num_covered / num_statements if recall_valid else 0.0 + result["recall_at_10"] = ( + min(num_covered, 10) / min(num_statements, 10) if recall_valid else 0.0 + ) + result["num_statements"] = num_statements + result["num_covered"] = num_covered + result["recall_valid"] = recall_valid + except Exception as exc: + logger.warning( + "Recall scoring failed for %s: %s", + instance.metadata.get("url", "?"), + exc, + ) + result.update( + recall=0.0, + recall_at_10=0.0, + num_statements=0, + num_covered=0, + recall_valid=False, + ) + + if "consistency" in self.target_metrics: + try: + statements_str = await _cached_gpt_call( + _canonical_prompt(caption), + model=self.model, + cache_dir=self.cache_dir, + cache_only=self.cache_only, + recompute=self.recompute, + ) + cons_raw = await _cached_gpt_call( + _consistency_prompt(len(transcripts), transcripts_str, statements_str), + model=self.model, + cache_dir=self.cache_dir, + cache_only=self.cache_only, + recompute=self.recompute, + ) + num_consistent, num_valid = parse_consistency_output(cons_raw) + consistency_valid = num_valid > 0 + result["consistency"] = num_consistent / num_valid if consistency_valid else 0.0 + result["num_consistent"] = num_consistent + result["consistency_valid"] = consistency_valid + except Exception as exc: + logger.warning( + "Consistency scoring failed for %s: %s", + instance.metadata.get("url", "?"), + exc, + ) + result.update(consistency=0.0, num_consistent=0, consistency_valid=False) + + if output.metadata is None: + output.metadata = {} + output.metadata["dense_caption_result"] = result + + return result.get("recall", 0.0) if result.get("recall_valid", False) else 0.0 diff --git a/src/olmo_eval/common/scorers/image_qa.py b/src/olmo_eval/common/scorers/image_qa.py new file mode 100644 index 00000000..8cc8dd8e --- /dev/null +++ b/src/olmo_eval/common/scorers/image_qa.py @@ -0,0 +1,335 @@ +"""Scorers for the Molmo2 image-QA benchmarks. + +Each scorer mirrors one metric of the mm_olmo reference evaluators +(``VqaEval``, ``PointCountEval``, ``MathVistaEval``) using the vendored +functions in :mod:`olmo_eval.common.image_qa`. Scorers read what they need +from ``instance.metadata``: + +============================ ================================================= +Scorer Required instance metadata +============================ ================================================= +``VqaScoreScorer`` ``answers`` (list[str]) +``AnlsScorer`` / ``EmScorer`` ``answers`` +``RelaxedCorrectnessScorer`` ``answers`` +``ScifiRelaxedScorer`` ``answers`` +``MmmuScorer`` ``answer``, ``question_type``, ``options``, + ``example_id`` +``RealWorldQaScorer`` ``answer``, ``question_type`` +``MathVistaOfflineScorer`` ``answer``, ``question_type``, ``answer_type``, + ``choices``, ``precision`` +``MathVistaGptScorer`` same as offline, plus ``query`` +``PointCountScorer`` ``count`` +``Ai2dScorer`` ``answer_idx``, ``option_names``, ``abc_label``, + ``has_transparent_box`` +============================ ================================================= + +VQA-family scorers apply the original ``VqaEval`` prediction cleanup +(:func:`clean_prediction`); the counting and MathVista scorers intentionally +do not, matching mm_olmo. +""" + +from __future__ import annotations + +import logging +import os +import tempfile +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from olmo_eval.common.image_qa import ( + anls_metric, + clean_prediction, + extract_answer_quick, + math_vista_score_from_extraction, + math_vista_score_offline, + mmmu_score, + parse_count, + real_world_qa_score, + relaxed_correctness, + scifi_relaxed_correctness, + select_mc_option, + vqa_score, +) +from olmo_eval.common.image_qa.math_vista_offline import create_test_prompt +from olmo_eval.common.scorers.base import Scorer +from olmo_eval.common.scorers.execution import ContextScorer +from olmo_eval.common.types import Instance, LMOutput + +if TYPE_CHECKING: + from olmo_eval.common.execution import ScoringContext + +logger = logging.getLogger(__name__) + + +def _response_text(output: LMOutput) -> str: + answer = output.extracted_answer + if isinstance(answer, str) and answer: + return answer + return output.text or "" + + +def _answers(instance: Instance) -> list[str]: + answers = instance.metadata.get("answers") + if answers is None: + answer = instance.metadata.get("answer") + answers = [] if answer is None else [answer] + if isinstance(answers, str): + answers = [answers] + return list(answers) + + +@dataclass(frozen=True, slots=True) +class VqaScoreScorer(Scorer): + """Official VQA v2 accuracy against the reference answer list.""" + + name: str = "vqa_score" + + def score(self, instance: Instance, output: LMOutput) -> float: + pred = clean_prediction(_response_text(output)) + return float(vqa_score(_answers(instance), pred)) + + +@dataclass(frozen=True, slots=True) +class AnlsScorer(Scorer): + """ANLS (DocVQA / InfographicVQA), max over reference answers.""" + + name: str = "ansl" + + def score(self, instance: Instance, output: LMOutput) -> float: + pred = clean_prediction(_response_text(output)) + answers = _answers(instance) + if not answers: + return 0.0 + return float(max(anls_metric(ref, pred) for ref in answers)) + + +@dataclass(frozen=True, slots=True) +class EmScorer(Scorer): + """Case-insensitive exact match against any reference answer.""" + + name: str = "em" + + def score(self, instance: Instance, output: LMOutput) -> float: + pred = clean_prediction(_response_text(output)) + return float(pred.lower() in [x.lower() for x in _answers(instance)]) + + +@dataclass(frozen=True, slots=True) +class RelaxedCorrectnessScorer(Scorer): + """ChartQA relaxed accuracy, max over reference answers.""" + + name: str = "relaxed_correctness" + + def score(self, instance: Instance, output: LMOutput) -> float: + pred = clean_prediction(_response_text(output)) + answers = _answers(instance) + if not answers: + return 0.0 + return float(max(relaxed_correctness(ans, pred) for ans in answers)) + + +@dataclass(frozen=True, slots=True) +class ScifiRelaxedScorer(Scorer): + """Lenient ChartQA relaxed accuracy, max over reference answers.""" + + name: str = "scifi_relaxed_correctness" + + def score(self, instance: Instance, output: LMOutput) -> float: + pred = clean_prediction(_response_text(output)) + answers = _answers(instance) + if not answers: + return 0.0 + return float(max(scifi_relaxed_correctness(ans, pred) for ans in answers)) + + +@dataclass(frozen=True, slots=True) +class MmmuScorer(Scorer): + """Official MMMU scoring (multiple-choice parsing or open matching).""" + + name: str = "mmmu_score" + + def score(self, instance: Instance, output: LMOutput) -> float: + pred = clean_prediction(_response_text(output)) + meta = instance.metadata + return mmmu_score( + _answers(instance), + pred, + question_type=meta["question_type"], + options=meta.get("options") or [], + stable_id=str(meta.get("example_id", "")), + ) + + +@dataclass(frozen=True, slots=True) +class RealWorldQaScorer(Scorer): + """RealWorldQA: A–D letter match for MC, normalized EM otherwise.""" + + name: str = "real_world_qa_score" + + def score(self, instance: Instance, output: LMOutput) -> float: + pred = clean_prediction(_response_text(output)) + meta = instance.metadata + return float(real_world_qa_score(meta["answer"], pred, meta["question_type"])) + + +@dataclass(frozen=True, slots=True) +class MathVistaOfflineScorer(Scorer): + """MathVista scoring with offline (no-GPT) answer extraction.""" + + name: str = "score" + + def score(self, instance: Instance, output: LMOutput) -> float: + pred = _response_text(output).strip() + meta = instance.metadata + try: + correct = math_vista_score_offline( + pred, + question_type=meta["question_type"], + answer_type=meta["answer_type"], + choices=list(meta.get("choices") or []), + precision=meta.get("precision"), + target=meta["answer"], + ) + except Exception as exc: + logger.warning( + "MathVista offline scoring failed for %s: %s", meta.get("example_id"), exc + ) + return 0.0 + return float(correct) + + +@dataclass(frozen=True, slots=True) +class PointCountScorer(Scorer): + """Counting accuracy for the ``point_count`` style (CountBench/PixMo Count). + + Stores ``{correct, close, valid, pred_count}`` in + ``output.metadata["point_count_result"]`` for the per-count metrics and + returns ``correct``. + """ + + name: str = "point_count" + + def score(self, instance: Instance, output: LMOutput) -> float: + gt = int(instance.metadata["count"]) + pred_count = parse_count(_response_text(output)) + result = { + "correct": float(gt == pred_count), + "close": float(abs(gt - pred_count) <= 1), + "valid": 1.0, + "pred_count": pred_count, + } + if output.metadata is None: + output.metadata = {} + output.metadata["point_count_result"] = result + return result["correct"] + + +@dataclass(frozen=True, slots=True) +class Ai2dScorer(Scorer): + """AI2D multiple-choice scoring with opaque/transparent routing metadata. + + Stores ``{is_correct, abc_label, has_transparent_box}`` in + ``output.metadata["ai2d_result"]`` so the two AI2D metrics can route each + abc-label question to exactly one of the opaque/transparent variants. + """ + + name: str = "mc_ai2d" + + def score(self, instance: Instance, output: LMOutput) -> float: + pred = clean_prediction(_response_text(output)) + meta = instance.metadata + options = list(meta["option_names"]) + pred_idx = select_mc_option(pred, options) + is_correct = float(pred_idx == meta["answer_idx"]) + if output.metadata is None: + output.metadata = {} + output.metadata["ai2d_result"] = { + "is_correct": is_correct, + "abc_label": bool(meta["abc_label"]), + "has_transparent_box": bool(meta["has_transparent_box"]), + } + return is_correct + + +_PROCESS_GPT_CACHE_DIR: list[str] = [] + + +def _default_gpt_cache_dir() -> str: + """Per-run GPT cache dir: env override or a fresh process-local temp dir. + + Never defaults to any pre-existing shared cache; the shared mm_olmo + ``gpt4-cache`` must not be read or written by this scorer. + """ + env_dir = os.environ.get("MATHVISTA_GPT_CACHE_DIR") + if env_dir: + return env_dir + if not _PROCESS_GPT_CACHE_DIR: + _PROCESS_GPT_CACHE_DIR.append(tempfile.mkdtemp(prefix="mathvista-gpt-cache-")) + return _PROCESS_GPT_CACHE_DIR[0] + + +@dataclass(frozen=True) +class MathVistaGptScorer(ContextScorer): + """MathVista scoring with the official GPT-4 answer extraction. + + Follows the official protocol: deterministic short-circuits first, then a + ``gpt-4-0613`` extraction call (requires ``OPENAI_API_KEY``). Responses + are cached under ``cache_dir`` (env ``MATHVISTA_GPT_CACHE_DIR`` or a fresh + per-process temp dir) so a user's own re-runs are cheap; existing shared + caches are never touched. + """ + + name: str = "score" + + model: str = "gpt-4-0613" + cache_dir: str | None = field(default_factory=_default_gpt_cache_dir) + cache_only: bool = False + recompute: bool = False + + async def ascore_with_context( + self, + instance: Instance, + output: LMOutput, + context: ScoringContext, + ) -> float: + from olmo_eval.common.scorers.dense_caption_judge import _cached_gpt_call + + pred = _response_text(output).strip() + meta = instance.metadata + choices = list(meta.get("choices") or []) + question_type = meta["question_type"] + answer_type = meta["answer_type"] + + extraction = extract_answer_quick(pred, question_type, answer_type, choices) + if extraction is None: + try: + extraction = await _cached_gpt_call( + create_test_prompt(meta["query"], pred), + model=self.model, + cache_dir=self.cache_dir or _default_gpt_cache_dir(), + cache_only=self.cache_only, + recompute=self.recompute, + ) + except Exception as exc: + logger.warning( + "MathVista GPT extraction failed for %s: %s", meta.get("example_id"), exc + ) + return 0.0 + + if output.metadata is None: + output.metadata = {} + output.metadata["math_vista_extraction"] = extraction + + try: + correct = math_vista_score_from_extraction( + extraction, + question_type=question_type, + answer_type=answer_type, + choices=choices, + precision=meta.get("precision"), + target=meta["answer"], + ) + except Exception as exc: + logger.warning("MathVista GPT scoring failed for %s: %s", meta.get("example_id"), exc) + return 0.0 + return float(correct) diff --git a/src/olmo_eval/evals/tasks/ai2d.py b/src/olmo_eval/evals/tasks/ai2d.py new file mode 100644 index 00000000..6095390c --- /dev/null +++ b/src/olmo_eval/evals/tasks/ai2d.py @@ -0,0 +1,86 @@ +"""AI2D (validation by default; ``ai2d:test`` for the official test split). + +Mirrors mm_olmo's ``AI2DConfig(boxes="both")`` (task name +``ai2_diagram_v2_mix_transparent``): loads the prepared arrow dataset at +``torch_datasets/academic_datasets/ai2d``, where every abc-label question +appears twice — once with opaque answer boxes drawn on the diagram and once +with transparent ones (``has_transparent_box``). + +Formatting follows ``AI2DConfig.format_example``: when a question's answer +options are (almost all) the on-diagram letters themselves, options are +listed without ``A./B.`` prefixes and the model must answer with the option +text (``ai2_diagram_no_letter``); otherwise standard lettered options are +used. Multiple-choice style — no style tag. + +Reference (Molmo2-4B ck2000, val): mc_ai2d_opaque=0.8537, +mc_ai2d_transparent=0.9481. +""" + +from __future__ import annotations + +from collections.abc import Iterator + +from olmo_eval.common.image_qa import format_mc_question +from olmo_eval.common.scorers.image_qa import Ai2dScorer +from olmo_eval.common.types import Instance, SamplingParams, Split +from olmo_eval.evals.tasks.common import register, register_variant +from olmo_eval.evals.tasks.common.image_qa_base import ( + Ai2dMetric, + ImageQATask, + lazy_hf_image, + torch_datasets_dir, +) + +_SCORER = Ai2dScorer() +_OPAQUE = Ai2dMetric(name="mc_ai2d_opaque", scorer=_SCORER, transparent=False) +_TRANSPARENT = Ai2dMetric(name="mc_ai2d_transparent", scorer=_SCORER, transparent=True) + + +@register("ai2d") +class Ai2dTask(ImageQATask): + sampling_params = SamplingParams(temperature=0.0, max_tokens=32) + metrics = (_OPAQUE, _TRANSPARENT) + primary_metric = _OPAQUE + split = Split.VALIDATION + + def _build_instances(self) -> Iterator[Instance]: + import datasets + + ds = datasets.load_from_disk(str(torch_datasets_dir() / "academic_datasets" / "ai2d")) + ds = ds[self.config.split.value] + ds_nodecode = ds.cast_column("image", datasets.Image(decode=False)) + + for idx in range(len(ds_nodecode)): + ex = ds_nodecode[idx] + options = ex["answer_texts"] + answer_idx = ex["correct_answer"] + if ex["abc_label"] and sum(ex["option_is_abc"]) >= (len(options) - 1): + # ai2_diagram_no_letter: unlabelled options, abc options uppercased + unlabelled = [ + opt.upper() if abc else opt + for opt, abc in zip(options, ex["option_is_abc"], strict=True) + ] + question, option_names = format_mc_question( + ex["question"], unlabelled, labelled=False + ) + gold = unlabelled[answer_idx] + else: + question, option_names = format_mc_question(ex["question"], options) + gold = option_names[answer_idx] + yield Instance( + question=question, + gold_answer=gold, + metadata={ + "example_id": ex["question_id"], + "image_id": ex["image_id"], + "abc_label": ex["abc_label"], + "has_transparent_box": ex["has_transparent_box"], + "answer_idx": answer_idx, + "option_names": option_names, + "options": options, + "image": lazy_hf_image(ds_nodecode, idx, "image"), + }, + ) + + +register_variant("ai2d", "test", split=Split.TEST) diff --git a/src/olmo_eval/evals/tasks/chart_qa.py b/src/olmo_eval/evals/tasks/chart_qa.py new file mode 100644 index 00000000..16b8b2c7 --- /dev/null +++ b/src/olmo_eval/evals/tasks/chart_qa.py @@ -0,0 +1,70 @@ +"""ChartQA (human + augmented parts). + +Mirrors mm_olmo's ``ChartQaConfig`` (``parts="both"``): loads +``torch_datasets/chartqa/{split}/{split}_{human,augmented}.json`` (human part +first), prompts with the ``chart_qa`` style tag, and scores relaxed +correctness / scifi relaxed correctness / exact match with ``_human`` / +``_aug`` breakdowns. + +Reference (Molmo2-4B ck2000, val): relaxed_correctness=0.8380, em=0.7490, +scifi_relaxed_correctness=0.8516. +""" + +from __future__ import annotations + +import json +from collections.abc import Iterator + +from olmo_eval.common.metrics.base import Metric +from olmo_eval.common.scorers.image_qa import ( + EmScorer, + RelaxedCorrectnessScorer, + ScifiRelaxedScorer, +) +from olmo_eval.common.types import Instance, SamplingParams, Split +from olmo_eval.evals.tasks.common import register, register_variant +from olmo_eval.evals.tasks.common.image_qa_base import ( + ChartQaSubsetMetric, + ImageQATask, + torch_datasets_dir, +) + +_RELAXED = RelaxedCorrectnessScorer() +_SCIFI = ScifiRelaxedScorer() +_EM = EmScorer() + +_METRICS: tuple[Metric, ...] = tuple( + ChartQaSubsetMetric(name=scorer.name + suffix, scorer=scorer, subset=subset) + for scorer in (_RELAXED, _SCIFI, _EM) + for subset, suffix in (("all", ""), ("human", "_human"), ("aug", "_aug")) +) + + +@register("chart_qa") +class ChartQaTask(ImageQATask): + sampling_params = SamplingParams(temperature=0.0, max_tokens=12) + metrics = _METRICS + primary_metric = _METRICS[0] # relaxed_correctness (overall) + split = Split.VALIDATION + + def _build_instances(self) -> Iterator[Instance]: + split = "val" if self.config.split == Split.VALIDATION else "test" + src_dir = torch_datasets_dir() / "chartqa" / split + for part in ("human", "augmented"): + with open(src_dir / f"{split}_{part}.json") as f: + data = json.load(f) + for ex_id, ex in enumerate(data): + label = ex["label"] + yield Instance( + question=f"chart_qa: {ex['query']}", + gold_answer=label if isinstance(label, str) else label[0], + metadata={ + "answers": label, + "is_human": part == "human", + "example_id": ex_id, + "image_path": str(src_dir / "png" / ex["imgname"]), + }, + ) + + +register_variant("chart_qa", "test", split=Split.TEST) diff --git a/src/olmo_eval/evals/tasks/common/image_qa_base.py b/src/olmo_eval/evals/tasks/common/image_qa_base.py new file mode 100644 index 00000000..d562cd23 --- /dev/null +++ b/src/olmo_eval/evals/tasks/common/image_qa_base.py @@ -0,0 +1,274 @@ +"""Shared base task and metrics for the Molmo2 image-QA benchmarks. + +The 11 benchmark task modules (``chart_qa``, ``vqa2``, ``doc_qa``, ``info_qa``, +``text_vqa``, ``real_world_qa``, ``mmmu``, ``math_vista``, ``countbench_qa``, +``pixmo_count``, ``ai2d``) build on: + +* :class:`ImageQATask` — caches instances, formats CHAT requests, resolves the + mm_olmo data root. +* Generic metrics (:class:`MeanScorerMetric`, :class:`ChartQaSubsetMetric`, + :class:`PointCountMetric`, :class:`PointCountPerCountMetric`, + :class:`PointCountCategoryAverageMetric`, :class:`Ai2dMetric`). + +Conventions: + +* ``instance.question`` is the **fully formatted** prompt text — style prefix + (e.g. ``"vqa2: "``) and multiple-choice option block already baked in, so + inference scripts pass it through verbatim (prompt parity is the task's + responsibility). +* The image is stored in ``instance.metadata["image_path"]`` (filesystem path) + or ``instance.metadata["image"]`` (a PIL image or a zero-arg callable + returning one — use :func:`load_instance_image` to resolve either form). +* Data is read from ``$MOLMO_DATA_DIR`` (default + ``/weka/oe-training-default/mm-olmo``) and never written — loaders error + out rather than build caches. +""" + +from __future__ import annotations + +import os +from abc import abstractmethod +from collections.abc import Iterator, Sequence +from dataclasses import dataclass +from pathlib import Path + +from olmo_eval.common.metrics.base import Metric +from olmo_eval.common.scorers.base import Scorer +from olmo_eval.common.types import Instance, LMRequest, RequestType, Response +from olmo_eval.evals.tasks.common.base import Task + +DEFAULT_MOLMO_DATA_DIR = "/weka/oe-training-default/mm-olmo" + + +def molmo_data_dir() -> Path: + """Root of the mm-olmo data tree (read-only).""" + return Path(os.environ.get("MOLMO_DATA_DIR", DEFAULT_MOLMO_DATA_DIR)) + + +def torch_datasets_dir() -> Path: + return molmo_data_dir() / "torch_datasets" + + +def rebase_data_path(path: str) -> str: + """Rebase an absolute path recorded on another machine onto the current root. + + Cached manifests (e.g. ``vqa2/molmo_val.json``) store absolute image paths + from the machine that built them; if the stored path does not exist locally + but contains ``torch_datasets/``, re-anchor it under the current data root. + """ + if os.path.exists(path): + return path + marker = "torch_datasets/" + if marker in path: + suffix = path.split(marker, 1)[1] + return str(torch_datasets_dir() / suffix) + return path + + +def lazy_hf_image(dataset, index: int, column: str = "image"): + """A zero-arg callable that decodes one image cell of a no-decode HF dataset. + + ``dataset`` should have ``column`` cast to ``datasets.Image(decode=False)`` + so building instances never decodes pixels; the callable decodes exactly + one image when the inference script needs it. + """ + + def _load(): + import io + + from PIL import Image + + rec = dataset[index][column] + if isinstance(rec, dict): + if rec.get("bytes"): + return Image.open(io.BytesIO(rec["bytes"])) + if rec.get("path"): + return Image.open(rec["path"]) + return rec + + return _load + + +def load_instance_image(instance: Instance): + """Resolve an instance's image to a PIL image (or None if imageless).""" + image = instance.metadata.get("image") + if image is not None: + return image() if callable(image) else image + path = instance.metadata.get("image_path") + if path is not None: + from PIL import Image + + return Image.open(path) + return None + + +class ImageQATask(Task): + """Base class for the Molmo2 image-QA benchmark tasks.""" + + @property + def instances(self) -> Iterator[Instance]: + if self._instances_cache is None: + instances = list(self._build_instances()) + limit = self.config.limit + if limit is not None: + instances = instances[:limit] + self._instances_cache = instances + yield from self._instances_cache + + @abstractmethod + def _build_instances(self) -> Iterator[Instance]: + """Yield all instances for ``self.config.split`` (before ``limit``).""" + ... + + def format_request(self, instance: Instance) -> LMRequest: + return LMRequest( + request_type=RequestType.CHAT, + messages=({"role": "user", "content": instance.question},), + ) + + +# --------------------------------------------------------------------------- +# Generic metrics +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class MeanScorerMetric(Metric): + """Mean of a scorer's per-response score (the mm_olmo ``global_mean``).""" + + name: str # type: ignore[misc] + scorer: Scorer # type: ignore[misc] + + def compute(self, responses: Sequence[Response]) -> float: + if not responses: + return 0.0 + scorer_name = self.scorer().name + return sum(r.scores.get(scorer_name, 0.0) for r in responses) / len(responses) + + +@dataclass(frozen=True) +class ChartQaSubsetMetric(Metric): + """ChartQA metric over all / human / augmented examples. + + Subset membership comes from ``instance.metadata["is_human"]``, matching + the ``_human`` / ``_aug`` breakdowns of mm_olmo's ``VqaEval``. + """ + + name: str # type: ignore[misc] + scorer: Scorer # type: ignore[misc] + subset: str = "all" # all | human | aug + + def compute(self, responses: Sequence[Response]) -> float: + scorer_name = self.scorer().name + vals = [r.scores.get(scorer_name, 0.0) for r in responses if self._in_subset(r)] + return sum(vals) / len(vals) if vals else 0.0 + + def _in_subset(self, response: Response) -> bool: + if self.subset == "all": + return True + is_human = bool(response.instance.metadata.get("is_human")) + return is_human if self.subset == "human" else not is_human + + +def _point_count_results(responses: Sequence[Response]) -> Iterator[tuple[Response, dict]]: + for response in responses: + for output in response.outputs: + if output.metadata and "point_count_result" in output.metadata: + yield response, output.metadata["point_count_result"] + + +@dataclass(frozen=True) +class PointCountMetric(Metric): + """Mean of one field (``correct`` / ``close`` / ``valid``) of the count result.""" + + name: str # type: ignore[misc] + scorer: Scorer # type: ignore[misc] + kind: str = "correct" + + def compute(self, responses: Sequence[Response]) -> float: + vals = [result[self.kind] for _, result in _point_count_results(responses)] + return sum(vals) / len(vals) if vals else 0.0 + + +@dataclass(frozen=True) +class PointCountPerCountMetric(Metric): + """Counting accuracy restricted to examples with ground-truth count ``k``.""" + + name: str # type: ignore[misc] + scorer: Scorer # type: ignore[misc] + k: int = 0 + + def compute(self, responses: Sequence[Response]) -> float: + vals = [ + result["correct"] + for response, result in _point_count_results(responses) + if int(response.instance.metadata["count"]) == self.k + ] + return sum(vals) / len(vals) if vals else 0.0 + + +@dataclass(frozen=True) +class PointCountCategoryAverageMetric(Metric): + """Macro average of per-count accuracies over the counts present.""" + + name: str # type: ignore[misc] + scorer: Scorer # type: ignore[misc] + + def compute(self, responses: Sequence[Response]) -> float: + by_count: dict[int, list[float]] = {} + for response, result in _point_count_results(responses): + by_count.setdefault(int(response.instance.metadata["count"]), []).append( + result["correct"] + ) + if not by_count: + return 0.0 + per_count = [sum(v) / len(v) for v in by_count.values()] + return sum(per_count) / len(per_count) + + +@dataclass(frozen=True) +class Ai2dMetric(Metric): + """AI2D accuracy split by box rendering. + + abc-label questions count toward exactly one variant (transparent or + opaque, per ``has_transparent_box``); questions without abc labels count + toward both — matching ``mc_ai2d_opaque`` / ``mc_ai2d_transparent`` in + mm_olmo's ``VqaEval``. + """ + + name: str # type: ignore[misc] + scorer: Scorer # type: ignore[misc] + transparent: bool = False + + def compute(self, responses: Sequence[Response]) -> float: + vals: list[float] = [] + for response in responses: + for output in response.outputs: + if not output.metadata or "ai2d_result" not in output.metadata: + continue + result = output.metadata["ai2d_result"] + if result["abc_label"]: + if self.transparent and not result["has_transparent_box"]: + continue + if not self.transparent and result["has_transparent_box"]: + continue + vals.append(result["is_correct"]) + return sum(vals) / len(vals) if vals else 0.0 + + +# Counts present in the CountBench QA / PixMo Count eval sets. +POINT_COUNT_KS: tuple[int, ...] = tuple(range(2, 11)) + + +def point_count_metrics(scorer: Scorer) -> tuple[Metric, ...]: + """The full mm_olmo ``PointCountEval`` metric family for one shared scorer.""" + return ( + PointCountMetric(name="correct", scorer=scorer, kind="correct"), + PointCountMetric(name="close", scorer=scorer, kind="close"), + PointCountMetric(name="valid", scorer=scorer, kind="valid"), + *( + PointCountPerCountMetric(name=f"correct_{k}", scorer=scorer, k=k) + for k in POINT_COUNT_KS + ), + PointCountCategoryAverageMetric(name="per_category_average", scorer=scorer), + ) diff --git a/src/olmo_eval/evals/tasks/countbench_qa.py b/src/olmo_eval/evals/tasks/countbench_qa.py new file mode 100644 index 00000000..ae03f7dc --- /dev/null +++ b/src/olmo_eval/evals/tasks/countbench_qa.py @@ -0,0 +1,55 @@ +"""CountBench QA (490 examples, counts 2–10). + +Mirrors mm_olmo's ``CountBenchQaConfig``: loads the prepared arrow dataset at +``torch_datasets/academic_datasets/countbench_qa`` (CountBench images/counts +merged with the PaliGemma paired questions). The dataset has a single test +set; the natural-language counting question is used verbatim (``point_count`` +style — no style tag). + +Reference (Molmo2-4B ck2000): correct=0.9408. +""" + +from __future__ import annotations + +from collections.abc import Iterator + +from olmo_eval.common.scorers.image_qa import PointCountScorer +from olmo_eval.common.types import Instance, SamplingParams, Split +from olmo_eval.evals.tasks.common import register +from olmo_eval.evals.tasks.common.image_qa_base import ( + ImageQATask, + lazy_hf_image, + point_count_metrics, + torch_datasets_dir, +) + +_SCORER = PointCountScorer() +_METRICS = point_count_metrics(_SCORER) + + +@register("countbench_qa") +class CountBenchQaTask(ImageQATask): + sampling_params = SamplingParams(temperature=0.0, max_tokens=192) + metrics = _METRICS + primary_metric = _METRICS[0] # correct + split = Split.TEST # single prepared set + + def _build_instances(self) -> Iterator[Instance]: + import datasets + + ds = datasets.load_from_disk( + str(torch_datasets_dir() / "academic_datasets" / "countbench_qa") + ) + ds_nodecode = ds.cast_column("image", datasets.Image(decode=False)) + for idx in range(len(ds_nodecode)): + ex = ds_nodecode[idx] + yield Instance( + question=str(ex["question"]), + gold_answer=str(ex["count"]), + metadata={ + "count": ex["count"], + "example_id": ex["example_id"], + "image_url": ex["image_url"], + "image": lazy_hf_image(ds_nodecode, idx, "image"), + }, + ) diff --git a/src/olmo_eval/evals/tasks/dense_caption.py b/src/olmo_eval/evals/tasks/dense_caption.py new file mode 100644 index 00000000..dd8c241a --- /dev/null +++ b/src/olmo_eval/evals/tasks/dense_caption.py @@ -0,0 +1,244 @@ +"""Pixmo-cap dense-caption evaluation task. + +Loads the pixmo-cap test split (2730 images), joins judge reference data from +two on-disk sources, runs a GPT-4o recall+consistency judge, and aggregates +into recall / recall_at_10 / consistency / num_statements / avg metrics. + +Environment variables (optional — defaults point to Weka paths): + DENSE_CAPTION_EVAL_DIR root of dense_caption_eval/ (contains final-data.json, + mturk-eval-statements/, gpt4-cache/) + MOLMO_DATA_DIR parent of torch_datasets/ (contains pixmo_images/ and + pixmo_datasets/dense-caption-eval/test.jsonl) +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import os +from collections.abc import Iterator, Sequence +from dataclasses import dataclass +from pathlib import Path + +from olmo_eval.common.metrics.base import Metric +from olmo_eval.common.scorers.base import Scorer +from olmo_eval.common.scorers.dense_caption_judge import DenseCaptionJudgeScorer +from olmo_eval.common.types import Instance, LMRequest, RequestType, Response, SamplingParams +from olmo_eval.evals.tasks.common import Task, register, register_variant + +logger = logging.getLogger(__name__) + +_DEFAULT_EVAL_DIR = "/weka/oe-training-default/mm-olmo/dense_caption_eval" +_DEFAULT_DATA_HOME = "/weka/oe-training-default/mm-olmo/torch_datasets" + +# Shared scorer instance — all 5 metrics hold a reference so _get_scorers() +# deduplicates to a single GPT-judge call per example (via Scorer.__call__). +_JUDGE = DenseCaptionJudgeScorer() + + +# --------------------------------------------------------------------------- +# Metric definitions +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class DenseCaptionRecallMetric(Metric): + """Mean recall (×100) over valid examples.""" + + name: str = "recall" + scorer: type[Scorer] | Scorer = _JUDGE + + def compute(self, responses: Sequence[Response]) -> float: + vals = [ + o.metadata["dense_caption_result"]["recall"] + for r in responses + for o in r.outputs + if o.metadata + and "dense_caption_result" in o.metadata + and o.metadata["dense_caption_result"].get("recall_valid") + ] + return (sum(vals) / len(vals) * 100) if vals else 0.0 + + +@dataclass(frozen=True, slots=True) +class DenseCaptionConsistencyMetric(Metric): + """Mean consistency (×100) over valid examples.""" + + name: str = "consistency" + scorer: type[Scorer] | Scorer = _JUDGE + + def compute(self, responses: Sequence[Response]) -> float: + vals = [ + o.metadata["dense_caption_result"]["consistency"] + for r in responses + for o in r.outputs + if o.metadata + and "dense_caption_result" in o.metadata + and o.metadata["dense_caption_result"].get("consistency_valid") + ] + return (sum(vals) / len(vals) * 100) if vals else 0.0 + + +@dataclass(frozen=True, slots=True) +class DenseCaptionRecallAt10Metric(Metric): + """Mean recall-at-10 (×100) over valid examples.""" + + name: str = "recall_at_10" + scorer: type[Scorer] | Scorer = _JUDGE + + def compute(self, responses: Sequence[Response]) -> float: + vals = [ + o.metadata["dense_caption_result"]["recall_at_10"] + for r in responses + for o in r.outputs + if o.metadata + and "dense_caption_result" in o.metadata + and o.metadata["dense_caption_result"].get("recall_valid") + ] + return (sum(vals) / len(vals) * 100) if vals else 0.0 + + +@dataclass(frozen=True, slots=True) +class DenseCaptionNumStatementsMetric(Metric): + """Mean number of mturk statements per valid example (raw, not ×100).""" + + name: str = "num_statements" + scorer: type[Scorer] | Scorer = _JUDGE + + def compute(self, responses: Sequence[Response]) -> float: + vals = [ + o.metadata["dense_caption_result"]["num_statements"] + for r in responses + for o in r.outputs + if o.metadata + and "dense_caption_result" in o.metadata + and o.metadata["dense_caption_result"].get("recall_valid") + ] + return (sum(vals) / len(vals)) if vals else 0.0 + + +@dataclass(frozen=True, slots=True) +class DenseCaptionAvgMetric(Metric): + """Primary metric: (mean_recall + mean_consistency) / 2 × 100.""" + + name: str = "avg" + scorer: type[Scorer] | Scorer = _JUDGE + + def compute(self, responses: Sequence[Response]) -> float: + results = [ + o.metadata["dense_caption_result"] + for r in responses + for o in r.outputs + if o.metadata and "dense_caption_result" in o.metadata + ] + recall_vals = [r["recall"] for r in results if r.get("recall_valid")] + cons_vals = [r["consistency"] for r in results if r.get("consistency_valid")] + mean_recall = sum(recall_vals) / len(recall_vals) if recall_vals else 0.0 + mean_cons = sum(cons_vals) / len(cons_vals) if cons_vals else 0.0 + return (mean_recall + mean_cons) / 2.0 * 100 + + +_DEFAULT_METRICS: tuple[Metric, ...] = ( + DenseCaptionRecallMetric(), + DenseCaptionConsistencyMetric(), + DenseCaptionRecallAt10Metric(), + DenseCaptionNumStatementsMetric(), + DenseCaptionAvgMetric(), +) +_AVG_METRIC = DenseCaptionAvgMetric() + + +# --------------------------------------------------------------------------- +# Task +# --------------------------------------------------------------------------- + + +def _sha256(s: str) -> str: + return hashlib.sha256(s.encode()).hexdigest() + + +@register("dense_caption") +class DenseCaptionEval(Task): + """Pixmo-cap dense-caption GPT-judge evaluation. + + Data comes from three on-disk sources: + * ``final-data.json`` — whisper transcripts (consistency reference) + * ``mturk-eval-statements/{sha256(url)}.json`` — canonical statements + (recall reference) + * ``torch_datasets/pixmo_datasets/dense-caption-eval/test.jsonl`` — image + paths and URLs + + The model inference request is a CHAT message "Describe this image." + The local image path is stored in ``instance.metadata["image_path"]`` for + an inference script to load. + """ + + sampling_params = SamplingParams(temperature=0.0, max_tokens=448) + metrics = _DEFAULT_METRICS + primary_metric = _AVG_METRIC + + @property + def instances(self) -> Iterator[Instance]: + if self._instances_cache is None: + self._instances_cache = list(self._build_instances()) + yield from self._instances_cache + + def _build_instances(self) -> Iterator[Instance]: + eval_dir = Path(os.environ.get("DENSE_CAPTION_EVAL_DIR", _DEFAULT_EVAL_DIR)) + data_home = Path(os.environ.get("MOLMO_DATA_DIR", _DEFAULT_DATA_HOME)) + test_jsonl = data_home / "pixmo_datasets" / "dense-caption-eval" / "test.jsonl" + + with open(eval_dir / "final-data.json") as f: + final_data = json.load(f) + url_to_transcripts: dict[str, list[dict]] = { + ex["image"]: ex["transcripts"] for ex in final_data + } + + limit = self.config.limit + count = 0 + with open(test_jsonl) as f: + for line in f: + if limit is not None and count >= limit: + break + rec = json.loads(line) + url: str = rec["url"] + image_id: str = rec.get("image_id", _sha256(url)) + image_name: str = rec.get("image", image_id) + image_path = data_home / "pixmo_images" / image_name + + transcripts = url_to_transcripts.get(url) + if transcripts is None: + logger.warning("No transcripts for %s — skipping", url) + continue + + mturk_file = eval_dir / "mturk-eval-statements" / f"{_sha256(url)}.json" + if not mturk_file.exists(): + logger.warning("No mturk file for %s — skipping", url) + continue + with open(mturk_file) as f2: + mturk_data = json.load(f2) + mturk_statements: str = mturk_data["canonical_statements"] + + yield Instance( + question="Describe this image.", + gold_answer=None, + metadata={ + "id": image_id, + "url": url, + "image_path": str(image_path), + "transcripts": transcripts, + "mturk_statements": mturk_statements, + }, + ) + count += 1 + + def format_request(self, instance: Instance) -> LMRequest: + return LMRequest( + request_type=RequestType.CHAT, + messages=({"role": "user", "content": instance.question},), + ) + + +# "pixmo_cap" is an alias for "dense_caption" with no overrides. +register_variant("dense_caption", "pixmo_cap") diff --git a/src/olmo_eval/evals/tasks/doc_qa.py b/src/olmo_eval/evals/tasks/doc_qa.py new file mode 100644 index 00000000..e5291b70 --- /dev/null +++ b/src/olmo_eval/evals/tasks/doc_qa.py @@ -0,0 +1,62 @@ +"""DocVQA (validation by default; ``doc_qa:test`` for the test split). + +Mirrors mm_olmo's ``DocQaConfig``: loads +``torch_datasets/docqa/val_v1.0_withQT.json`` (manual RRC Task 1 download), +prompts with the ``doc_qa`` style tag, scores ANLS (primary) + exact match. + +The ``doc_qa:test`` variant loads ``test_v1.0.json``, whose answers are not +public — like mm_olmo, instances carry placeholder ``[""]`` answers, so the +computed metrics are meaningless; run it to produce predictions for an RRC +evaluation-server submission (``questionId`` is in ``metadata["example_id"]``). +""" + +from __future__ import annotations + +import json +from collections.abc import Iterator + +from olmo_eval.common.scorers.image_qa import AnlsScorer, EmScorer +from olmo_eval.common.types import Instance, SamplingParams, Split +from olmo_eval.evals.tasks.common import register, register_variant +from olmo_eval.evals.tasks.common.image_qa_base import ( + ImageQATask, + MeanScorerMetric, + torch_datasets_dir, +) + +_ANLS_METRIC = MeanScorerMetric(name="ansl", scorer=AnlsScorer()) +_EM_METRIC = MeanScorerMetric(name="em", scorer=EmScorer()) + + +@register("doc_qa") +class DocQaTask(ImageQATask): + sampling_params = SamplingParams(temperature=0.0, max_tokens=12) + metrics = (_ANLS_METRIC, _EM_METRIC) + primary_metric = _ANLS_METRIC + split = Split.VALIDATION + + def _build_instances(self) -> Iterator[Instance]: + src_dir = torch_datasets_dir() / "docqa" + if self.config.split == Split.TEST: + src = src_dir / "test_v1.0.json" + else: + src = src_dir / "val_v1.0_withQT.json" + with open(src) as f: + data = json.load(f) + for ex in data["data"]: + # The test split has no public answers; mm_olmo injects [""]. + answers = ex.get("answers") or [""] + yield Instance( + question=f"doc_qa: {ex['question']}", + gold_answer=answers[0] if answers[0] else None, + metadata={ + "answers": answers, + "example_id": ex["questionId"], + "doc_id": ex["docId"], + "question_types": ex.get("question_types") or [""], + "image_path": str(src_dir / ex["image"]), + }, + ) + + +register_variant("doc_qa", "test", split=Split.TEST) diff --git a/src/olmo_eval/evals/tasks/info_qa.py b/src/olmo_eval/evals/tasks/info_qa.py new file mode 100644 index 00000000..2f624880 --- /dev/null +++ b/src/olmo_eval/evals/tasks/info_qa.py @@ -0,0 +1,61 @@ +"""InfographicVQA (validation by default; ``info_qa:test`` for the test split). + +Mirrors mm_olmo's ``InfoQaConfig``: loads +``torch_datasets/info_qa/infographicsVQA_val_v1.0_withQT.json`` (manual RRC +Task 3 download), prompts with the ``info_qa`` style tag, scores ANLS +(primary) + exact match. + +The ``info_qa:test`` variant loads ``infographicsVQA_test_v1.0.json``, whose +answers are not public — instances carry placeholder ``[""]`` answers, so the +computed metrics are meaningless; run it to produce predictions for an RRC +evaluation-server submission (``questionId`` is in ``metadata["example_id"]``). +""" + +from __future__ import annotations + +import json +from collections.abc import Iterator + +from olmo_eval.common.scorers.image_qa import AnlsScorer, EmScorer +from olmo_eval.common.types import Instance, SamplingParams, Split +from olmo_eval.evals.tasks.common import register, register_variant +from olmo_eval.evals.tasks.common.image_qa_base import ( + ImageQATask, + MeanScorerMetric, + torch_datasets_dir, +) + +_ANLS_METRIC = MeanScorerMetric(name="ansl", scorer=AnlsScorer()) +_EM_METRIC = MeanScorerMetric(name="em", scorer=EmScorer()) + + +@register("info_qa") +class InfoQaTask(ImageQATask): + sampling_params = SamplingParams(temperature=0.0, max_tokens=12) + metrics = (_ANLS_METRIC, _EM_METRIC) + primary_metric = _ANLS_METRIC + split = Split.VALIDATION + + def _build_instances(self) -> Iterator[Instance]: + src_dir = torch_datasets_dir() / "info_qa" + if self.config.split == Split.TEST: + src = src_dir / "infographicsVQA_test_v1.0.json" + else: + src = src_dir / "infographicsVQA_val_v1.0_withQT.json" + with open(src) as f: + data = json.load(f) + for ex in data["data"]: + # The test split has no public answers; placeholder [""] like mm_olmo. + answers = ex.get("answers") or [""] + yield Instance( + question=f"info_qa: {ex['question']}", + gold_answer=answers[0] if answers[0] else None, + metadata={ + "answers": answers, + "example_id": ex["questionId"], + "image_path": str(src_dir / "images" / ex["image_local_name"]), + }, + ) + + +register_variant("info_qa", "test", split=Split.TEST) diff --git a/src/olmo_eval/evals/tasks/math_vista.py b/src/olmo_eval/evals/tasks/math_vista.py new file mode 100644 index 00000000..dfef4dc3 --- /dev/null +++ b/src/olmo_eval/evals/tasks/math_vista.py @@ -0,0 +1,73 @@ +"""MathVista testmini (1,000 examples). + +Mirrors mm_olmo's ``MathVistaConfig(simplify_question=True)`` (task name +``math_vista_v2``): the requested validation split maps to HF ``testmini``; +the ``Question:``/``Hint:`` boilerplate is stripped from the query. +Multiple-choice questions are templated with lettered options (no style tag); +free-form questions are prompted with the ``vqa2`` style tag. + +Scoring: the default task scores **offline** (deterministic extraction — no +API key needed). The official protocol extracts answers with GPT-4 +(``gpt-4-0613``); use the ``math_vista:gpt`` variant for that (requires +``OPENAI_API_KEY``; responses cached per-run, see ``MathVistaGptScorer``). + +Reference (Molmo2-4B ck2000, GPT extraction): score=0.5670. +""" + +from __future__ import annotations + +from collections.abc import Iterator + +from olmo_eval.common.image_qa import format_mc_question +from olmo_eval.common.scorers.image_qa import MathVistaGptScorer, MathVistaOfflineScorer +from olmo_eval.common.types import Instance, SamplingParams, Split +from olmo_eval.evals.tasks.common import register, register_variant +from olmo_eval.evals.tasks.common.image_qa_base import ImageQATask, MeanScorerMetric, lazy_hf_image + +_OFFLINE_METRIC = MeanScorerMetric(name="score", scorer=MathVistaOfflineScorer()) +_GPT_METRIC = MeanScorerMetric(name="score", scorer=MathVistaGptScorer()) + + +@register("math_vista") +class MathVistaTask(ImageQATask): + sampling_params = SamplingParams(temperature=0.0, max_tokens=32) + metrics = (_OFFLINE_METRIC,) + primary_metric = _OFFLINE_METRIC + split = Split.VALIDATION # maps to the HF "testmini" split + + def _build_instances(self) -> Iterator[Instance]: + import datasets + + split = "testmini" if self.config.split == Split.VALIDATION else self.config.split.value + ds = datasets.load_dataset("AI4Math/MathVista", split=split) + ds_nodecode = ds.cast_column("decoded_image", datasets.Image(decode=False)) + + for idx in range(len(ds_nodecode)): + ex = ds_nodecode[idx] + # simplify_question=True: strip the "Question:"/"Hint:" wrappers + question = ex["question"].split("Question:")[-1] + question = question.split("Hint:")[0].strip() + + metadata = { + "example_id": ex["pid"], + "answer": ex["answer"], + "precision": ex["precision"], + "query": ex["question"], + "choices": ex["choices"], + "question_type": ex["question_type"], + "answer_type": ex["answer_type"], + "image": lazy_hf_image(ds_nodecode, idx, "decoded_image"), + } + if ex["question_type"] == "multi_choice": + question, option_names = format_mc_question(question, ex["choices"]) + metadata["option_names"] = option_names + else: + question = f"vqa2: {question}" + yield Instance( + question=question, + gold_answer=ex["answer"], + metadata=metadata, + ) + + +register_variant("math_vista", "gpt", metrics=(_GPT_METRIC,), primary_metric=_GPT_METRIC) diff --git a/src/olmo_eval/evals/tasks/mmmu.py b/src/olmo_eval/evals/tasks/mmmu.py new file mode 100644 index 00000000..7ca6c415 --- /dev/null +++ b/src/olmo_eval/evals/tasks/mmmu.py @@ -0,0 +1,111 @@ +"""MMMU validation (all 30 subjects, 900 examples). + +Mirrors mm_olmo's ``MMMUConfig()`` (task name ``mmmu_test`` — the validation +split is evaluated, as is standard practice). Multiple-choice questions are +templated with lettered options and no style tag; open questions are prompted +with the ``vqa2`` style tag. Following LLaVA (and mm_olmo), the image input +is dropped when more than one answer option embeds an image tag. + +Set ``HF_DATASETS_CACHE`` to a local cache for offline loading. + +Reference (Molmo2-4B ck2000): mmmu_score=0.5089. +""" + +from __future__ import annotations + +import ast +import re +from collections.abc import Iterator +from typing import ClassVar + +from olmo_eval.common.image_qa import format_mc_question +from olmo_eval.common.scorers.image_qa import MmmuScorer +from olmo_eval.common.types import Instance, SamplingParams, Split +from olmo_eval.evals.tasks.common import register +from olmo_eval.evals.tasks.common.image_qa_base import ImageQATask, MeanScorerMetric, lazy_hf_image + +_METRIC = MeanScorerMetric(name="mmmu_score", scorer=MmmuScorer()) + +_IMG_OPTION_RE = re.compile(r"") + + +@register("mmmu") +class MmmuTask(ImageQATask): + sampling_params = SamplingParams(temperature=0.0, max_tokens=12) + metrics = (_METRIC,) + primary_metric = _METRIC + split = Split.VALIDATION + + SUBJECTS: ClassVar[list[str]] = [ + "Accounting", + "Agriculture", + "Architecture_and_Engineering", + "Art", + "Art_Theory", + "Basic_Medical_Science", + "Biology", + "Chemistry", + "Clinical_Medicine", + "Computer_Science", + "Design", + "Diagnostics_and_Laboratory_Medicine", + "Economics", + "Electronics", + "Energy_and_Power", + "Finance", + "Geography", + "History", + "Literature", + "Manage", + "Marketing", + "Materials", + "Math", + "Mechanical_Engineering", + "Music", + "Pharmacy", + "Physics", + "Psychology", + "Public_Health", + "Sociology", + ] + + def _build_instances(self) -> Iterator[Instance]: + import datasets + + parts = [ + datasets.load_dataset("MMMU/MMMU", name=name, split=self.config.split.value) + for name in self.SUBJECTS + ] + ds = datasets.concatenate_datasets(parts) + image_cols = [f"image_{i}" for i in range(1, 8)] + ds_nodecode = ds + for col in image_cols: + ds_nodecode = ds_nodecode.cast_column(col, datasets.Image(decode=False)) + + for idx in range(len(ds_nodecode)): + ex = ds_nodecode[idx] + is_mc = ex["question_type"] == "multiple-choice" + metadata = { + "answer": ex["answer"], + "example_id": ex["id"], + "question_type": ex["question_type"], + } + if is_mc: + options = ast.literal_eval(ex["options"]) + metadata["options"] = options + # Following LLaVA, drop the image when multiple options embed + # image paths (the images are the answer options themselves). + n_img_options = sum(_IMG_OPTION_RE.match(opt) is not None for opt in options) + if n_img_options <= 1 and ex["image_1"] is not None: + metadata["image"] = lazy_hf_image(ds_nodecode, idx, "image_1") + question, option_names = format_mc_question(ex["question"], options) + metadata["option_names"] = option_names + else: + if ex["image_1"] is not None: + metadata["image"] = lazy_hf_image(ds_nodecode, idx, "image_1") + question = f"vqa2: {ex['question']}" + yield Instance( + question=question, + gold_answer=ex["answer"], + metadata=metadata, + ) diff --git a/src/olmo_eval/evals/tasks/pixmo_count.py b/src/olmo_eval/evals/tasks/pixmo_count.py new file mode 100644 index 00000000..27fca059 --- /dev/null +++ b/src/olmo_eval/evals/tasks/pixmo_count.py @@ -0,0 +1,65 @@ +"""PixMo Count (validation by default; ``pixmo_count:test`` for the test split). + +Mirrors mm_olmo's ``PixMoCountConfig(counting=True)`` (task name +``pixmo_count_counting``): loads the prepared arrow dataset at +``torch_datasets/pixmo_datasets/count`` and asks an RNG-templated counting +question per example (``point_count`` style — no style tag). + +The question template is selected per example by the seeded RNG of mm_olmo's +eval data pipeline, which depends on the example's **arrow-order index** — +instances are therefore built strictly in arrow order (verified to reproduce +all 540 released validation prompts exactly). + +Reference (Molmo2-4B ck2000, val): correct=0.9093. +""" + +from __future__ import annotations + +from collections.abc import Iterator + +from olmo_eval.common.image_qa import pixmo_count_question +from olmo_eval.common.scorers.image_qa import PointCountScorer +from olmo_eval.common.types import Instance, SamplingParams, Split +from olmo_eval.evals.tasks.common import register, register_variant +from olmo_eval.evals.tasks.common.image_qa_base import ( + ImageQATask, + point_count_metrics, + rebase_data_path, + torch_datasets_dir, +) + +_SCORER = PointCountScorer() +_METRICS = point_count_metrics(_SCORER) + + +@register("pixmo_count") +class PixmoCountTask(ImageQATask): + sampling_params = SamplingParams(temperature=0.0, max_tokens=192) + metrics = _METRICS + primary_metric = _METRICS[0] # correct + split = Split.VALIDATION + + def _build_instances(self) -> Iterator[Instance]: + import datasets + + ds = datasets.load_from_disk(str(torch_datasets_dir() / "pixmo_datasets" / "count")) + ds = ds[self.config.split.value] + # Arrow order is load-bearing: the per-example question template is + # picked by an RNG seeded with the arrow index. + for idx in range(len(ds)): + ex = ds[idx] + yield Instance( + question=pixmo_count_question(ex["label"], idx), + gold_answer=str(ex["count"]), + metadata={ + "count": ex["count"], + "label": ex["label"], + "arrow_idx": idx, + "example_id": ex["image_url"], + "image_url": ex["image_url"], + "image_path": rebase_data_path(ex["image"]), + }, + ) + + +register_variant("pixmo_count", "test", split=Split.TEST) diff --git a/src/olmo_eval/evals/tasks/real_world_qa.py b/src/olmo_eval/evals/tasks/real_world_qa.py new file mode 100644 index 00000000..77c0b8ca --- /dev/null +++ b/src/olmo_eval/evals/tasks/real_world_qa.py @@ -0,0 +1,65 @@ +"""RealWorldQA (xai-org/RealworldQA, test split — the only split). + +Mirrors mm_olmo's ``RealWorldQaConfig(mode="no_instruction")`` (task name +``real_world_qa_no_instruction``): each question embeds one of two x.ai +instruction suffixes which determine the question type. Short-answer +questions are truncated to their first line and prompted with the ``vqa2`` +style tag; multiple-choice questions keep the full original prompt (embedded +options + letter instruction) with no style tag. + +Set ``HF_DATASETS_CACHE`` to a local cache for offline loading. + +Reference (Molmo2-4B ck2000): real_world_qa_score=0.7542. +""" + +from __future__ import annotations + +from collections.abc import Iterator + +from olmo_eval.common.scorers.image_qa import RealWorldQaScorer +from olmo_eval.common.types import Instance, SamplingParams, Split +from olmo_eval.evals.tasks.common import register +from olmo_eval.evals.tasks.common.image_qa_base import ImageQATask, MeanScorerMetric, lazy_hf_image + +_METRIC = MeanScorerMetric(name="real_world_qa_score", scorer=RealWorldQaScorer()) + +_SHORT_ANSWER_INSTRUCTION = "Please answer directly with a single word or number." +_MC_INSTRUCTION = ( + "Please answer directly with only the letter of the correct option and nothing else." +) + + +@register("real_world_qa") +class RealWorldQaTask(ImageQATask): + sampling_params = SamplingParams(temperature=0.0, max_tokens=12) + metrics = (_METRIC,) + primary_metric = _METRIC + split = Split.TEST # the dataset's only split + + def _build_instances(self) -> Iterator[Instance]: + import datasets + + ds = datasets.load_dataset("xai-org/RealworldQA", split="test") + ds_nodecode = ds.cast_column("image", datasets.Image(decode=False)) + for idx in range(len(ds_nodecode)): + ex = ds_nodecode[idx] + prompt: str = ex["question"] + if _SHORT_ANSWER_INSTRUCTION in prompt: + question_type = "short_answer" + first_line = prompt.split("\n")[0] + question = f"vqa2: {first_line}" + else: + assert _MC_INSTRUCTION in prompt, prompt + question_type = "multiple_choice" + question = prompt + yield Instance( + question=question, + gold_answer=ex["answer"], + metadata={ + "answer": ex["answer"], + "question_type": question_type, + "example_id": idx, + "original_question": prompt, + "image": lazy_hf_image(ds_nodecode, idx, "image"), + }, + ) diff --git a/src/olmo_eval/evals/tasks/text_vqa.py b/src/olmo_eval/evals/tasks/text_vqa.py new file mode 100644 index 00000000..8eeee4b7 --- /dev/null +++ b/src/olmo_eval/evals/tasks/text_vqa.py @@ -0,0 +1,51 @@ +"""TextVQA validation. + +Mirrors mm_olmo's ``TextVqaConfig``: loads +``torch_datasets/text_vqa/TextVQA_0.5.1_val.json`` (val images live under +``train_images/``), prompts with the ``text_vqa`` style tag, scores the +official VQA accuracy. +""" + +from __future__ import annotations + +import json +from collections.abc import Iterator + +from olmo_eval.common.scorers.image_qa import VqaScoreScorer +from olmo_eval.common.types import Instance, SamplingParams, Split +from olmo_eval.evals.tasks.common import register +from olmo_eval.evals.tasks.common.image_qa_base import ( + ImageQATask, + MeanScorerMetric, + torch_datasets_dir, +) + +_METRIC = MeanScorerMetric(name="vqa_score", scorer=VqaScoreScorer()) + +_NUM_ANSWERS_PER_QUESTION = 10 + + +@register("text_vqa") +class TextVqaTask(ImageQATask): + sampling_params = SamplingParams(temperature=0.0, max_tokens=12) + metrics = (_METRIC,) + primary_metric = _METRIC + split = Split.VALIDATION + + def _build_instances(self) -> Iterator[Instance]: + src_dir = torch_datasets_dir() / "text_vqa" + with open(src_dir / "TextVQA_0.5.1_val.json") as f: + data = json.load(f)["data"] + for item in data: + answers = item.get("answers", [""] * _NUM_ANSWERS_PER_QUESTION) + image_subfolder = "train_images" if item["set_name"] != "test" else "test_images" + yield Instance( + question=f"text_vqa: {item['question']}", + gold_answer=answers[0] if answers else None, + metadata={ + "answers": answers, + "example_id": item["question_id"], + "image_id": item["image_id"], + "image_path": str(src_dir / image_subfolder / f"{item['image_id']}.jpg"), + }, + ) diff --git a/src/olmo_eval/evals/tasks/vqa2.py b/src/olmo_eval/evals/tasks/vqa2.py new file mode 100644 index 00000000..0db99f76 --- /dev/null +++ b/src/olmo_eval/evals/tasks/vqa2.py @@ -0,0 +1,100 @@ +"""VQA v2.0 validation (8,192-question subsample). + +Mirrors mm_olmo's ``Vqa2Config(multi_question=False, sample=8192)`` +(task name ``coco_2014_vqa_8192``): reads the prebuilt +``torch_datasets/vqa2/molmo_val.json`` manifest (never rebuilt here — the +loader is strictly read-only), flattens to one example per question, and +subsamples 8,192 questions with ``np.random.RandomState(9123)``. + +Reference (Molmo2-4B ck2000): vqa_score=0.8582. +""" + +from __future__ import annotations + +import json +from collections.abc import Iterator + +import numpy as np + +from olmo_eval.common.scorers.image_qa import VqaScoreScorer +from olmo_eval.common.types import Instance, SamplingParams, Split +from olmo_eval.evals.tasks.common import register +from olmo_eval.evals.tasks.common.image_qa_base import ( + ImageQATask, + MeanScorerMetric, + molmo_data_dir, + rebase_data_path, + torch_datasets_dir, +) + +_METRIC = MeanScorerMetric(name="vqa_score", scorer=VqaScoreScorer()) + + +def _resolve_coco_image(path: str) -> str: + """Resolve a COCO image path from the manifest. + + Prefers the recorded ``vqa2//`` location; falls back to the shared + ``images/coco//`` tree (the ``vqa2`` image dirs are symlinks that + may be broken on some mounts). + """ + import os + + path = rebase_data_path(path) + if os.path.exists(path): + return path + name = os.path.basename(path) # e.g. COCO_val2014_000000123456.jpg + coco_split = name.split("_")[1] + fallback = molmo_data_dir() / "images" / "coco" / coco_split / name + return str(fallback) + + +_SAMPLE_SIZE = 8192 +_SAMPLE_SEED = 9123 + + +@register("vqa2") +class Vqa2Task(ImageQATask): + sampling_params = SamplingParams(temperature=0.0, max_tokens=12) + metrics = (_METRIC,) + primary_metric = _METRIC + split = Split.VALIDATION + + def _build_instances(self) -> Iterator[Instance]: + src = torch_datasets_dir() / "vqa2" / "molmo_val.json" + if not src.exists(): + raise FileNotFoundError( + f"{src} not found. This loader only reads the manifest cached by the " + "original mm_olmo pipeline; it never (re)builds it." + ) + with open(src) as f: + data = json.load(f) + + flattened = [] + for item in data: + for q in item["messages"]: + flattened.append( + { + "question": q["question"], + "answers": q["answers"], + "image": item["image"], + "image_id": item["image_id"], + "question_id": q["question_id"], + } + ) + # shuffle the list of dicts in place; the seeded RandomState ordering must + # match mm_olmo exactly for dump parity, so keep np shuffle (not random.shuffle). + np.random.RandomState(_SAMPLE_SEED).shuffle(flattened) # ty: ignore[invalid-argument-type] + flattened = flattened[:_SAMPLE_SIZE] + + for ex in flattened: + answers = ex["answers"] + yield Instance( + question=f"vqa2: {ex['question']}", + gold_answer=answers[0] if answers else None, + metadata={ + "answers": answers, + "image_id": ex["image_id"], + "example_id": ex["question_id"], + "image_path": _resolve_coco_image(ex["image"]), + }, + ) diff --git a/tests/core/test_dense_caption_judge.py b/tests/core/test_dense_caption_judge.py new file mode 100644 index 00000000..8d22f372 --- /dev/null +++ b/tests/core/test_dense_caption_judge.py @@ -0,0 +1,298 @@ +"""Tests for DenseCaptionJudgeScorer and dense_caption task. + +Two levels: +1. Unit tests for parse helpers — no I/O, no GPU, no API key needed. +2. Metric aggregation tests — synthetic scored outputs, verify ×100 scaling + and valid-filter semantics. +3. Scorer integration tests with mocked GPT calls — verify end-to-end + metadata plumbing and error handling. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import patch + +import pytest + +from olmo_eval.common.execution import ScoringContext +from olmo_eval.common.scorers.dense_caption_judge import ( + DenseCaptionJudgeScorer, + parse_consistency_output, + parse_recall_output, +) +from olmo_eval.common.types import Instance, LMOutput, Response +from olmo_eval.evals.tasks.dense_caption import ( + DenseCaptionAvgMetric, + DenseCaptionConsistencyMetric, + DenseCaptionNumStatementsMetric, + DenseCaptionRecallAt10Metric, + DenseCaptionRecallMetric, +) + +# --------------------------------------------------------------------------- +# 1. Unit tests — parse_recall_output +# --------------------------------------------------------------------------- + + +class TestParseRecallOutput: + def test_basic_stated_not_stated(self): + text = "1. Stated\n2. Not Stated\n3. Stated" + covered, total = parse_recall_output(text) + assert total == 3 + assert covered == 2 + + def test_gpt_misspelling_not_stted(self): + # Legacy regex allows any suffix after "not st" + text = "1. Not stated\n2. Not stted\n3. Stated" + covered, total = parse_recall_output(text) + assert total == 3 + assert covered == 1 + + def test_numbered_with_prefix(self): + # "Not Stated (missing)" does NOT match the end-anchored regex but + # does contain " stated", so the legacy logic counts it as Stated. + text = "1. The fact is Stated.\n2. Not Stated (missing)" + covered, total = parse_recall_output(text) + assert total == 2 + assert covered == 2 + + def test_clean_not_stated_at_end(self): + text = "1. Stated\n2. Not Stated" + covered, total = parse_recall_output(text) + assert total == 2 + assert covered == 1 + + def test_ambiguous_lines_skipped(self): + text = "1. Stated\n2. Unclear\n3. Not Stated" + covered, total = parse_recall_output(text) + assert total == 2 # "Unclear" is skipped + assert covered == 1 + + def test_empty_input(self): + covered, total = parse_recall_output("") + assert covered == 0 + assert total == 0 + + def test_case_insensitive(self): + text = "1. STATED\n2. NOT STATED" + covered, total = parse_recall_output(text) + assert total == 2 + assert covered == 1 + + +# --------------------------------------------------------------------------- +# 2. Unit tests — parse_consistency_output +# --------------------------------------------------------------------------- + + +class TestParseConsistencyOutput: + def test_basic_consistent_inconsistent(self): + text = "1. Consistent\n2. Inconsistent\n3. Consistent" + consistent, total = parse_consistency_output(text) + assert total == 3 + assert consistent == 2 + + def test_fuzzy_misspellings(self): + text = "1. inconsisent\n2. constistent\n3. Inconsistent" + consistent, total = parse_consistency_output(text) + assert total == 3 + assert consistent == 1 # only constistent is consistent + + def test_unknown_labels_skipped(self): + text = "1. Consistent\n2. Not specified\n3. Inconsistent\n4. ambiguous" + consistent, total = parse_consistency_output(text) + assert total == 2 # unknown labels skipped + assert consistent == 1 + + def test_empty_input(self): + consistent, total = parse_consistency_output("") + assert consistent == 0 + assert total == 0 + + def test_case_insensitive(self): + text = "1. CONSISTENT\n2. INCONSISTENT" + consistent, total = parse_consistency_output(text) + assert total == 2 + assert consistent == 1 + + +# --------------------------------------------------------------------------- +# 3. Metric aggregation tests +# --------------------------------------------------------------------------- + + +def _make_response(result: dict) -> Response: + """Build a synthetic Response with a pre-filled dense_caption_result.""" + instance = Instance(question="Describe this image.", gold_answer=None, metadata={}) + output = LMOutput(text="A caption.") + output.metadata = {"dense_caption_result": result} + return Response(instance=instance, request=None, outputs=[output], scores={}) # type: ignore[arg-type] + + +class TestDenseCaptionMetrics: + def _valid_result( + self, + recall: float = 0.5, + consistency: float = 0.8, + num_covered: int = 5, + num_statements: int = 10, + num_consistent: int = 8, + consistency_valid: bool = True, + ) -> dict: + return dict( + recall=recall, + recall_at_10=min(num_covered, 10) / min(num_statements, 10), + num_statements=num_statements, + num_covered=num_covered, + recall_valid=True, + consistency=consistency, + num_consistent=num_consistent, + consistency_valid=consistency_valid, + ) + + def test_recall_metric_valid(self): + responses = [ + _make_response(self._valid_result(recall=0.4)), + _make_response(self._valid_result(recall=0.6)), + ] + score = DenseCaptionRecallMetric().compute(responses) + assert abs(score - 50.0) < 1e-6 + + def test_recall_metric_filters_invalid(self): + invalid = dict( + recall=0.0, + recall_valid=False, + consistency=0.0, + consistency_valid=False, + num_statements=0, + num_covered=0, + recall_at_10=0.0, + num_consistent=0, + ) + responses = [_make_response(self._valid_result(recall=0.6)), _make_response(invalid)] + score = DenseCaptionRecallMetric().compute(responses) + assert abs(score - 60.0) < 1e-6 + + def test_consistency_metric(self): + responses = [ + _make_response(self._valid_result(consistency=0.7)), + _make_response(self._valid_result(consistency=0.9)), + ] + score = DenseCaptionConsistencyMetric().compute(responses) + assert abs(score - 80.0) < 1e-6 + + def test_recall_at_10(self): + result = self._valid_result(num_covered=8, num_statements=20) + # recall_at_10 = min(8,10)/min(20,10) = 8/10 = 0.8 + responses = [_make_response(result)] + score = DenseCaptionRecallAt10Metric().compute(responses) + assert abs(score - 80.0) < 1e-6 + + def test_num_statements_not_scaled(self): + responses = [ + _make_response(self._valid_result(num_statements=20)), + _make_response(self._valid_result(num_statements=10)), + ] + score = DenseCaptionNumStatementsMetric().compute(responses) + assert abs(score - 15.0) < 1e-6 + + def test_avg_metric(self): + responses = [ + _make_response(self._valid_result(recall=0.4, consistency=0.6)), + _make_response(self._valid_result(recall=0.6, consistency=0.8)), + ] + # mean_recall=0.5, mean_cons=0.7, avg=(0.5+0.7)/2*100=60 + score = DenseCaptionAvgMetric().compute(responses) + assert abs(score - 60.0) < 1e-6 + + def test_empty_responses(self): + for metric in [ + DenseCaptionRecallMetric(), + DenseCaptionConsistencyMetric(), + DenseCaptionRecallAt10Metric(), + DenseCaptionNumStatementsMetric(), + DenseCaptionAvgMetric(), + ]: + assert metric.compute([]) == 0.0 + + +# --------------------------------------------------------------------------- +# 4. Scorer unit test with mocked GPT calls +# --------------------------------------------------------------------------- + + +class TestDenseCaptionJudgeScorerMocked: + @pytest.mark.anyio + async def test_scorer_sets_metadata_and_returns_recall(self, tmp_path): + scorer = DenseCaptionJudgeScorer( + cache_dir=str(tmp_path), + cache_only=False, + ) + instance = Instance( + question="Describe this image.", + gold_answer=None, + metadata={ + "url": "http://example.com/img.jpg", + "mturk_statements": "1. The sky is blue.\n2. There is a tree.", + "transcripts": [{"whisperTranscript": "A sunny day with a tree."}], + }, + ) + output = LMOutput(text="The sky is blue and there is a tall tree.") + output.metadata = {} + + # GPT recall response: 2 Stated + # GPT canonical response: "1. Sky is blue.\n2. Tree is present." + # GPT consistency response: "1. Consistent\n2. Consistent" + call_responses = [ + "1. Stated\n2. Stated", # recall check + "1. Sky is blue.\n2. Tree is present.", # canonical statements + "1. Consistent\n2. Consistent", # consistency check + ] + call_iter = iter(call_responses) + + async def fake_gpt(*args: Any, **kwargs: Any) -> str: + return next(call_iter) + + with patch( + "olmo_eval.common.scorers.dense_caption_judge._cached_gpt_call", + side_effect=fake_gpt, + ): + score = await scorer.ascore_with_context(instance, output, ScoringContext()) + + assert output.metadata is not None + result = output.metadata["dense_caption_result"] + assert result["recall_valid"] is True + assert result["num_statements"] == 2 + assert result["num_covered"] == 2 + assert abs(result["recall"] - 1.0) < 1e-6 + assert abs(result["consistency"] - 1.0) < 1e-6 + assert abs(score - 1.0) < 1e-6 # primary return is recall + + @pytest.mark.anyio + async def test_scorer_handles_gpt_error_gracefully(self, tmp_path): + scorer = DenseCaptionJudgeScorer(cache_dir=str(tmp_path), cache_only=False) + instance = Instance( + question="Describe this image.", + gold_answer=None, + metadata={ + "url": "http://example.com/img.jpg", + "mturk_statements": "1. Something.", + "transcripts": [{"whisperTranscript": "Something."}], + }, + ) + output = LMOutput(text="A caption.") + output.metadata = {} + + async def failing_gpt(*args: Any, **kwargs: Any) -> str: + raise RuntimeError("API error") + + with patch( + "olmo_eval.common.scorers.dense_caption_judge._cached_gpt_call", + side_effect=failing_gpt, + ): + score = await scorer.ascore_with_context(instance, output, ScoringContext()) + + assert score == 0.0 + result = output.metadata["dense_caption_result"] + assert result["recall_valid"] is False diff --git a/tests/core/test_image_qa_scorers.py b/tests/core/test_image_qa_scorers.py new file mode 100644 index 00000000..a1d4dc47 --- /dev/null +++ b/tests/core/test_image_qa_scorers.py @@ -0,0 +1,572 @@ +"""Tests for the vendored image-QA scoring/prompt utilities. + +Pure unit tests — no I/O, no GPU, no API key. Expected values mirror the +mm_olmo reference implementation (``olmo/eval/vqa.py``, +``molmo_prediction_evaluators.PointCountEval``, ``mmmu_eval_utils.py``, +``math_vista_utils.py``, ``data_formatter.py``). +""" + +from __future__ import annotations + +import pytest + +from olmo_eval.common.image_qa import ( + POINT_COUNT_TEMPLATES, + anls_metric, + clean_prediction, + extract_image_points, + format_mc_question, + levenshtein, + math_vista_score_offline, + mmmu_score, + parse_count, + parse_multi_choice_response, + parse_open_response, + pixmo_count_question, + preprocess_answer, + real_world_qa_score, + relaxed_correctness, + scifi_relaxed_correctness, + select_mc_option, + vqa_score, +) + +# --------------------------------------------------------------------------- +# levenshtein +# --------------------------------------------------------------------------- + + +class TestLevenshtein: + @pytest.mark.parametrize( + ("a", "b", "expected"), + [ + ("", "", 0), + ("abc", "abc", 0), + ("abc", "", 3), + ("", "abc", 3), + ("kitten", "sitting", 3), + ("flaw", "lawn", 2), + ("gumbo", "gambol", 2), + ("a", "b", 1), + ], + ) + def test_reference_distances(self, a: str, b: str, expected: int) -> None: + assert levenshtein(a, b) == expected + assert levenshtein(b, a) == expected + + +# --------------------------------------------------------------------------- +# VQA v2 normalization + vqa_score +# --------------------------------------------------------------------------- + + +class TestVqaScore: + def test_three_matches_is_full_credit(self) -> None: + answers = ["red", "red", "red", "blue"] + ["green"] * 6 + assert vqa_score(answers, "red") == 1.0 + + def test_one_match_is_third_credit(self) -> None: + answers = ["red"] + ["blue"] * 9 + assert vqa_score(answers, "red") == pytest.approx(1 / 3) + + def test_number_word_normalization(self) -> None: + # "two" and "2" normalize to the same answer + assert vqa_score(["2"] * 10, "two") == 1.0 + + def test_article_removal(self) -> None: + assert vqa_score(["dog"] * 10, "a dog") == 1.0 + + def test_contraction_normalization(self) -> None: + assert preprocess_answer("dont") == "don't" + + def test_punctuation_stripped(self) -> None: + assert vqa_score(["yes"] * 10, "yes.") == 1.0 + + def test_no_match(self) -> None: + assert vqa_score(["red"] * 10, "blue") == 0.0 + + +# --------------------------------------------------------------------------- +# clean_prediction (VqaEval cleanup) +# --------------------------------------------------------------------------- + + +class TestCleanPrediction: + def test_answer_prefix_split(self) -> None: + assert clean_prediction("Reasoning blah. Answer: 42") == "42" + + def test_multiline_majority_vote(self) -> None: + assert clean_prediction("cat\ndog\ncat") == "cat" + + def test_multiline_tie_takes_first(self) -> None: + assert clean_prediction("dog\ncat") == "dog" + + def test_whitespace_collapse(self) -> None: + assert clean_prediction(" a b ") == "a b" + + +# --------------------------------------------------------------------------- +# ANLS +# --------------------------------------------------------------------------- + + +class TestAnls: + def test_exact(self) -> None: + assert anls_metric("hello", "hello") == 1.0 + + def test_case_insensitive(self) -> None: + assert anls_metric("Hello", "hello") == 1.0 + + def test_below_threshold_scores_zero(self) -> None: + # distance 3 over max-len 5 = 0.6 >= 0.5 -> 0 + assert anls_metric("abcde", "abxyz") == 0 + + def test_above_threshold_partial(self) -> None: + # distance 1 over max-len 5 = 0.2 -> 0.8 + assert anls_metric("abcde", "abcdx") == pytest.approx(0.8) + + +# --------------------------------------------------------------------------- +# ChartQA relaxed correctness +# --------------------------------------------------------------------------- + + +class TestRelaxedCorrectness: + def test_numeric_within_5pct(self) -> None: + assert relaxed_correctness("100", "104") + assert not relaxed_correctness("100", "106") + + def test_percent_parsing(self) -> None: + assert relaxed_correctness("0.07", "7%") + + def test_non_numeric_exact(self) -> None: + assert relaxed_correctness("Yes", "yes") + assert not relaxed_correctness("Yes", "yes!") + + def test_zero_target_falls_back_to_exact(self) -> None: + # target "0" is falsy as float -> exact-match branch + assert relaxed_correctness("0", "0") + assert not relaxed_correctness("0", "0.001") + + +class TestScifiRelaxedCorrectness: + def test_answer_prefix(self) -> None: + assert scifi_relaxed_correctness("42", "the answer: 42") + + def test_word_to_number(self) -> None: + assert scifi_relaxed_correctness("3", "three") + + def test_comma_removal(self) -> None: + assert scifi_relaxed_correctness("1000", "1,000") + + def test_div_100_normalization(self) -> None: + assert scifi_relaxed_correctness("0.5", "50") + + def test_list_target(self) -> None: + assert scifi_relaxed_correctness("[2007, 2008]", "between 2007 and 2008") + assert not scifi_relaxed_correctness("[2007, 2008]", "between 2007 and 2009") + + def test_string_containment(self) -> None: + assert scifi_relaxed_correctness("cat", "it is a cat indeed") + + def test_empty_prediction(self) -> None: + assert not scifi_relaxed_correctness("1", "") + + +# --------------------------------------------------------------------------- +# select_mc_option +# --------------------------------------------------------------------------- + + +class TestSelectMcOption: + OPTIONS = ["A", "B", "C", "D"] + + def test_exact(self) -> None: + assert select_mc_option("b", self.OPTIONS) == 1 + + def test_target_starts_with_option(self) -> None: + assert select_mc_option("C. some text", self.OPTIONS) == 2 + + def test_option_starts_with_target(self) -> None: + options = ["apple pie", "banana split", "cherry cake"] + assert select_mc_option("banana", options) == 1 + + def test_containment(self) -> None: + options = ["the red car", "the blue boat", "the green tree"] + assert select_mc_option("blue", options) == 1 + + def test_edit_distance_fallback(self) -> None: + options = ["alpha", "beta", "gamma"] + assert select_mc_option("btea", options) == 1 + + def test_full_option_text(self) -> None: + options = ["moon", "none of the above", "earth", "sun"] + assert select_mc_option("B. none of the above", ["A", "B", "C", "D"]) == 1 + assert select_mc_option("none of the above", options) == 1 + + +# --------------------------------------------------------------------------- +# RealWorldQA +# --------------------------------------------------------------------------- + + +class TestRealWorldQa: + def test_mc_letter(self) -> None: + assert real_world_qa_score("B", "B", "multiple_choice") == 1.0 + assert real_world_qa_score("B", "C", "multiple_choice") == 0.0 + + def test_short_answer_normalized(self) -> None: + assert real_world_qa_score("two", "2", "short_answer") == 1.0 + + +# --------------------------------------------------------------------------- +# MMMU parsing + scoring +# --------------------------------------------------------------------------- + + +class TestMmmuParsing: + CHOICES = ["A", "B", "C", "D"] + INDEX2ANS = {"A": "moon", "B": "sun", "C": "earth", "D": "mars"} + + def test_paren_format(self) -> None: + assert parse_multi_choice_response("The answer is (B)", self.CHOICES, self.INDEX2ANS) == "B" + + def test_dot_format(self) -> None: + assert parse_multi_choice_response("B. sun", self.CHOICES, self.INDEX2ANS) == "B" + + def test_bare_letter(self) -> None: + assert parse_multi_choice_response("B", self.CHOICES, self.INDEX2ANS) == "B" + + def test_content_match_long_response(self) -> None: + resp = "I believe from the diagram that it must be the sun shining" + assert parse_multi_choice_response(resp, self.CHOICES, self.INDEX2ANS) == "B" + + def test_multiple_candidates_takes_last(self) -> None: + resp = "(A) is wrong, the answer is (C)" + assert parse_multi_choice_response(resp, self.CHOICES, self.INDEX2ANS) == "C" + + def test_unparseable_is_deterministic(self) -> None: + first = parse_multi_choice_response("?!", self.CHOICES, self.INDEX2ANS, stable_id="x1") + for _ in range(3): + assert ( + parse_multi_choice_response("?!", self.CHOICES, self.INDEX2ANS, stable_id="x1") + == first + ) + + def test_open_number_extraction(self) -> None: + preds = parse_open_response("So the result is 14.") + assert 14.0 in preds + + def test_open_comma_number(self) -> None: + preds = parse_open_response("The total is 1,234") + assert 1234.0 in preds + + def test_mmmu_score_mc(self) -> None: + score = mmmu_score( + ["B"], + "The answer is (B)", + question_type="multiple-choice", + options=["moon", "sun", "earth", "mars"], + ) + assert score == 1.0 + + def test_mmmu_score_open(self) -> None: + assert mmmu_score(["14"], "The answer is 14", question_type="open", options=[]) == 1.0 + assert mmmu_score(["14"], "The answer is 15", question_type="open", options=[]) == 0.0 + + +# --------------------------------------------------------------------------- +# MathVista offline scoring +# --------------------------------------------------------------------------- + + +class TestMathVistaOffline: + def test_mc(self) -> None: + assert math_vista_score_offline( + "B", + question_type="multi_choice", + answer_type="text", + choices=["3/11", "8/11", "6/11", "3/5"], + precision=None, + target="8/11", + ) + + def test_mc_full_text(self) -> None: + assert math_vista_score_offline( + "8/11", + question_type="multi_choice", + answer_type="text", + choices=["3/11", "8/11", "6/11", "3/5"], + precision=None, + target="8/11", + ) + + def test_integer(self) -> None: + assert math_vista_score_offline( + "14", + question_type="free_form", + answer_type="integer", + choices=[], + precision=None, + target="14", + ) + + def test_float_precision(self) -> None: + assert math_vista_score_offline( + "0.59999", + question_type="free_form", + answer_type="float", + choices=[], + precision=1, + target="0.6", + ) + + def test_wrong_integer(self) -> None: + assert not math_vista_score_offline( + "13", + question_type="free_form", + answer_type="integer", + choices=[], + precision=None, + target="14", + ) + + +# --------------------------------------------------------------------------- +# Count parsing (PointCountEval ladder) +# --------------------------------------------------------------------------- + + +class TestParseCount: + def test_last_token_int(self) -> None: + assert parse_count("There are 7") == 7 + + def test_trailing_period(self) -> None: + assert parse_count("Counting shows 12.") == 12 + + def test_number_word(self) -> None: + assert parse_count("there are three") == 3 + + def test_a_total_of(self) -> None: + pred = ( + 'Counting the people' + " shows a total of 8." + ) + assert parse_count(pred) == 8 + + def test_none_means_zero(self) -> None: + assert parse_count("There are none.") == 0 + + def test_points_fallback(self) -> None: + pred = 'cats' + assert parse_count(pred) == 3 + + def test_no_points_no_number(self) -> None: + assert parse_count("I cannot tell") == 0 + + def test_extract_image_points_unified(self) -> None: + text = 'x' + assert len(extract_image_points(text, 100, 100)) == 2 + + def test_extract_image_points_out_of_bounds_filtered(self) -> None: + # 4-digit coords > 1000 scale past the 100x100 bounds and are dropped + text = 'x' + assert len(extract_image_points(text, 100, 100)) == 0 + + +# --------------------------------------------------------------------------- +# Prompt templates +# --------------------------------------------------------------------------- + + +class TestPromptTemplates: + def test_template_count(self) -> None: + assert len(POINT_COUNT_TEMPLATES) == 60 + + # Pinned (arrow_idx, label) -> question triples captured from the released + # Molmo2-4B predictions-ck2000-pixmo_count_counting-validation dump; the + # full 540-prompt parity is asserted in the dump-parity test suite. + @pytest.mark.parametrize( + ("arrow_idx", "label", "expected"), + [ + ( + 0, + "cows", + "How many cows are there in the image? Point to them and output the total count.", + ), + (1, "people", "How many people do you see?"), + (17, "people", "how many people."), + ( + 539, + "people", + "Can you see any people in the image? Point to them and output the total count.", + ), + ], + ) + def test_pixmo_count_question_pinned(self, arrow_idx: int, label: str, expected: str) -> None: + assert pixmo_count_question(label, arrow_idx) == expected + + def test_pixmo_count_label_lowercased(self) -> None: + assert pixmo_count_question("People", 1) == pixmo_count_question("people", 1) + + def test_format_mc_question_labelled(self) -> None: + text, option_names = format_mc_question("What is X?", ["moon", "sun"]) + assert text == "What is X?\nOnly return the correct answer option.\nA. moon\nB. sun" + assert option_names == "AB" + + def test_format_mc_question_unlabelled(self) -> None: + text, option_names = format_mc_question("What is X?", ["P", "Q"], labelled=False) + assert text == "What is X?\nOnly return the correct answer option.\nP\nQ" + assert option_names == ["P", "Q"] + + +# --------------------------------------------------------------------------- +# Metric aggregation (synthetic responses) +# --------------------------------------------------------------------------- + +from olmo_eval.common.scorers.image_qa import ( # noqa: E402 + Ai2dScorer, + PointCountScorer, + RelaxedCorrectnessScorer, +) +from olmo_eval.common.types import Instance, LMOutput, Response # noqa: E402 +from olmo_eval.evals.tasks.common.image_qa_base import ( # noqa: E402 + Ai2dMetric, + ChartQaSubsetMetric, + MeanScorerMetric, + PointCountCategoryAverageMetric, + PointCountMetric, + PointCountPerCountMetric, +) + + +def _response(metadata: dict, score_name: str, score: float, output_metadata: dict | None = None): + instance = Instance(question="q", metadata=metadata) + output = LMOutput(text="", metadata=output_metadata or {}) + response = Response(instance=instance, request=None, outputs=[output]) + response.scores[score_name] = score + return response + + +class TestMeanScorerMetric: + def test_mean(self) -> None: + scorer = RelaxedCorrectnessScorer() + metric = MeanScorerMetric(name="relaxed_correctness", scorer=scorer) + responses = [ + _response({}, scorer.name, 1.0), + _response({}, scorer.name, 0.0), + ] + assert metric.compute(responses) == pytest.approx(0.5) + + def test_empty(self) -> None: + metric = MeanScorerMetric(name="x", scorer=RelaxedCorrectnessScorer()) + assert metric.compute([]) == 0.0 + + +class TestChartQaSubsetMetric: + def test_subset_split(self) -> None: + scorer = RelaxedCorrectnessScorer() + responses = [ + _response({"is_human": True}, scorer.name, 1.0), + _response({"is_human": True}, scorer.name, 0.0), + _response({"is_human": False}, scorer.name, 1.0), + _response({"is_human": False}, scorer.name, 1.0), + ] + m_all = ChartQaSubsetMetric(name="relaxed_correctness", scorer=scorer, subset="all") + m_human = ChartQaSubsetMetric( + name="relaxed_correctness_human", scorer=scorer, subset="human" + ) + m_aug = ChartQaSubsetMetric(name="relaxed_correctness_aug", scorer=scorer, subset="aug") + assert m_all.compute(responses) == pytest.approx(0.75) + assert m_human.compute(responses) == pytest.approx(0.5) + assert m_aug.compute(responses) == pytest.approx(1.0) + + +class TestPointCountMetrics: + def _responses(self): + scorer = PointCountScorer() + rows = [ + # (gt count, correct, close) + (2, 1.0, 1.0), + (2, 0.0, 1.0), + (3, 1.0, 1.0), + (5, 0.0, 0.0), + ] + responses = [] + for count, correct, close in rows: + responses.append( + _response( + {"count": count}, + scorer.name, + correct, + output_metadata={ + "point_count_result": { + "correct": correct, + "close": close, + "valid": 1.0, + "pred_count": 0, + } + }, + ) + ) + return scorer, responses + + def test_correct_close_valid(self) -> None: + scorer, responses = self._responses() + assert PointCountMetric(name="correct", scorer=scorer, kind="correct").compute( + responses + ) == pytest.approx(0.5) + assert PointCountMetric(name="close", scorer=scorer, kind="close").compute( + responses + ) == pytest.approx(0.75) + assert PointCountMetric(name="valid", scorer=scorer, kind="valid").compute( + responses + ) == pytest.approx(1.0) + + def test_per_count(self) -> None: + scorer, responses = self._responses() + assert PointCountPerCountMetric(name="correct_2", scorer=scorer, k=2).compute( + responses + ) == pytest.approx(0.5) + assert PointCountPerCountMetric(name="correct_3", scorer=scorer, k=3).compute( + responses + ) == pytest.approx(1.0) + # absent count -> 0.0 + assert ( + PointCountPerCountMetric(name="correct_9", scorer=scorer, k=9).compute(responses) == 0.0 + ) + + def test_per_category_average(self) -> None: + scorer, responses = self._responses() + # means: k=2 -> 0.5, k=3 -> 1.0, k=5 -> 0.0; macro avg = 0.5 + assert PointCountCategoryAverageMetric(name="per_category_average", scorer=scorer).compute( + responses + ) == pytest.approx(0.5) + + +class TestAi2dMetric: + def _response(self, is_correct: float, abc_label: bool, transparent_box: bool): + return _response( + {}, + "mc_ai2d", + is_correct, + output_metadata={ + "ai2d_result": { + "is_correct": is_correct, + "abc_label": abc_label, + "has_transparent_box": transparent_box, + } + }, + ) + + def test_routing(self) -> None: + scorer = Ai2dScorer() + responses = [ + self._response(1.0, abc_label=False, transparent_box=False), # both + self._response(0.0, abc_label=True, transparent_box=False), # opaque only + self._response(1.0, abc_label=True, transparent_box=True), # transparent only + ] + opaque = Ai2dMetric(name="mc_ai2d_opaque", scorer=scorer, transparent=False) + transparent = Ai2dMetric(name="mc_ai2d_transparent", scorer=scorer, transparent=True) + assert opaque.compute(responses) == pytest.approx(0.5) # (1.0 + 0.0) / 2 + assert transparent.compute(responses) == pytest.approx(1.0) # (1.0 + 1.0) / 2 diff --git a/tests/evals/tasks/test_image_qa_dump_parity.py b/tests/evals/tasks/test_image_qa_dump_parity.py new file mode 100644 index 00000000..4af44a7b --- /dev/null +++ b/tests/evals/tasks/test_image_qa_dump_parity.py @@ -0,0 +1,330 @@ +"""Parity tests against the released mm_olmo Molmo2-4B prediction dumps. + +For every image-QA benchmark this re-scores the predictions saved by the +*original* mm_olmo evaluation (``predictions-ck2000-*``) with the new +task/scorer/metric stack and asserts: + +1. **Prompt parity** — the user-turn text of each saved prompt equals the + ``instance.question`` produced by the new task (style prefixes, MC + formatting, and the PixMo-Count RNG templates must all match exactly). +2. **Metric parity** — the recomputed metrics equal the reference + ``metrics.json`` values within a small tolerance. + +The dumps are reference ground truth and are opened **read-only**; nothing +in this test writes to them. + +Opt-in: + + RUN_DUMP_PARITY_TESTS=1 \ + HF_DATASETS_CACHE=/weka/oe-training-default/mm-olmo/hf_datasets \ + HF_DATASETS_OFFLINE=1 \ + pytest tests/evals/tasks/test_image_qa_dump_parity.py -v + +``MOLMO2_PREDICTIONS_ROOT`` overrides the dump location (default: the +released Molmo2-4B directory). +""" + +from __future__ import annotations + +import asyncio +import json +import os +from pathlib import Path + +import pytest + +from olmo_eval.common.types import Instance, LMOutput, Response + +if not os.environ.get("RUN_DUMP_PARITY_TESTS"): + pytest.skip( + "Set RUN_DUMP_PARITY_TESTS=1 (and HF_DATASETS_CACHE for the HF-hub tasks) " + "to run dump-parity tests", + allow_module_level=True, + ) + +from olmo_eval.evals.tasks.common.registry import get_task # noqa: E402 + +DEFAULT_PREDICTIONS_ROOT = "/weka/oe-training-default/mm-olmo/released-models-molmo2-1225/Molmo2-4B" + +# Per-task plumbing: (task spec, dump dir name, join key fn, metric tolerance) +TOLERANCE_DEFAULT = 2e-4 +TOLERANCE_MMMU = 2e-3 + + +def _root() -> Path: + return Path(os.environ.get("MOLMO2_PREDICTIONS_ROOT", DEFAULT_PREDICTIONS_ROOT)) + + +def _load_dump(dump_name: str) -> tuple[list[dict], dict[str, float]]: + dump_dir = _root() / f"predictions-ck2000-{dump_name}" + if not dump_dir.exists(): + pytest.skip(f"reference dump not found: {dump_dir}") + with open(dump_dir / "predictions.json") as f: + rows = json.load(f) + with open(dump_dir / "metrics.json") as f: + metrics = json.load(f)["metrics"] + return rows, {k: v for k, v in metrics.items() if isinstance(v, (int, float))} + + +def _user_text(prompt: str) -> str: + """Extract the user-turn text from a decoded native prompt.""" + text = prompt.split("<|im_start|>user\n", 1)[1] + return text.split("<|im_end|>", 1)[0] + + +def _score_against_dump(task, joined: list[tuple[Instance, str]]) -> dict[str, float]: + responses = [ + Response( + instance=instance, + request=task.format_request(instance), + outputs=[LMOutput(text=prediction)], + ) + for instance, prediction in joined + ] + responses = asyncio.run(task.score_responses(responses)) + nested = task.compute_metrics(responses) + return {name: next(iter(by_scorer.values())) for name, by_scorer in nested.items()} + + +def _assert_metrics(mine: dict[str, float], ref: dict[str, float], tol: float) -> None: + compared = 0 + for name, value in mine.items(): + if name not in ref: + continue + assert value == pytest.approx(ref[name], abs=tol), ( + f"{name}: recomputed {value:.6f} != reference {ref[name]:.6f}" + ) + compared += 1 + assert compared > 0, "no overlapping metric names with the reference" + + +# --------------------------------------------------------------------------- +# Simple joined tasks: example_id-keyed, full prompt parity +# --------------------------------------------------------------------------- + + +def _join_by(instances, rows, instance_key, row_key): + by_key = {instance_key(inst): inst for inst in instances} + assert len(by_key) == len(instances), "join keys are not unique" + joined = [] + for row in rows: + key = row_key(row) + assert key in by_key, f"dump row {key!r} has no matching instance" + joined.append((by_key[key], row)) + assert len(joined) == len(rows) + return joined + + +@pytest.mark.parametrize( + ("spec", "dump_name", "tol"), + [ + ("chart_qa", "chart_qa-validation", TOLERANCE_DEFAULT), + ("vqa2", "coco_2014_vqa_8192-validation", TOLERANCE_DEFAULT), + ("doc_qa", "doc_qa-validation", TOLERANCE_DEFAULT), + ("info_qa", "info_qa-validation", TOLERANCE_DEFAULT), + ("text_vqa", "text_vqa-validation", TOLERANCE_DEFAULT), + ("mmmu", "mmmu_test-validation", TOLERANCE_MMMU), + ("ai2d", "ai2_diagram_v2_mix_transparent-validation", TOLERANCE_DEFAULT), + ("countbench_qa", "countbench_qa-huggingface", TOLERANCE_DEFAULT), + ("pixmo_count", "pixmo_count_counting-validation", TOLERANCE_DEFAULT), + ], +) +def test_dump_parity(spec: str, dump_name: str, tol: float) -> None: + rows, ref = _load_dump(dump_name) + task = get_task(spec) + instances = list(task.instances) + assert len(instances) == len(rows) + + if spec == "chart_qa": + joined = _join_by( + instances, + rows, + lambda inst: (inst.metadata["example_id"], inst.metadata["is_human"]), + lambda row: (row["example_id"], row["is_human"]), + ) + elif spec == "pixmo_count": + joined = _join_by( + instances, + rows, + lambda inst: inst.metadata["image_url"], + lambda row: row["image_url"], + ) + elif spec == "countbench_qa": + # image_url is not unique in CountBench; the dump saves the integer + # example_id under "image_id". + joined = _join_by( + instances, + rows, + lambda inst: inst.metadata["example_id"], + lambda row: row["image_id"], + ) + else: + joined = _join_by( + instances, + rows, + lambda inst: inst.metadata["example_id"], + lambda row: row["example_id"], + ) + + # 1. Prompt parity + mismatches = [ + (instance.metadata.get("example_id"), _user_text(row["prompt"]), instance.question) + for instance, row in joined + if _user_text(row["prompt"]) != instance.question + ] + assert not mismatches, ( + f"{len(mismatches)}/{len(joined)} prompt mismatches; first: {mismatches[0]}" + ) + + # 2. Metric parity + mine = _score_against_dump(task, [(inst, row["prediction"]) for inst, row in joined]) + _assert_metrics(mine, ref, tol) + + +# --------------------------------------------------------------------------- +# Unlabeled test-split variants (eval-server submissions): answers are not +# public, so only prompt parity is asserted against the native test dumps. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("spec", "dump_name"), + [ + ("doc_qa:test", "doc_qa-test-base_native_test"), + ("info_qa:test", "info_qa-test-base_native_test"), + ], +) +def test_dump_prompt_parity_unlabeled_test_split(spec: str, dump_name: str) -> None: + rows, _ = _load_dump(dump_name) + task = get_task(spec) + instances = list(task.instances) + assert len(instances) == len(rows) + + joined = _join_by( + instances, + rows, + lambda inst: inst.metadata["example_id"], + lambda row: row["example_id"], + ) + mismatches = [ + (instance.metadata["example_id"], _user_text(row["prompt"]), instance.question) + for instance, row in joined + if _user_text(row["prompt"]) != instance.question + ] + assert not mismatches, ( + f"{len(mismatches)}/{len(joined)} prompt mismatches; first: {mismatches[0]}" + ) + + +# --------------------------------------------------------------------------- +# RealWorldQA: the dump's `prompt` field is the *original* HF question (it is +# overwritten by metadata["prompt"] in SavePredictions), so prompt parity is +# checked against the documented derivation instead of the decoded input. +# --------------------------------------------------------------------------- + + +def test_dump_parity_real_world_qa() -> None: + rows, ref = _load_dump("real_world_qa_no_instruction-test") + task = get_task("real_world_qa") + instances = list(task.instances) + assert len(instances) == len(rows) + + # RealWorldQA has duplicate question texts, so join as a multiset keyed by + # (question, answer, question_type) — duplicates beyond that are + # interchangeable for scoring purposes. + pools: dict[tuple, list[Instance]] = {} + for inst in instances: + key = ( + inst.metadata["original_question"], + inst.metadata["answer"], + inst.metadata["question_type"], + ) + pools.setdefault(key, []).append(inst) + joined = [] + for row in rows: + key = (row["prompt"], row["answer"], row["question_type"]) + assert pools.get(key), f"dump row has no matching instance: {key[0][:80]!r}" + joined.append((pools[key].pop(), row)) + assert len(joined) == len(rows) + + for instance, row in joined: + original = row["prompt"] + if row["question_type"] == "short_answer": + expected = f"vqa2: {original.split(chr(10))[0]}" + else: + expected = original + assert instance.question == expected, instance.metadata["example_id"] + + mine = _score_against_dump(task, [(inst, row["prediction"]) for inst, row in joined]) + _assert_metrics(mine, ref, TOLERANCE_DEFAULT) + + +# --------------------------------------------------------------------------- +# MathVista: prompt parity is exact; the reference `score` (0.5670) used GPT-4 +# answer extraction, so the offline score is only asserted as a sanity band. +# The `math_vista:gpt` variant can be asserted against the reference with +# RUN_MATHVISTA_GPT_PARITY=1 + OPENAI_API_KEY (fresh API calls, own cache). +# --------------------------------------------------------------------------- + + +def test_dump_parity_math_vista_offline() -> None: + rows, ref = _load_dump("math_vista_v2-validation") + task = get_task("math_vista") + instances = list(task.instances) + assert len(instances) == len(rows) + + joined = _join_by( + instances, + rows, + lambda inst: inst.metadata["example_id"], + lambda row: row["example_id"], + ) + + mismatches = [ + (instance.metadata["example_id"], _user_text(row["prompt"]), instance.question) + for instance, row in joined + if _user_text(row["prompt"]) != instance.question + ] + assert not mismatches, ( + f"{len(mismatches)}/{len(joined)} prompt mismatches; first: {mismatches[0]}" + ) + + mine = _score_against_dump(task, [(inst, row["prediction"]) for inst, row in joined]) + # Offline extraction is not the GPT protocol that produced ref["score"]; + # assert a sanity band and report the delta. + assert mine["score"] >= 0.50, f"offline MathVista score suspiciously low: {mine['score']}" + print(f"math_vista offline={mine['score']:.4f} vs GPT reference={ref['score']:.4f}") + + +@pytest.mark.skipif( + not os.environ.get("RUN_MATHVISTA_GPT_PARITY"), + reason="Set RUN_MATHVISTA_GPT_PARITY=1 + OPENAI_API_KEY for GPT parity (~1000 API calls)", +) +def test_dump_parity_math_vista_gpt() -> None: + rows, ref = _load_dump("math_vista_v2-validation") + task = get_task("math_vista:gpt") + instances = list(task.instances) + + joined = _join_by( + instances, + rows, + lambda inst: inst.metadata["example_id"], + lambda row: row["example_id"], + ) + + from olmo_eval.common.execution import ScoringContext + + responses = [ + Response( + instance=instance, + request=task.format_request(instance), + outputs=[LMOutput(text=row["prediction"])], + ) + for instance, row in joined + ] + responses = asyncio.run(task.score_responses(responses, ScoringContext())) + nested = task.compute_metrics(responses) + score = next(iter(nested["score"].values())) + assert score == pytest.approx(ref["score"], abs=0.01), ( + f"GPT-extraction score {score:.4f} vs reference {ref['score']:.4f}" + ) diff --git a/tests/evals/tasks/test_image_qa_pipeline.py b/tests/evals/tasks/test_image_qa_pipeline.py new file mode 100644 index 00000000..2fc5618d --- /dev/null +++ b/tests/evals/tasks/test_image_qa_pipeline.py @@ -0,0 +1,93 @@ +"""Real-data pipeline tests for the 11 Molmo2 image-QA tasks. + +Opt-in (reads the shared mm-olmo data tree, strictly read-only): + + RUN_REAL_DATASET_TESTS=1 \ + HF_DATASETS_CACHE=/weka/oe-training-default/mm-olmo/hf_datasets \ + HF_DATASETS_OFFLINE=1 \ + pytest tests/evals/tasks/test_image_qa_pipeline.py -v + +Each task is checked for (a) the exact instance count of the original +mm_olmo eval split, and (b) oracle behavior: gold-answer responses score near +1.0 on the primary metric while corrupted responses score much lower. +""" + +from __future__ import annotations + +import asyncio +import os + +import pytest + +from olmo_eval.common.types import LMOutput, Response + +if not os.environ.get("RUN_REAL_DATASET_TESTS"): + pytest.skip( + "Set RUN_REAL_DATASET_TESTS=1 (and HF_DATASETS_CACHE for the HF-hub tasks) " + "to run real-data image-QA pipeline tests", + allow_module_level=True, + ) + +from olmo_eval.evals.tasks.common.registry import get_task # noqa: E402 + +# (task spec, expected instance count) +EXPECTED_COUNTS = [ + ("chart_qa", 1920), + ("vqa2", 8192), + ("doc_qa", 5349), + ("info_qa", 2801), + ("text_vqa", 5000), + ("real_world_qa", 765), + ("mmmu", 900), + ("math_vista", 1000), + ("countbench_qa", 490), + ("pixmo_count", 540), + ("ai2d", 1980), +] + +# Unlabeled test-split variants (predictions for eval-server submission): +# instance counts only — their metrics are computed against placeholder answers. +TEST_VARIANT_COUNTS = [ + ("doc_qa:test", 5188), + ("info_qa:test", 3288), +] + +CORRUPTED_TEXT = "the wrong answer entirely 424242" + + +def _score(task, texts): + responses = [] + for instance, text in zip(task.instances, texts, strict=True): + responses.append( + Response( + instance=instance, + request=task.format_request(instance), + outputs=[LMOutput(text=text)], + ) + ) + responses = asyncio.run(task.score_responses(responses)) + metrics = task.compute_metrics(responses) + primary = task.config.get_primary_metric() + return next(iter(metrics[primary.name].values())) + + +@pytest.mark.parametrize(("spec", "expected"), EXPECTED_COUNTS + TEST_VARIANT_COUNTS) +def test_instance_count(spec: str, expected: int) -> None: + task = get_task(spec) + assert len(list(task.instances)) == expected + + +@pytest.mark.parametrize(("spec", "_"), EXPECTED_COUNTS) +def test_oracle_beats_corrupted(spec: str, _: int) -> None: + # A slice is enough to separate oracle from corrupted decisively. + task = get_task(spec, {"limit": 64}) + instances = list(task.instances) + golds = [inst.gold_answer if inst.gold_answer is not None else "" for inst in instances] + + oracle = _score(task, golds) + corrupted = _score(task, [CORRUPTED_TEXT] * len(instances)) + + assert oracle >= 0.85, f"{spec}: oracle primary metric unexpectedly low ({oracle})" + assert corrupted <= oracle - 0.3, ( + f"{spec}: corrupted ({corrupted}) not clearly below oracle ({oracle})" + )