diff --git a/compass/llm/config.py b/compass/llm/config.py index 487f14a5e..2acf5e3e7 100644 --- a/compass/llm/config.py +++ b/compass/llm/config.py @@ -200,6 +200,14 @@ def client_kwargs(self): for key, env_var in arg_env_pairs: if self._client_kwargs.get(key) is None: self._client_kwargs[key] = os.environ.get(env_var) + elif self.client_type == "openai": + arg_env_pairs = [ + ("api_key", "OPENAI_API_KEY"), + ("base_url", "OPENAI_BASE_URL"), + ] + for key, env_var in arg_env_pairs: + if self._client_kwargs.get(key) is None: + self._client_kwargs[key] = os.environ.get(env_var) return self._client_kwargs diff --git a/compass/utilities/parsing.py b/compass/utilities/parsing.py index ce7e79f4b..37fb6e52e 100644 --- a/compass/utilities/parsing.py +++ b/compass/utilities/parsing.py @@ -39,16 +39,18 @@ def llm_response_as_json(content): Returns ------- - dict + object Parsed JSON structure. When parsing fails, the function returns an empty dictionary. Notes ----- The parser strips Markdown code fences, coerces Python-style - booleans to lowercase JSON literals, and logs the raw response on - decode failure. The logging includes guidance for increasing token - limits or updating prompts. + booleans to lowercase JSON literals, and first attempts strict JSON + decoding. If strict decoding fails, the parser attempts to recover + the first valid JSON object or array embedded in the response. If + recovery also fails, the raw response is logged with guidance for + prompt/token adjustments. """ content = clean_backticks_from_llm_response(content) content = content.removeprefix("json").lstrip("\n") @@ -56,6 +58,10 @@ def llm_response_as_json(content): try: content = json.loads(content) except json.decoder.JSONDecodeError: + parsed_content = _parse_first_json_payload(content) + if isinstance(parsed_content, dict): + return parsed_content + logger.exception( "LLM returned improperly formatted JSON. " "This is likely due to the completion running out of tokens. " @@ -68,6 +74,22 @@ def llm_response_as_json(content): return content +def _parse_first_json_payload(content): + """Parse first valid JSON payload embedded in text""" + decoder = json.JSONDecoder() + for start_ind, start_char in enumerate(content): + if start_char not in {"{", "["}: + continue + try: + parsed_content, __ = decoder.raw_decode(content[start_ind:]) + except json.decoder.JSONDecodeError: + continue + else: + return parsed_content + + return None + + def merge_overlapping_texts(text_chunks, n=300): """Merge text chunks while trimming overlapping boundaries diff --git a/tests/python/unit/llm/test_config.py b/tests/python/unit/llm/test_config.py new file mode 100644 index 000000000..210b426ab --- /dev/null +++ b/tests/python/unit/llm/test_config.py @@ -0,0 +1,54 @@ +"""Tests for LLM configuration helpers""" + +from pathlib import Path + +import pytest + +from compass.llm.config import OpenAIConfig + + +def test_openai_client_kwargs_loaded_from_env(monkeypatch): + """OpenAI kwargs can be populated from OPENAI_* env vars""" + monkeypatch.setenv("OPENAI_API_KEY", "test-openai-key") + monkeypatch.setenv("OPENAI_BASE_URL", "https://litellm.example.gov") + + config = OpenAIConfig(name="gpt-4o-mini", client_type="openai") + + assert config.client_kwargs["api_key"] == "test-openai-key" + assert config.client_kwargs["base_url"] == "https://litellm.example.gov" + + +def test_openai_client_kwargs_user_values_take_precedence(monkeypatch): + """Explicit client kwargs should not be replaced by env vars""" + monkeypatch.setenv("OPENAI_API_KEY", "env-key") + monkeypatch.setenv("OPENAI_BASE_URL", "https://env.example") + + config = OpenAIConfig( + name="gpt-4o-mini", + client_type="openai", + client_kwargs={ + "api_key": "user-key", + "base_url": "https://user.example", + }, + ) + + assert config.client_kwargs["api_key"] == "user-key" + assert config.client_kwargs["base_url"] == "https://user.example" + + +def test_azure_client_kwargs_unchanged(monkeypatch): + """Azure env var mapping remains unchanged""" + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "azure-key") + monkeypatch.setenv("AZURE_OPENAI_VERSION", "2024-02-15-preview") + monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://azure.example") + + config = OpenAIConfig(name="gpt-4o-mini", client_type="azure") + + assert config.client_kwargs["api_key"] == "azure-key" + assert config.client_kwargs["api_version"] == "2024-02-15-preview" + assert config.client_kwargs["azure_endpoint"] == "https://azure.example" + assert "base_url" not in config.client_kwargs + + +if __name__ == "__main__": + pytest.main(["-q", "--show-capture=all", Path(__file__), "-rapP"]) diff --git a/tests/python/unit/utilities/test_utilities_parsing.py b/tests/python/unit/utilities/test_utilities_parsing.py index 36afd4b0c..8cc1f0aed 100644 --- a/tests/python/unit/utilities/test_utilities_parsing.py +++ b/tests/python/unit/utilities/test_utilities_parsing.py @@ -43,6 +43,14 @@ def test_clean_backticks_from_llm_response(in_str, expected): ('{"a": True', {}), ('json\n{"key": "value"}', {"key": "value"}), ('{"a": True, "b": False}', {"a": True, "b": False}), + ( + ( + "I can extract date information from the URL provided. " + "However, the URL does not contain date information.\n\n" + '{"year": null, "month": null, "day": null}' + ), + {"year": None, "month": None, "day": None}, + ), ], ) def test_llm_response_as_json(in_str, expected):