Skip to content
45 changes: 19 additions & 26 deletions src/galileo/handlers/langchain/async_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from galileo.handlers.base_async_handler import GalileoAsyncBaseHandler
from galileo.handlers.langchain.handler import GalileoCallback
from galileo.handlers.langchain.utils import get_agent_name, is_agent_node, update_root_to_agent
from galileo.handlers.langchain.utils import get_agent_name, is_agent_node, parse_llm_result, update_root_to_agent
from galileo.logger import GalileoLogger
from galileo.schema.handlers import NODE_TYPE
from galileo.schema.trace import TracesIngestRequest
Expand Down Expand Up @@ -102,13 +102,13 @@ async def on_chain_end(
# The input is sent via kwargs in on_chain_end in async streaming mode
if "inputs" in kwargs:
kwargs["input"] = serialize_to_str(kwargs["inputs"])
await self._handler.async_end_node(run_id, output=serialize_to_str(outputs), **kwargs)
await self._handler.async_end_node(run_id, output=serialize_to_str(outputs), status_code=200, **kwargs)

async def on_agent_finish(self, finish: AgentFinish, *, run_id: UUID, **kwargs: Any) -> Any:
"""Langchain callback when an agent finishes."""
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id
await self._handler.async_end_node(run_id, output=serialize_to_str(finish))
await self._handler.async_end_node(run_id, output=serialize_to_str(finish), status_code=200)

async def on_llm_start(
self,
Expand Down Expand Up @@ -216,21 +216,14 @@ async def on_llm_end(
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id

token_usage = response.llm_output.get("token_usage", {}) if response.llm_output else {}

try:
flattened_messages = [message for batch in response.generations for message in batch]
output = json.loads(json.dumps(flattened_messages[0], cls=EventSerializer))
except Exception as e:
_logger.warning(f"Failed to serialize LLM output: {e}")
output = str(response.generations)

result = parse_llm_result(response)
await self._handler.async_end_node(
run_id,
output=output,
num_input_tokens=token_usage.get("prompt_tokens"),
num_output_tokens=token_usage.get("completion_tokens"),
total_tokens=token_usage.get("total_tokens"),
output=result.output,
num_input_tokens=result.num_input_tokens,
num_output_tokens=result.num_output_tokens,
total_tokens=result.total_tokens,
status_code=200,
)

async def on_tool_start(
Expand Down Expand Up @@ -270,7 +263,7 @@ async def on_tool_end(
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id

end_node_kwargs = {}
end_node_kwargs: dict[str, Any] = {"status_code": 200}
if (tool_message := GalileoCallback._find_tool_message(output)) is not None:
end_node_kwargs["output"] = tool_message.content
end_node_kwargs["tool_call_id"] = tool_message.tool_call_id
Expand Down Expand Up @@ -330,36 +323,36 @@ async def on_retriever_end(
_logger.warning(f"Failed to serialize retriever output: {e}")
serialized_response = str(documents)

await self._handler.async_end_node(run_id, output=serialized_response)
await self._handler.async_end_node(run_id, output=serialized_response, status_code=200)

async def on_chain_error(
self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
) -> Any:
"""Called when a chain errors."""
"""Langchain callback when a chain errors."""
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id
await self._handler.async_end_node(run_id, output=f"Error: {error!s}")
await self._handler.async_end_node(run_id, output=f"Error: {error!s}", status_code=400)

async def on_llm_error(
self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
) -> Any:
"""Called when an LLM errors."""
"""Langchain callback when an LLM errors."""
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id
await self._handler.async_end_node(run_id, output=f"Error: {error!s}")
await self._handler.async_end_node(run_id, output=f"Error: {error!s}", status_code=400)

async def on_tool_error(
self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
) -> Any:
"""Called when a tool errors."""
"""Langchain callback when a tool errors."""
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id
await self._handler.async_end_node(run_id, output=f"Error: {error!s}")
await self._handler.async_end_node(run_id, output=f"Error: {error!s}", status_code=400)

async def on_retriever_error(
self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
) -> Any:
"""Called when a retriever errors."""
"""Langchain callback when a retriever errors."""
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id
await self._handler.async_end_node(run_id, output=f"Error: {error!s}")
await self._handler.async_end_node(run_id, output=f"Error: {error!s}", status_code=400)
61 changes: 43 additions & 18 deletions src/galileo/handlers/langchain/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from uuid import UUID

from galileo.handlers.base_handler import GalileoBaseHandler
from galileo.handlers.langchain.utils import get_agent_name, is_agent_node, update_root_to_agent
from galileo.handlers.langchain.utils import get_agent_name, is_agent_node, parse_llm_result, update_root_to_agent
from galileo.logger import GalileoLogger
from galileo.schema.handlers import LANGCHAIN_NODE_TYPE, NODE_TYPE
from galileo.schema.trace import TracesIngestRequest
Expand Down Expand Up @@ -123,13 +123,13 @@ def on_chain_end(
"""Langchain callback when a chain ends."""
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id
self._handler.end_node(run_id, output=serialize_to_str(outputs))
self._handler.end_node(run_id, output=serialize_to_str(outputs), status_code=200)

def on_agent_finish(self, finish: AgentFinish, *, run_id: UUID, **kwargs: Any) -> Any:
"""Langchain callback when an agent finishes."""
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id
self._handler.end_node(run_id, output=serialize_to_str(finish))
self._handler.end_node(run_id, output=serialize_to_str(finish), status_code=200)

def on_llm_start(
self,
Expand Down Expand Up @@ -237,21 +237,14 @@ def on_llm_end(
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id

token_usage = response.llm_output.get("token_usage", {}) if response.llm_output else {}

try:
flattened_messages = [message for batch in response.generations for message in batch]
output = json.loads(json.dumps(flattened_messages[0], cls=EventSerializer))
except Exception as e:
_logger.warning(f"Failed to serialize LLM output: {e}")
output = str(response.generations)

result = parse_llm_result(response)
self._handler.end_node(
run_id,
output=output,
num_input_tokens=token_usage.get("prompt_tokens"),
num_output_tokens=token_usage.get("completion_tokens"),
total_tokens=token_usage.get("total_tokens"),
output=result.output,
num_input_tokens=result.num_input_tokens,
num_output_tokens=result.num_output_tokens,
total_tokens=result.total_tokens,
status_code=200,
)

def on_tool_start(
Expand Down Expand Up @@ -320,7 +313,7 @@ def on_tool_end(self, output: Any, *, run_id: UUID, parent_run_id: Optional[UUID
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id

end_node_kwargs = {}
end_node_kwargs: dict[str, Any] = {"status_code": 200}
if (tool_message := self._find_tool_message(output)) is not None:
end_node_kwargs["output"] = tool_message.content
end_node_kwargs["tool_call_id"] = tool_message.tool_call_id
Expand Down Expand Up @@ -380,7 +373,7 @@ def on_retriever_end(
_logger.warning(f"Failed to serialize retriever output: {e}")
serialized_response = str(documents)

self._handler.end_node(run_id, output=serialized_response)
self._handler.end_node(run_id, output=serialized_response, status_code=200)

def _get_agent_name(self, parent_run_id: Optional[UUID], node_name: str) -> str:
if parent_run_id is not None:
Expand All @@ -390,3 +383,35 @@ def _get_agent_name(self, parent_run_id: Optional[UUID], node_name: str) -> str:
if parent:
return parent.span_params["name"] + ":" + node_name
return node_name

def on_chain_error(
self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
) -> Any:
"""Langchain callback when a chain errors."""
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id
self._handler.end_node(run_id, output=f"Error: {error!s}", status_code=400)

def on_llm_error(
self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
) -> Any:
"""Langchain callback when an LLM errors."""
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id
self._handler.end_node(run_id, output=f"Error: {error!s}", status_code=400)

def on_tool_error(
self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
) -> Any:
"""Langchain callback when a tool errors."""
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id
self._handler.end_node(run_id, output=f"Error: {error!s}", status_code=400)

def on_retriever_error(
self, error: Exception, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
) -> Any:
"""Langchain callback when a retriever errors."""
# Convert UUID7 to UUID4 if needed
run_id = convert_uuid_if_uuid7(run_id) or run_id
self._handler.end_node(run_id, output=f"Error: {error!s}", status_code=400)
61 changes: 58 additions & 3 deletions src/galileo/handlers/langchain/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from typing import Any, Optional
from __future__ import annotations

import json
import logging
from dataclasses import dataclass
from typing import Any
from uuid import UUID

from galileo.schema.handlers import Node
from galileo.utils.serialization import EventSerializer
from galileo.utils.uuid_utils import convert_uuid_if_uuid7

_logger = logging.getLogger(__name__)


def get_agent_name(parent_run_id: Optional[UUID], node_name: str, nodes: dict[str, Node]) -> str:
def get_agent_name(parent_run_id: UUID | None, node_name: str, nodes: dict[str, Node]) -> str:
if parent_run_id is not None:
# Convert UUID7 to UUID4 if needed
parent_run_id = convert_uuid_if_uuid7(parent_run_id) or parent_run_id
Expand All @@ -19,7 +27,7 @@ def is_agent_node(node_name: str) -> bool:
return node_name.lower() in ["langgraph", "agent"]


def update_root_to_agent(parent_run_id: Optional[UUID], metadata: dict[str, Any], parent_node: Optional[Node]) -> None:
def update_root_to_agent(parent_run_id: UUID | None, metadata: dict[str, Any], parent_node: Node | None) -> None:
"""Update the parent node to be an agent if it is a root-level chain and has LangGraph metadata in the children.

Parameters
Expand All @@ -34,3 +42,50 @@ def update_root_to_agent(parent_run_id: Optional[UUID], metadata: dict[str, Any]
# Only convert to agent if parent is a root-level chain (no parent of its own)
if parent_node and parent_node.node_type == "chain" and parent_node.parent_run_id is None:
parent_node.node_type = "agent"


@dataclass
class LLMEndResult:
output: Any
num_input_tokens: int | None
num_output_tokens: int | None
total_tokens: int | None


def parse_llm_result(response: Any) -> LLMEndResult:
"""Extract serialized output and token metrics from a LangChain LLMResult.

Handles three token-usage sources (checked in order):
1. ``response.llm_output["token_usage"]`` with OpenAI keys (``prompt_tokens`` / ``completion_tokens``).
2. Same dict with GCP Vertex AI keys (``input_tokens`` / ``output_tokens``).
3. ``ChatGeneration.message.usage_metadata`` when ``llm_output`` carries no usage.

Parameters
----------
response
A ``langchain_core.outputs.LLMResult`` (typed as ``Any`` to avoid import).
"""
token_usage: dict[str, Any] = response.llm_output.get("token_usage", {}) if response.llm_output else {}

try:
flattened_messages = [message for batch in response.generations for message in batch]
first_message = flattened_messages[0] if flattened_messages else None
if first_message is None:
# Empty generations - fall back to stringified representation
output = str(response.generations)
else:
output = json.loads(json.dumps(first_message, cls=EventSerializer))
if not token_usage and hasattr(first_message, "message"):
message_token_usage = getattr(getattr(first_message, "message", {}), "usage_metadata", None)
if message_token_usage:
token_usage = {**token_usage, **message_token_usage}
except Exception as e:
_logger.warning(f"Failed to serialize LLM output: {e}")
output = str(response.generations)

return LLMEndResult(
output=output,
num_input_tokens=token_usage.get("prompt_tokens") or token_usage.get("input_tokens"),
num_output_tokens=token_usage.get("completion_tokens") or token_usage.get("output_tokens"),
total_tokens=token_usage.get("total_tokens"),
)
1 change: 1 addition & 0 deletions src/galileo/utils/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ def is_dependency_available(name: str) -> bool:

is_langchain_available = is_dependency_available("langchain_core")
is_langgraph_available = is_dependency_available("langgraph")
is_proto_plus_available = is_dependency_available("proto")
18 changes: 15 additions & 3 deletions src/galileo/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from pydantic import BaseModel

from galileo.utils.dependencies import is_langchain_available, is_langgraph_available
from galileo.utils.dependencies import is_langchain_available, is_langgraph_available, is_proto_plus_available

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -237,11 +237,23 @@ def default(self, obj: Any) -> Any:
if isinstance(obj, list):
return [self.default(item) for item in obj]

# Important: this needs to be always checked after str and bytes types
# Useful for serializing protobuf messages
# Important: this needs to be always checked after str and bytes types.
# Bare protobuf messages (google.protobuf.message.Message) implement
# Sequence, so this branch handles them as iterables. Proto-plus
# messages (proto.Message) are handled separately below.
if isinstance(obj, Sequence):
return [self.default(item) for item in obj]

# Handle proto-plus messages (e.g. google.cloud.aiplatform types).
# Their __dict__ only contains private attrs, so the generic
# __dict__ serialization below would produce empty objects.
if is_proto_plus_available:
# Lazy import: proto-plus is an optional dependency used by GCP tool classes (e.g. google.cloud.aiplatform)
import proto

if isinstance(obj, proto.Message):
return self.default(proto.Message.to_dict(obj, use_integers_for_enums=False))

if hasattr(obj, "__slots__") and len(obj.__slots__) > 0:
return self.default({slot: getattr(obj, slot, None) for slot in obj.__slots__})

Expand Down
Loading
Loading