Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/olmo_eval/common/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,28 @@ class LMOutput:
"""Output from a language model.

Supports both text generation and tool calling outputs.

The provider_extras dict holds provider-specific fields (e.g., has_reasoning
flag from reasoning models). Only providers that need it populate this field.
"""

text: str
logprobs: list[LogProbEntry] | None = None
extracted_answer: Any = None
metadata: dict[str, Any] = field(default_factory=dict)
tool_calls: list[ToolCall] | None = None
provider_extras: dict[str, Any] = field(default_factory=dict)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a dict set up for provider specific (in this case litellm) properties from the response. Only thing add in this is has_reasoning. I suppose other providers could also implement this on LMOutput (i.e. embedded thinking tag). Mainly, I am wondering if this is okay to add this dict (can rename, maybe can just be in metadata, but that looks like its for metric calls).


@property
def has_tool_calls(self) -> bool:
"""Check if this output contains tool calls."""
return self.tool_calls is not None and len(self.tool_calls) > 0

@property
def has_reasoning(self) -> bool:
"""Check if this output contains reasoning content."""
return self.provider_extras.get("has_reasoning", False)


@dataclass(slots=True)
class Response:
Expand Down
183 changes: 183 additions & 0 deletions src/olmo_eval/evals/tasks/response_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""Response tests for verifying model response properties.

Simple tests to verify models respond correctly with expected properties.
"""

from __future__ import annotations

from collections.abc import Iterator
from dataclasses import dataclass
from typing import ClassVar

from olmo_eval.common.formatters import CompletionFormatter
from olmo_eval.common.metrics import AccuracyMetric
from olmo_eval.common.scorers import Scorer, SubstringRecallScorer, ToolCallScorer
from olmo_eval.common.types import (
Instance,
LMOutput,
LMRequest,
RequestType,
SamplingParams,
ToolSchema,
)
from olmo_eval.evals.tasks.common import Task, register


@dataclass(frozen=True, slots=True)
class NonEmptyResponseScorer(Scorer):
"""Score 1.0 if model produced a non-empty response, else 0.0."""

name: ClassVar[str] = "non_empty_response"

def score(self, instance: Instance, output: LMOutput) -> float:
return 1.0 if output.text and output.text.strip() else 0.0


@dataclass(frozen=True, slots=True)
class ReasoningResponseScorer(Scorer):
"""Score 1.0 if model produced reasoning content, else 0.0.

This verifies that reasoning models correctly return their chain-of-thought
in the reasoning field of the response.
"""

name: ClassVar[str] = "reasoning_present"

def score(self, instance: Instance, output: LMOutput) -> float:
return 1.0 if output.has_reasoning else 0.0


# =============================================================================
# Content Verification Response Test
# =============================================================================


@register("response_match")
class ResponseContentVerify(Task):
"""Verify that model responses contain expected content.

- Use without data_source (default): Asks "Who are you?" and checks for non-empty response

- Use with adhoc data_source: Loads prompts and expected substrings from file
and checks that each response contains the expected substring.

Data file format (JSONL):
{"question": "Who are you?", "expected_substring": "OLMo"}
"""

sampling_params = SamplingParams(temperature=0.0, max_tokens=1024)
formatter = CompletionFormatter(template="User: {question}\nAssistant:")
metrics = (
AccuracyMetric(scorer=SubstringRecallScorer),
AccuracyMetric(scorer=NonEmptyResponseScorer),
)
primary_metric = AccuracyMetric(scorer=SubstringRecallScorer)

def process_doc(self, doc: dict, index: int = 0) -> Instance:
return Instance(
question=doc["question"],
gold_answer=doc.get("expected_substring", ""),
metadata={"id": f"response_match_{index}", "check_type": "substring"},
)

@property
def instances(self) -> Iterator[Instance]:
if self.config.data_source is not None:
yield from self._load_instances()
else:
yield Instance(
question="Who are you?",
gold_answer="",
metadata={"id": "response_match_default", "check_type": "substring"},
)

def format_request(self, instance: Instance) -> LMRequest:
if self.config.formatter is not None:
return self.config.formatter.format(instance, self.get_fewshot())
return LMRequest(request_type=self.request_type, prompt=instance.question)


# =============================================================================
# Tool Calling Response Test
# =============================================================================

# Weather tool schema for testing tool calls
_WEATHER_TOOL = ToolSchema(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
)


@register("response_toolcall")
class ResponseToolCall(Task):
"""Response test: can the model make tool calls?

