diff --git a/src/cleanlab_tlm/internal/constants.py b/src/cleanlab_tlm/internal/constants.py index 9c751f0..06f4826 100644 --- a/src/cleanlab_tlm/internal/constants.py +++ b/src/cleanlab_tlm/internal/constants.py @@ -107,7 +107,12 @@ _TLM_EVAL_QUERY_IDENTIFIER_KEY: str = "query_identifier" _TLM_EVAL_CONTEXT_IDENTIFIER_KEY: str = "context_identifier" _TLM_EVAL_RESPONSE_IDENTIFIER_KEY: str = "response_identifier" +_TLM_EVAL_MODE_KEY: str = "mode" # Values that wont support logging explanation by default _REASONING_EFFORT_UNSUPPORTED_EXPLANATION_LOGGING: set[str] = {"none", "minimal"} _QUALITY_PRESETS_UNSUPPORTED_EXPLANATION_LOGGING: set[str] = {"low", "base"} # For regular TLM not TrustworthyRAG + +# RAG Evals modes binary/continuous +_CONTINUOUS_STR = "continuous" +_BINARY_STR = "binary" diff --git a/src/cleanlab_tlm/utils/rag.py b/src/cleanlab_tlm/utils/rag.py index e3f2f28..2dfc1f6 100644 --- a/src/cleanlab_tlm/utils/rag.py +++ b/src/cleanlab_tlm/utils/rag.py @@ -11,13 +11,13 @@ from __future__ import annotations import asyncio +import warnings from collections.abc import Sequence from typing import ( - # lazydocs: ignore TYPE_CHECKING, Any, Callable, - Optional, + Optional, # lazydocs: ignore Union, cast, ) @@ -29,9 +29,12 @@ from cleanlab_tlm.internal.api import api from cleanlab_tlm.internal.base import BaseTLM from cleanlab_tlm.internal.constants import ( + _BINARY_STR, + _CONTINUOUS_STR, _DEFAULT_TLM_QUALITY_PRESET, _TLM_EVAL_CONTEXT_IDENTIFIER_KEY, _TLM_EVAL_CRITERIA_KEY, + _TLM_EVAL_MODE_KEY, _TLM_EVAL_NAME_KEY, _TLM_EVAL_QUERY_IDENTIFIER_KEY, _TLM_EVAL_RESPONSE_IDENTIFIER_KEY, @@ -47,6 +50,7 @@ validate_logging, validate_rag_inputs, ) +from cleanlab_tlm.tlm import TLM if TYPE_CHECKING: from collections.abc import Coroutine @@ -124,6 +128,7 @@ def __init__( query_identifier=eval_config.get(_TLM_EVAL_QUERY_IDENTIFIER_KEY), context_identifier=eval_config.get(_TLM_EVAL_CONTEXT_IDENTIFIER_KEY), response_identifier=eval_config.get(_TLM_EVAL_RESPONSE_IDENTIFIER_KEY), + mode=eval_config.get(_TLM_EVAL_MODE_KEY), ) for eval_config in _DEFAULT_EVALS ] @@ -863,6 +868,13 @@ class Eval: response_identifier (str, optional): The exact string used in your evaluation `criteria` to reference the RAG/LLM response. For example, specifying `response_identifier` as "AI Answer" means your `criteria` should refer to the response as "AI Answer". Leave this value as None (the default) if this Eval doesn't consider the response. + mode (str, optional): What type of evaluation these `criteria` correspond to, either "continuous" (default), "binary", or "auto". + - "continuous": For `criteria` that define what is good/better v.s. what is bad/worse, corresponding to evaluations of quality along a continuous spectrum (e.g., relevance, conciseness). + - "binary": For `criteria` written as Yes/No questions, corresponding to evaluations that most would consider either True or False rather than grading along a continuous spectrum (e.g., does Response mention ACME Inc., is Query asking about refund, ...). + - "auto": Automatically determines whether the criteria is binary or continuous based on the criteria text. + Both modes return scores in the 0-1 range. + For "continuous" evaluations, your `criteria` should define what good vs. bad looks like (cases deemed bad will return low evaluation scores). + For binary evaluations, your `criteria` should be a Yes/No question (cases answered "Yes" will return low evaluation scores, so phrase your question such that the likelihood of "Yes" matches the likelihood of the particular problem you wish to detect). Note on handling Tool Calls: By default, when a tool call response is detected, evaluations that analyze the response content (those with a `response_identifier`) are assigned `score=None`. You can override this behavior for specific evals via @@ -876,6 +888,7 @@ def __init__( query_identifier: Optional[str] = None, context_identifier: Optional[str] = None, response_identifier: Optional[str] = None, + mode: Optional[str] = "auto", ): """ lazydocs: ignore @@ -892,6 +905,189 @@ def __init__( self.context_identifier = context_identifier self.response_identifier = response_identifier + # Compile and validate the eval + self.mode = self._compile_mode(mode, criteria, name) + + def _compile_mode(self, mode: Optional[str], criteria: str, name: str) -> str: + """ + Compile and validate the mode based on criteria. + + Args: + mode: The specified mode ("binary", "continuous", or "auto") + criteria: The evaluation criteria text + name: The name of the evaluation + + Returns: + str: The compiled mode ("binary" or "continuous") + """ + + # Check binary criteria once at the beginning + is_binary = self._check_binary_criteria(criteria) + + # If mode is auto, determine it automatically + if mode == "auto": + compiled_mode = _BINARY_STR if is_binary else _CONTINUOUS_STR + + # Check if it's appropriate for neither + if not is_binary: + has_good_bad = self._check_good_bad_specified(criteria) + has_numeric = self._check_numeric_scoring_scheme(criteria) + + if not has_good_bad and not has_numeric: + warning_msg = ( + f"Eval '{name}': Criteria does not appear to be a Yes/No question " + "and does not clearly specify what is good/bad or desirable/undesirable. " + "This may result in poor evaluation quality." + ) + warnings.warn(warning_msg, UserWarning) + + return compiled_mode + + # Validation checks for explicit mode specification + if mode == _BINARY_STR: + if not is_binary: + warning_msg = ( + f"Eval '{name}': mode is set to '{_BINARY_STR}' but criteria does not appear " + "to be a Yes/No question. Consider rephrasing as a Yes/No question or " + f"changing mode to '{_CONTINUOUS_STR}'." + ) + warnings.warn(warning_msg, UserWarning) + + elif mode == _CONTINUOUS_STR: + # Check if it's actually a Yes/No question + if is_binary: + warning_msg = ( + f"Eval '{name}': mode is set to '{_CONTINUOUS_STR}' but criteria appears to be " + f"a Yes/No question. Consider changing mode to '{_BINARY_STR}' for more appropriate scoring." + ) + warnings.warn(warning_msg, UserWarning) + + # Check if good/bad is specified + has_good_bad = self._check_good_bad_specified(criteria) + if not has_good_bad: + warning_msg = ( + f"Eval '{name}': mode is set to '{_CONTINUOUS_STR}' but criteria does not clearly " + "specify what is good/desirable versus bad/undesirable. This may lead to " + "inconsistent or unclear scoring." + ) + warnings.warn(warning_msg, UserWarning) + + # Check if it already has a numeric scoring scheme + has_numeric = self._check_numeric_scoring_scheme(criteria) + if has_numeric: + warning_msg = ( + f"Eval '{name}': Your `criteria` appears to specify " + "a numeric scoring scheme. We recommend removing any " + "specific numeric scoring scheme from your `criteria` and just specifying what is considered good/better vs. bad/worse." + ) + warnings.warn(warning_msg, UserWarning) + + # For explicit modes, return as-is (already validated above) + if mode in (_BINARY_STR, _CONTINUOUS_STR): + return mode + + # Default to continuous for None or any other value + return _CONTINUOUS_STR + + @staticmethod + def _check_binary_criteria(criteria: str) -> bool: + """ + Check if criteria is a Yes/No question using TLM. + + Args: + criteria: The evaluation criteria text + + Returns: + True if criteria is a Yes/No question, False otherwise + """ + tlm = TLM(quality_preset="base") + + prompt = f"""Consider the following statement: + + + {criteria} + + + ## Instructions + + Classify this statement into one of the following options: + A) This statement is essentially worded as a Yes/No question or implies a Yes/No question. + B) This statement is not a Yes/No question, since replying to it with either "Yes" or "No" would not be sensible. + + Your output must be one choice from either A or B (output only a single letter, no other text).""" + + response = tlm.prompt(prompt, constrain_outputs=["A", "B"]) + if isinstance(response, list): + return False + response_text = response.get("response", "") + if response_text is None: + return False + return str(response_text).strip().upper() == "A" + + @staticmethod + def _check_good_bad_specified(criteria: str) -> bool: + """ + Check if criteria clearly specifies what is Good vs Bad or Desirable vs Undesirable. + + Args: + criteria: The evaluation criteria text + + Returns: + True if criteria clearly defines good/bad or desirable/undesirable, False otherwise + """ + tlm = TLM(quality_preset="base") + + prompt = f"""Analyze the following evaluation criteria and determine if it clearly specifies what is "good" versus "bad", "desirable" versus "undesirable", "better" versus "worse", or uses similar language to define quality distinctions. + + The criteria should make it clear what characteristics or qualities are considered positive/desirable versus negative/undesirable. + + Evaluation Criteria: + {criteria} + + Does this criteria clearly specify what is good/desirable versus bad/undesirable? Answer only "Yes" or "No".""" + + response = tlm.prompt(prompt, constrain_outputs=["Yes", "No"]) + if isinstance(response, list): + return False + response_text = response.get("response", "") + if response_text is None: + return False + return str(response_text).strip().lower() == "yes" + + @staticmethod + def _check_numeric_scoring_scheme(criteria: str) -> bool: + """ + Check if criteria contains a specific numeric scoring scheme (e.g., "rate from 1-5", "score 0-100"). + + Args: + criteria: The evaluation criteria text + + Returns: + True if criteria includes a numeric scoring scheme, False otherwise + """ + tlm = TLM(quality_preset="base") + + prompt = f"""Analyze the following evaluation criteria and determine if it contains a specific numeric scoring scheme. + + Examples of numeric scoring schemes include: + - "Rate from 1 to 5" + - "Score between 0 and 100" + - "Assign a rating of 1-10" + - "Give a score from 0 to 1" + + Evaluation Criteria: + {criteria} + + Does this criteria specify a numeric scoring scheme? Answer only "Yes" or "No".""" + + response = tlm.prompt(prompt, constrain_outputs=["Yes", "No"]) + if isinstance(response, list): + return False + response_text = response.get("response", "") + if response_text is None: + return False + return str(response_text).strip().lower() == "yes" + def __repr__(self) -> str: """ Return a string representation of the Eval object in dictionary format. @@ -906,6 +1102,7 @@ def __repr__(self) -> str: f" 'query_identifier': {self.query_identifier!r},\n" f" 'context_identifier': {self.context_identifier!r},\n" f" 'response_identifier': {self.response_identifier!r}\n" + f" 'mode': {self.mode!r}\n" f"}}" ) @@ -917,6 +1114,7 @@ def __repr__(self) -> str: "query_identifier": "Question", "context_identifier": "Document", "response_identifier": None, + "mode": _BINARY_STR, }, { "name": "response_groundedness", @@ -924,6 +1122,7 @@ def __repr__(self) -> str: "query_identifier": "Query", "context_identifier": "Context", "response_identifier": "Response", + "mode": _CONTINUOUS_STR, }, { "name": "response_helpfulness", @@ -933,6 +1132,7 @@ def __repr__(self) -> str: "query_identifier": "User Query", "context_identifier": None, "response_identifier": "AI Assistant Response", + "mode": _CONTINUOUS_STR, }, { "name": "query_ease", @@ -944,6 +1144,7 @@ def __repr__(self) -> str: "query_identifier": "User Request", "context_identifier": None, "response_identifier": None, + "mode": _CONTINUOUS_STR, }, ] @@ -976,6 +1177,7 @@ def get_default_evals() -> list[Eval]: query_identifier=eval_config.get("query_identifier"), context_identifier=eval_config.get("context_identifier"), response_identifier=eval_config.get("response_identifier"), + mode=eval_config.get("mode"), ) for eval_config in _DEFAULT_EVALS ] diff --git a/tests/test_tlm_rag.py b/tests/test_tlm_rag.py index 2232fa5..37a2f91 100644 --- a/tests/test_tlm_rag.py +++ b/tests/test_tlm_rag.py @@ -1,6 +1,6 @@ import os import re -from collections.abc import Generator +from collections.abc import Generator, Mapping, Sequence from typing import Any, cast from unittest import mock @@ -10,6 +10,7 @@ from cleanlab_tlm.internal.api import api from cleanlab_tlm.internal.constants import ( _TLM_DEFAULT_MODEL, + _TLM_EVAL_MODE_KEY, _VALID_TLM_QUALITY_PRESETS, ) from cleanlab_tlm.tlm import TLMOptions @@ -1084,3 +1085,158 @@ def test_tool_call_override_invalid_name_raises(trustworthy_rag: TrustworthyRAG) ), ): trustworthy_rag._configure_tool_call_eval_overrides(exclude_names=[existing_eval_name, "not_a_real_eval"]) + + +def test_eval_mode_defaults_to_continuous() -> None: + e = Eval( + name="helpfulness", + criteria="Rate if the AI Answer is helpful to the User Question using the Retrieved Context.", + query_identifier="User Question", + context_identifier="Retrieved Context", + response_identifier="AI Answer", + ) + # default should be continuous + assert e.mode in ( + None, + "continuous", + ), "Eval.mode should default to 'continuous' (or None treated as continuous)" + + +def test_eval_mode_binary_set_and_persisted() -> None: + e = Eval( + name="mentions_company", + criteria="Does the AI Answer mention any company names? Answer Yes/No.", + query_identifier="User Question", + response_identifier="AI Answer", + mode="binary", + ) + assert e.mode == "binary" + assert e.response_identifier == "AI Answer" + assert e.query_identifier == "User Question" + + +@pytest.mark.asyncio +async def test_api_binary_and_continuous_mix_roundtrip_payload() -> None: + """Mix of modes should be preserved per-eval in payload.""" + evals = [ + Eval( + name="response_helpfulness", + criteria="Rate helpfulness from 0-1.", + query_identifier="Question", + context_identifier="Context", + response_identifier="Answer", + mode="continuous", + ), + Eval( + name="mentions_company", + criteria="Does the Answer mention any company? Yes/No.", + response_identifier="Answer", + mode="binary", + ), + ] + + mock_resp_json = { + "trustworthiness": {"score": 0.9}, + "response_helpfulness": {"score": 0.8}, + "mentions_company": {"score": 0.0}, + } + + mock_response = mock.MagicMock() + mock_response.status = 200 + mock_response.json = mock.AsyncMock(return_value=mock_resp_json) + + mock_session = mock.MagicMock() + mock_session.post = mock.AsyncMock(return_value=mock_response) + mock_session.close = mock.AsyncMock() + + mock_rate_handler = mock.MagicMock() + mock_rate_handler.__aenter__ = mock.AsyncMock() + mock_rate_handler.__aexit__ = mock.AsyncMock() + + with mock.patch("aiohttp.ClientSession", return_value=mock_session): + result = await api.tlm_rag_score( + api_key="k", + response={"response": "A"}, + prompt=None, + query="Q?", + context="C", + evals=evals, + quality_preset="medium", + options=None, + rate_handler=mock_rate_handler, + ) + + assert set(result.keys()) >= { + "trustworthiness", + "response_helpfulness", + "mentions_company", + } + + call_args = mock_session.post.call_args + payload = call_args[1]["json"] + sent_evals = {e["name"]: e for e in payload.get("evals", [])} + assert sent_evals["response_helpfulness"].get( + "mode", sent_evals["response_helpfulness"].get(_TLM_EVAL_MODE_KEY) + ) in ( + None, + "continuous", + ) + assert ( + sent_evals["mentions_company"].get("mode", sent_evals["mentions_company"].get(_TLM_EVAL_MODE_KEY)) == "binary" + ) + + +def test_score_modes_explicit(trustworthy_rag_api_key: str) -> None: + """Ensure both continuous and binary evals are accepted and scored (0..1).""" + evals = [ + Eval( + name="response_helpfulness", + criteria="Rate helpfulness from 0 to 1.", + query_identifier="Question", + context_identifier="Context", + response_identifier="Answer", + mode="continuous", + ), + Eval( + name="mentions_company", + criteria="Does the Answer mention a company? Yes/No.", + response_identifier="Answer", + mode="binary", + ), + ] + + rag = TrustworthyRAG(api_key=trustworthy_rag_api_key, evals=evals) + raw_score = rag.score(query=test_query, context=test_context, response=test_response) + + assert is_trustworthy_rag_score(raw_score) + + # --- Normalize to: Dict[str, Mapping[str, Any]] --- + scores_by_name: dict[str, Mapping[str, Any]] = {} + + if isinstance(raw_score, list): + # e.g. [{"name": "response_helpfulness", "score": 0.7}, ...] + for e in raw_score: + if isinstance(e, dict): + name = e.get("name") + if isinstance(name, str): + scores_by_name[name] = cast(Mapping[str, Any], e) + elif isinstance(raw_score, dict): + # Could be dict[str, ...] OR dict[EvalMetric, ...] + # Case A: string keys + all_str_keys = all(isinstance(k, str) for k in raw_score) + if all_str_keys: + for k, v in raw_score.items(): + scores_by_name[k] = cast(Mapping[str, Any], v) + else: + # Case B: enum/non-str keys → align by order with our 'evals' list + # Dicts preserve insertion order; assume provider returns in same order as 'evals' + values_in_order: Sequence[Any] = list(raw_score.values()) + for ev, v in zip(evals, values_in_order): + if isinstance(v, dict): + scores_by_name[ev.name] = cast(Mapping[str, Any], v) + + # --- Validate both evals exist and have score ∈ [0,1] or None --- + for expected in ("response_helpfulness", "mentions_company"): + assert expected in scores_by_name, f"{expected} missing in result" + s = scores_by_name[expected].get("score") + assert (s is None) or (isinstance(s, float) and 0.0 <= s <= 1.0), f"Invalid score for {expected}: {s}"