diff --git a/src/galileo/handlers/langchain/async_handler.py b/src/galileo/handlers/langchain/async_handler.py index 6e7e753a..db38b0f6 100644 --- a/src/galileo/handlers/langchain/async_handler.py +++ b/src/galileo/handlers/langchain/async_handler.py @@ -6,7 +6,7 @@ from galileo.handlers.base_async_handler import GalileoAsyncBaseHandler from galileo.handlers.langchain.handler import GalileoCallback -from galileo.handlers.langchain.utils import get_agent_name, is_agent_node, update_root_to_agent +from galileo.handlers.langchain.utils import get_agent_name, is_agent_node, parse_llm_result, update_root_to_agent from galileo.logger import GalileoLogger from galileo.schema.handlers import NODE_TYPE from galileo.schema.trace import TracesIngestRequest @@ -102,13 +102,13 @@ async def on_chain_end( # The input is sent via kwargs in on_chain_end in async streaming mode if "inputs" in kwargs: kwargs["input"] = serialize_to_str(kwargs["inputs"]) - await self._handler.async_end_node(run_id, output=serialize_to_str(outputs), **kwargs) + await self._handler.async_end_node(run_id, output=serialize_to_str(outputs), status_code=200, **kwargs) async def on_agent_finish(self, finish: AgentFinish, *, run_id: UUID, **kwargs: Any) -> Any: """Langchain callback when an agent finishes.""" # Convert UUID7 to UUID4 if needed run_id = convert_uuid_if_uuid7(run_id) or run_id - await self._handler.async_end_node(run_id, output=serialize_to_str(finish)) + await self._handler.async_end_node(run_id, output=serialize_to_str(finish), status_code=200) async def on_llm_start( self, @@ -216,21 +216,14 @@ async def on_llm_end( # Convert UUID7 to UUID4 if needed run_id = convert_uuid_if_uuid7(run_id) or run_id - token_usage = response.llm_output.get("token_usage", {}) if response.llm_output else {} - - try: - flattened_messages = [message for batch in response.generations for message in batch] - output = json.loads(json.dumps(flattened_messages[0], cls=EventSerializer)) - except Exception as e: - _logger.warning(f"Failed to serialize LLM output: {e}") - output = str(response.generations) - + result = parse_llm_result(response) await self._handler.async_end_node( run_id, - output=output, - num_input_tokens=token_usage.get("prompt_tokens"), - num_output_tokens=token_usage.get("completion_tokens"), - total_tokens=token_usage.get("total_tokens"), + output=result.output, + num_input_tokens=result.num_input_tokens, + num_output_tokens=result.num_output_tokens, + total_tokens=result.total_tokens, + status_code=200, ) async def on_tool_start( @@ -270,7 +263,7 @@ async def on_tool_end( # Convert UUID7 to UUID4 if needed run_id = convert_uuid_if_uuid7(run_id) or run_id - end_node_kwargs = {} + end_node_kwargs: dict[str, Any] = {"status_code": 200} if (tool_message := GalileoCallback._find_tool_message(output)) is not None: end_node_kwargs["output"] = tool_message.content end_node_kwargs["tool_call_id"] = tool_message.tool_call_id @@ -330,36 +323,36 @@ async def on_retriever_end( _logger.warning(f"Failed to serialize retriever output: {e}") serialized_response = str(documents) - await self._handler.async_end_node(run_id, output=serialized_response) + await self._handler.async_end_node(run_id, output=serialized_response, status_code=200) async def on_chain_error( self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any ) -> Any: - """Called when a chain errors.""" + """Langchain callback when a chain errors.""" # Convert UUID7 to UUID4 if needed run_id = convert_uuid_if_uuid7(run_id) or run_id - await self._handler.async_end_node(run_id, output=f"Error: {error!s}") + await self._handler.async_end_node(run_id, output=f"Error: {error!s}", status_code=400) async def on_llm_error( self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any ) -> Any: - """Called when an LLM errors.""" + """Langchain callback when an LLM errors.""" # Convert UUID7 to UUID4 if needed run_id = convert_uuid_if_uuid7(run_id) or run_id - await self._handler.async_end_node(run_id, output=f"Error: {error!s}") + await self._handler.async_end_node(run_id, output=f"Error: {error!s}", status_code=400) async def on_tool_error( self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any ) -> Any: - """Called when a tool errors.""" + """Langchain callback when a tool errors.""" # Convert UUID7 to UUID4 if needed run_id = convert_uuid_if_uuid7(run_id) or run_id - await self._handler.async_end_node(run_id, output=f"Error: {error!s}") + await self._handler.async_end_node(run_id, output=f"Error: {error!s}", status_code=400) async def on_retriever_error( self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any ) -> Any: - """Called when a retriever errors.""" + """Langchain callback when a retriever errors.""" # Convert UUID7 to UUID4 if needed run_id = convert_uuid_if_uuid7(run_id) or run_id - await self._handler.async_end_node(run_id, output=f"Error: {error!s}") + await self._handler.async_end_node(run_id, output=f"Error: {error!s}", status_code=400) diff --git a/src/galileo/handlers/langchain/handler.py b/src/galileo/handlers/langchain/handler.py index 959ac6dc..5da79b81 100644 --- a/src/galileo/handlers/langchain/handler.py +++ b/src/galileo/handlers/langchain/handler.py @@ -5,7 +5,7 @@ from uuid import UUID from galileo.handlers.base_handler import GalileoBaseHandler -from galileo.handlers.langchain.utils import get_agent_name, is_agent_node, update_root_to_agent +from galileo.handlers.langchain.utils import get_agent_name, is_agent_node, parse_llm_result, update_root_to_agent from galileo.logger import GalileoLogger from galileo.schema.handlers import LANGCHAIN_NODE_TYPE, NODE_TYPE from galileo.schema.trace import TracesIngestRequest @@ -123,13 +123,13 @@ def on_chain_end( """Langchain callback when a chain ends.""" # Convert UUID7 to UUID4 if needed run_id = convert_uuid_if_uuid7(run_id) or run_id - self._handler.end_node(run_id, output=serialize_to_str(outputs)) + self._handler.end_node(run_id, output=serialize_to_str(outputs), status_code=200) def on_agent_finish(self, finish: AgentFinish, *, run_id: UUID, **kwargs: Any) -> Any: """Langchain callback when an agent finishes.""" # Convert UUID7 to UUID4 if needed run_id = convert_uuid_if_uuid7(run_id) or run_id - self._handler.end_node(run_id, output=serialize_to_str(finish)) + self._handler.end_node(run_id, output=serialize_to_str(finish), status_code=200) def on_llm_start( self, @@ -237,21 +237,14 @@ def on_llm_end( # Convert UUID7 to UUID4 if needed run_id = convert_uuid_if_uuid7(run_id) or run_id - token_usage = response.llm_output.get("token_usage", {}) if response.llm_output else {} - - try: - flattened_messages = [message for batch in response.generations for message in batch] - output = json.loads(json.dumps(flattened_messages[0], cls=EventSerializer)) - except Exception as e: - _logger.warning(f"Failed to serialize LLM output: {e}") - output = str(response.generations) - + result = parse_llm_result(response) self._handler.end_node( run_id, - output=output, - num_input_tokens=token_usage.get("prompt_tokens"), - num_output_tokens=token_usage.get("completion_tokens"), - total_tokens=token_usage.get("total_tokens"), + output=result.output, + num_input_tokens=result.num_input_tokens, + num_output_tokens=result.num_output_tokens, + total_tokens=result.total_tokens, + status_code=200, ) def on_tool_start( @@ -320,7 +313,7 @@ def on_tool_end(self, output: Any, *, run_id: UUID, parent_run_id: Optional[UUID # Convert UUID7 to UUID4 if needed run_id = convert_uuid_if_uuid7(run_id) or run_id - end_node_kwargs = {} + end_node_kwargs: dict[str, Any] = {"status_code": 200} if (tool_message := self._find_tool_message(output)) is not None: end_node_kwargs["output"] = tool_message.content end_node_kwargs["tool_call_id"] = tool_message.tool_call_id @@ -380,7 +373,7 @@ def on_retriever_end( _logger.warning(f"Failed to serialize retriever output: {e}") serialized_response = str(documents) - self._handler.end_node(run_id, output=serialized_response) + self._handler.end_node(run_id, output=serialized_response, status_code=200) def _get_agent_name(self, parent_run_id: Optional[UUID], node_name: str) -> str: if parent_run_id is not None: @@ -390,3 +383,35 @@ def _get_agent_name(self, parent_run_id: Optional[UUID], node_name: str) -> str: if parent: return parent.span_params["name"] + ":" + node_name return node_name + + def on_chain_error( + self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any + ) -> Any: + """Langchain callback when a chain errors.""" + # Convert UUID7 to UUID4 if needed + run_id = convert_uuid_if_uuid7(run_id) or run_id + self._handler.end_node(run_id, output=f"Error: {error!s}", status_code=400) + + def on_llm_error( + self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any + ) -> Any: + """Langchain callback when an LLM errors.""" + # Convert UUID7 to UUID4 if needed + run_id = convert_uuid_if_uuid7(run_id) or run_id + self._handler.end_node(run_id, output=f"Error: {error!s}", status_code=400) + + def on_tool_error( + self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any + ) -> Any: + """Langchain callback when a tool errors.""" + # Convert UUID7 to UUID4 if needed + run_id = convert_uuid_if_uuid7(run_id) or run_id + self._handler.end_node(run_id, output=f"Error: {error!s}", status_code=400) + + def on_retriever_error( + self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any + ) -> Any: + """Langchain callback when a retriever errors.""" + # Convert UUID7 to UUID4 if needed + run_id = convert_uuid_if_uuid7(run_id) or run_id + self._handler.end_node(run_id, output=f"Error: {error!s}", status_code=400) diff --git a/src/galileo/handlers/langchain/utils.py b/src/galileo/handlers/langchain/utils.py index 1d20c02b..810d76e8 100644 --- a/src/galileo/handlers/langchain/utils.py +++ b/src/galileo/handlers/langchain/utils.py @@ -1,11 +1,19 @@ -from typing import Any, Optional +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from typing import Any from uuid import UUID from galileo.schema.handlers import Node +from galileo.utils.serialization import EventSerializer from galileo.utils.uuid_utils import convert_uuid_if_uuid7 +_logger = logging.getLogger(__name__) + -def get_agent_name(parent_run_id: Optional[UUID], node_name: str, nodes: dict[str, Node]) -> str: +def get_agent_name(parent_run_id: UUID | None, node_name: str, nodes: dict[str, Node]) -> str: if parent_run_id is not None: # Convert UUID7 to UUID4 if needed parent_run_id = convert_uuid_if_uuid7(parent_run_id) or parent_run_id @@ -19,7 +27,7 @@ def is_agent_node(node_name: str) -> bool: return node_name.lower() in ["langgraph", "agent"] -def update_root_to_agent(parent_run_id: Optional[UUID], metadata: dict[str, Any], parent_node: Optional[Node]) -> None: +def update_root_to_agent(parent_run_id: UUID | None, metadata: dict[str, Any], parent_node: Node | None) -> None: """Update the parent node to be an agent if it is a root-level chain and has LangGraph metadata in the children. Parameters @@ -34,3 +42,50 @@ def update_root_to_agent(parent_run_id: Optional[UUID], metadata: dict[str, Any] # Only convert to agent if parent is a root-level chain (no parent of its own) if parent_node and parent_node.node_type == "chain" and parent_node.parent_run_id is None: parent_node.node_type = "agent" + + +@dataclass +class LLMEndResult: + output: Any + num_input_tokens: int | None + num_output_tokens: int | None + total_tokens: int | None + + +def parse_llm_result(response: Any) -> LLMEndResult: + """Extract serialized output and token metrics from a LangChain LLMResult. + + Handles three token-usage sources (checked in order): + 1. ``response.llm_output["token_usage"]`` with OpenAI keys (``prompt_tokens`` / ``completion_tokens``). + 2. Same dict with GCP Vertex AI keys (``input_tokens`` / ``output_tokens``). + 3. ``ChatGeneration.message.usage_metadata`` when ``llm_output`` carries no usage. + + Parameters + ---------- + response + A ``langchain_core.outputs.LLMResult`` (typed as ``Any`` to avoid import). + """ + token_usage: dict[str, Any] = response.llm_output.get("token_usage", {}) if response.llm_output else {} + + try: + flattened_messages = [message for batch in response.generations for message in batch] + first_message = flattened_messages[0] if flattened_messages else None + if first_message is None: + # Empty generations - fall back to stringified representation + output = str(response.generations) + else: + output = json.loads(json.dumps(first_message, cls=EventSerializer)) + if not token_usage and hasattr(first_message, "message"): + message_token_usage = getattr(getattr(first_message, "message", {}), "usage_metadata", None) + if message_token_usage: + token_usage = {**token_usage, **message_token_usage} + except Exception as e: + _logger.warning(f"Failed to serialize LLM output: {e}") + output = str(response.generations) + + return LLMEndResult( + output=output, + num_input_tokens=token_usage.get("prompt_tokens") or token_usage.get("input_tokens"), + num_output_tokens=token_usage.get("completion_tokens") or token_usage.get("output_tokens"), + total_tokens=token_usage.get("total_tokens"), + ) diff --git a/src/galileo/utils/dependencies.py b/src/galileo/utils/dependencies.py index bbc79970..36f178b8 100644 --- a/src/galileo/utils/dependencies.py +++ b/src/galileo/utils/dependencies.py @@ -7,3 +7,4 @@ def is_dependency_available(name: str) -> bool: is_langchain_available = is_dependency_available("langchain_core") is_langgraph_available = is_dependency_available("langgraph") +is_proto_plus_available = is_dependency_available("proto") diff --git a/src/galileo/utils/serialization.py b/src/galileo/utils/serialization.py index e6eef962..6828fc76 100644 --- a/src/galileo/utils/serialization.py +++ b/src/galileo/utils/serialization.py @@ -13,7 +13,7 @@ from pydantic import BaseModel -from galileo.utils.dependencies import is_langchain_available, is_langgraph_available +from galileo.utils.dependencies import is_langchain_available, is_langgraph_available, is_proto_plus_available _logger = logging.getLogger(__name__) @@ -237,11 +237,23 @@ def default(self, obj: Any) -> Any: if isinstance(obj, list): return [self.default(item) for item in obj] - # Important: this needs to be always checked after str and bytes types - # Useful for serializing protobuf messages + # Important: this needs to be always checked after str and bytes types. + # Bare protobuf messages (google.protobuf.message.Message) implement + # Sequence, so this branch handles them as iterables. Proto-plus + # messages (proto.Message) are handled separately below. if isinstance(obj, Sequence): return [self.default(item) for item in obj] + # Handle proto-plus messages (e.g. google.cloud.aiplatform types). + # Their __dict__ only contains private attrs, so the generic + # __dict__ serialization below would produce empty objects. + if is_proto_plus_available: + # Lazy import: proto-plus is an optional dependency used by GCP tool classes (e.g. google.cloud.aiplatform) + import proto + + if isinstance(obj, proto.Message): + return self.default(proto.Message.to_dict(obj, use_integers_for_enums=False)) + if hasattr(obj, "__slots__") and len(obj.__slots__) > 0: return self.default({slot: getattr(obj, slot, None) for slot in obj.__slots__}) diff --git a/tests/test_langchain.py b/tests/test_langchain.py index 54955084..31ebac2f 100644 --- a/tests/test_langchain.py +++ b/tests/test_langchain.py @@ -12,7 +12,7 @@ from galileo import Message, MessageRole, galileo_context from galileo.handlers.langchain import GalileoCallback -from galileo.handlers.langchain.utils import update_root_to_agent +from galileo.handlers.langchain.utils import parse_llm_result, update_root_to_agent from galileo.logger.logger import GalileoLogger from galileo.schema.handlers import Node from galileo.utils.uuid_utils import uuid7_to_uuid4 @@ -1160,6 +1160,142 @@ def test_updates_with_single_langgraph_key(self) -> None: assert parent_node.node_type == "agent" +class TestParseLlmResult: + """Tests for the parse_llm_result utility (GCP Vertex AI token metrics support).""" + + def test_openai_style_token_usage(self) -> None: + """Test standard OpenAI token keys in llm_output.""" + # Given: an LLMResult with OpenAI-style token_usage + response = LLMResult( + generations=[[ChatGeneration(message=AIMessage(content="hello"))]], + llm_output={"token_usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}}, + ) + + # When: parsing the result + result = parse_llm_result(response) + + # Then: standard keys are used + assert result.num_input_tokens == 10 + assert result.num_output_tokens == 20 + assert result.total_tokens == 30 + + def test_gcp_style_token_usage(self) -> None: + """Test GCP Vertex AI token keys (input_tokens/output_tokens) in llm_output.""" + # Given: an LLMResult with GCP-style token keys + response = LLMResult( + generations=[[ChatGeneration(message=AIMessage(content="hello"))]], + llm_output={"token_usage": {"input_tokens": 15, "output_tokens": 25, "total_tokens": 40}}, + ) + + # When: parsing the result + result = parse_llm_result(response) + + # Then: GCP keys are used as fallback + assert result.num_input_tokens == 15 + assert result.num_output_tokens == 25 + assert result.total_tokens == 40 + + def test_openai_keys_take_precedence_over_gcp_keys(self) -> None: + """Test that prompt_tokens/completion_tokens take precedence when both key styles exist.""" + # Given: an LLMResult with both OpenAI and GCP token keys + response = LLMResult( + generations=[[ChatGeneration(message=AIMessage(content="hello"))]], + llm_output={ + "token_usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "input_tokens": 99, + "output_tokens": 99, + "total_tokens": 30, + } + }, + ) + + # When: parsing the result + result = parse_llm_result(response) + + # Then: OpenAI keys take precedence + assert result.num_input_tokens == 10 + assert result.num_output_tokens == 20 + + def test_token_usage_from_message_usage_metadata(self) -> None: + """Test fallback to message.usage_metadata when llm_output has no token_usage.""" + # Given: an LLMResult with no llm_output but usage_metadata on the message + ai_message = AIMessage(content="hello") + ai_message.usage_metadata = {"input_tokens": 5, "output_tokens": 12, "total_tokens": 17} + response = LLMResult(generations=[[ChatGeneration(message=ai_message)]], llm_output=None) + + # When: parsing the result + + result = parse_llm_result(response) + + # Then: usage_metadata values are extracted + assert result.num_input_tokens == 5 + assert result.num_output_tokens == 12 + assert result.total_tokens == 17 + + def test_no_token_usage_anywhere(self) -> None: + """Test that None is returned when no token usage is available.""" + # Given: an LLMResult with no token information + response = LLMResult(generations=[[ChatGeneration(message=AIMessage(content="hello"))]], llm_output=None) + + # When: parsing the result + result = parse_llm_result(response) + + # Then: all token fields are None + assert result.num_input_tokens is None + assert result.num_output_tokens is None + assert result.total_tokens is None + assert result.output is not None + + def test_empty_generations(self) -> None: + """Test graceful handling when generations list is empty.""" + # Given: an LLMResult with empty generations + response = LLMResult( + generations=[[]], + llm_output={"token_usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}}, + ) + + # When: parsing the result + result = parse_llm_result(response) + + # Then: output is stringified fallback (never None), tokens are still extracted + assert result.output == "[[]]" + assert result.num_input_tokens == 10 + assert result.num_output_tokens == 20 + + def test_completely_empty_generations(self) -> None: + """Test graceful handling when generations is completely empty (no batches).""" + # Given: an LLMResult with no generation batches at all + response = LLMResult( + generations=[], + llm_output={"token_usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}}, + ) + + # When: parsing the result + result = parse_llm_result(response) + + # Then: output is stringified fallback (never None), tokens are still extracted + assert result.output == "[]" + assert result.num_input_tokens == 5 + assert result.num_output_tokens == 10 + + def test_llm_output_empty_token_usage_with_usage_metadata(self) -> None: + """Test fallback to usage_metadata when llm_output exists but token_usage is empty.""" + # Given: an LLMResult with empty token_usage in llm_output + ai_message = AIMessage(content="response") + ai_message.usage_metadata = {"input_tokens": 8, "output_tokens": 16, "total_tokens": 24} + response = LLMResult(generations=[[ChatGeneration(message=ai_message)]], llm_output={"token_usage": {}}) + + # When: parsing the result + result = parse_llm_result(response) + + # Then: usage_metadata values are extracted as fallback + assert result.num_input_tokens == 8 + assert result.num_output_tokens == 16 + assert result.total_tokens == 24 + + class TestGalileoCallbackIngestionHookWithoutCredentials: """SC-54690: GalileoCallback/GalileoAsyncCallback with ingestion_hook should work without API credentials. diff --git a/tests/utils/test_serialization.py b/tests/utils/test_serialization.py index b9b19705..21e75465 100644 --- a/tests/utils/test_serialization.py +++ b/tests/utils/test_serialization.py @@ -723,3 +723,63 @@ class ModelWithoutSchema(BaseModel): # Should return class name when schema method is not callable assert result == "" + + +try: + import proto as _proto + + class _ProtoAuthor(_proto.Message): + name = _proto.Field(_proto.STRING, number=1) + + class _ProtoBook(_proto.Message): + title = _proto.Field(_proto.STRING, number=1) + page_count = _proto.Field(_proto.INT32, number=2) + + class _ProtoBookWithAuthor(_proto.Message): + title = _proto.Field(_proto.STRING, number=1) + author = _proto.Field(_ProtoAuthor, number=2) + + class _ProtoEmpty(_proto.Message): + name = _proto.Field(_proto.STRING, number=1) + + _has_proto = True +except ImportError: + _has_proto = False + + +@pytest.mark.skipif(not _has_proto, reason="proto-plus not installed") +class TestProtoMessageSerialization: + """Test serialization of proto-plus messages (e.g. Google Cloud AI Platform types).""" + + def test_proto_plus_message_serialization(self) -> None: + """Test that proto-plus messages are serialized via proto.Message.to_dict.""" + # Given: a proto-plus message with fields set + book = _ProtoBook(title="Great Expectations", page_count=432) + + # When: serializing the proto-plus message + result = json.loads(json.dumps(book, cls=EventSerializer)) + + # Then: the message fields are properly serialized + assert result == {"title": "Great Expectations", "page_count": 432} + + def test_proto_plus_empty_message_serialization(self) -> None: + """Test that proto-plus messages with unset fields produce protobuf defaults, not empty dicts.""" + # Given: a proto-plus message with no fields set + msg = _ProtoEmpty() + + # When: serializing the empty message + result = json.loads(json.dumps(msg, cls=EventSerializer)) + + # Then: unset fields use their protobuf defaults (empty string for STRING) + assert result == {"name": ""} + + def test_proto_plus_nested_message_serialization(self) -> None: + """Test that nested proto-plus messages are serialized recursively.""" + # Given: a proto-plus message with a nested message + book = _ProtoBookWithAuthor(title="Great Expectations", author=_ProtoAuthor(name="Dickens")) + + # When: serializing the nested message + result = json.loads(json.dumps(book, cls=EventSerializer)) + + # Then: nested message is properly serialized + assert result == {"title": "Great Expectations", "author": {"name": "Dickens"}}