Verifies that the model can correctly invoke a tool when provided with
a tool schema. The test asks about weather, expecting the model to call
the get_current_weather tool.
"""

sampling_params = SamplingParams(temperature=0.0)
metrics = (AccuracyMetric(scorer=ToolCallScorer),)

@property
def instances(self) -> Iterator[Instance]:
yield Instance(
question="What's the weather like in Seattle?",
gold_answer="",
expected_tool_calls=({"name": "get_current_weather"},),
metadata={"id": "toolcall", "check_type": "tool_call"},
)

def format_request(self, instance: Instance) -> LMRequest:
return LMRequest(
request_type=RequestType.COMPLETION,
messages=({"role": "user", "content": instance.question},),
tools=(_WEATHER_TOOL,),
)


# =============================================================================
# Reasoning Response Test
# =============================================================================


@register("response_reasoning")
class ResponseReasoning(Task):
"""Response test: does the model return reasoning content?

Verifies that reasoning models correctly parse and return their
chain-of-thought reasoning in the response. This test asks a simple
question and checks that the reasoning field is populated.
"""

sampling_params = SamplingParams(temperature=0.0)
metrics = (AccuracyMetric(scorer=ReasoningResponseScorer),)

@property
def instances(self) -> Iterator[Instance]:
yield Instance(
question="Who are you?",
gold_answer="",
metadata={"id": "reasoning", "check_type": "reasoning_present"},
)

def format_request(self, instance: Instance) -> LMRequest:
return LMRequest(
request_type=RequestType.COMPLETION,
messages=({"role": "user", "content": instance.question},),
)
36 changes: 33 additions & 3 deletions src/olmo_eval/inference/providers/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from olmo_eval.common.debug import is_debug_provider
from olmo_eval.common.logging import get_logger
from olmo_eval.common.types import LMOutput, LMRequest, LogProbEntry, SamplingParams
from olmo_eval.common.types import LMOutput, LMRequest, LogProbEntry, SamplingParams, ToolCall
from olmo_eval.inference.base import InferenceProvider
from olmo_eval.inference.retry import retry_with_backoff
from olmo_eval.inference.utils import run_async
Expand Down Expand Up @@ -134,7 +134,11 @@ async def _generate_single_impl(
kwargs["temperature"] = params.temperature
if params.stop_sequences:
kwargs["stop"] = list(params.stop_sequences)[:_MAX_STOP_SEQUENCES]
# Always request logprobs for metrics computation

# Pass tools if provided in the request
if request.tools:
kwargs["tools"] = [tool.to_openai() for tool in request.tools]

kwargs["logprobs"] = True
kwargs["top_logprobs"] = (
1 # NOTE: workaround for litellm proxy issue https://github.com/BerriAI/litellm/issues/21932
Expand Down Expand Up @@ -168,7 +172,33 @@ async def _generate_single_impl(
"num_tokens_all": num_tokens,
}

outputs.append(LMOutput(text=text, logprobs=logprob_entries, metadata=metadata))
# Extract tool calls from response
tool_calls: list[ToolCall] | None = None
message_tool_calls = getattr(choice.message, "tool_calls", None)
if message_tool_calls:
tool_calls = [ToolCall.from_openai(tc.model_dump()) for tc in message_tool_calls]

# Check for reasoning content (for reasoning models)
has_reasoning = False
message_content = getattr(choice.message, "content", None)
if message_content is not None:
if getattr(message_content, "reasoning", None):
has_reasoning = True
if getattr(message_content, "reasoning_content", None):
has_reasoning = True
# Also check directly on message for reasoning_content (some APIs use this)
if not has_reasoning and getattr(choice.message, "reasoning_content", None):
has_reasoning = True

outputs.append(
LMOutput(
text=text,
logprobs=logprob_entries,
metadata=metadata,
tool_calls=tool_calls,
provider_extras={"has_reasoning": True} if has_reasoning else {},
)
)

return outputs

Expand Down
Loading