diff --git a/tests/agent/guardrails/actions/test_block_action.py b/tests/agent/guardrails/actions/test_block_action.py index d6b3f1e2..988fc964 100644 --- a/tests/agent/guardrails/actions/test_block_action.py +++ b/tests/agent/guardrails/actions/test_block_action.py @@ -35,3 +35,69 @@ async def test_node_name_and_exception_pre_llm(self): # The exception string is the provided reason assert str(excinfo.value) == "Sensitive data detected" + + @pytest.mark.asyncio + async def test_node_name_and_exception_post_llm(self): + """PostExecution + LLM: name is sanitized and node raises correct exception.""" + action = BlockAction(reason="Invalid output detected") + guardrail = MagicMock() + guardrail.name = "Output Guardrail v2" + + node_name, node = action.action_node( + guardrail=guardrail, + scope=GuardrailScope.LLM, + execution_stage=ExecutionStage.POST_EXECUTION, + guarded_component_name="guarded_node_name", + ) + + assert node_name == "llm_post_execution_output_guardrail_v2_block" + + with pytest.raises(AgentTerminationException) as excinfo: + await node(AgentGuardrailsGraphState(messages=[])) + + # The exception string is the provided reason + assert str(excinfo.value) == "Invalid output detected" + + @pytest.mark.asyncio + async def test_node_name_and_exception_pre_tool(self): + """PreExecution + TOOL: name is sanitized and node raises correct exception.""" + action = BlockAction(reason="Tool input validation failed") + guardrail = MagicMock() + guardrail.name = "Tool-Safety@Check" + + node_name, node = action.action_node( + guardrail=guardrail, + scope=GuardrailScope.TOOL, + execution_stage=ExecutionStage.PRE_EXECUTION, + guarded_component_name="test_tool", + ) + + assert node_name == "tool_pre_execution_tool_safety_check_block" + + with pytest.raises(AgentTerminationException) as excinfo: + await node(AgentGuardrailsGraphState(messages=[])) + + # The exception string is the provided reason + assert str(excinfo.value) == "Tool input validation failed" + + @pytest.mark.asyncio + async def test_node_name_and_exception_post_tool(self): + """PostExecution + TOOL: name is sanitized and node raises correct exception.""" + action = BlockAction(reason="Tool output validation failed") + guardrail = MagicMock() + guardrail.name = "Output-Validator#2024" + + node_name, node = action.action_node( + guardrail=guardrail, + scope=GuardrailScope.TOOL, + execution_stage=ExecutionStage.POST_EXECUTION, + guarded_component_name="test_tool", + ) + + assert node_name == "tool_post_execution_output_validator_2024_block" + + with pytest.raises(AgentTerminationException) as excinfo: + await node(AgentGuardrailsGraphState(messages=[])) + + # The exception string is the provided reason + assert str(excinfo.value) == "Tool output validation failed" diff --git a/tests/agent/guardrails/actions/test_log_action.py b/tests/agent/guardrails/actions/test_log_action.py index cf172ff3..48117a69 100644 --- a/tests/agent/guardrails/actions/test_log_action.py +++ b/tests/agent/guardrails/actions/test_log_action.py @@ -78,3 +78,188 @@ async def test_default_message_includes_context( == "Guardrail [My Guardrail] validation failed for [TOOL] [POST_EXECUTION] with the following reason: bad input" for rec in caplog.records ) + + @pytest.mark.asyncio + async def test_node_name_and_exception_post_llm( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """PostExecution + LLM: name is sanitized and default message is logged.""" + action = LogAction(message=None, level=logging.INFO) + guardrail = MagicMock() + guardrail.name = "Test Guardrail" + + node_name, node = action.action_node( + guardrail=guardrail, + scope=GuardrailScope.LLM, + execution_stage=ExecutionStage.POST_EXECUTION, + guarded_component_name="guarded_node_name", + ) + + # Verify node name format + assert node_name == "llm_post_execution_test_guardrail_log" + assert isinstance(node_name, str) + assert node_name.endswith("_log") + assert "llm" in node_name + assert "post_execution" in node_name + + # Verify node is callable + assert callable(node) + + # Verify node returns empty dict + with caplog.at_level(logging.INFO): + await node( + AgentGuardrailsGraphState( + messages=[], guardrail_validation_result="validation error" + ) + ) + + # Verify log record properties + log_record = caplog.records[0] + assert log_record.levelno == logging.INFO + + # Verify default message includes all context + assert ( + "Guardrail [Test Guardrail] validation failed for [LLM] [POST_EXECUTION]" + in log_record.message + ) + assert "validation error" in log_record.message + assert ( + log_record.message + == "Guardrail [Test Guardrail] validation failed for [LLM] [POST_EXECUTION] with the following reason: validation error" + ) + + @pytest.mark.asyncio + async def test_node_name_and_exception_pre_tool( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """PreExecution + TOOL: name is sanitized and default message is logged.""" + action = LogAction(message=None, level=logging.WARNING) + guardrail = MagicMock() + guardrail.name = "Tool Guardrail v2" + + node_name, node = action.action_node( + guardrail=guardrail, + scope=GuardrailScope.TOOL, + execution_stage=ExecutionStage.PRE_EXECUTION, + guarded_component_name="test_tool", + ) + + # Verify node name format + assert node_name == "tool_pre_execution_tool_guardrail_v2_log" + assert isinstance(node_name, str) + assert node_name.endswith("_log") + assert "tool" in node_name + assert "pre_execution" in node_name + + # Verify node returns empty dict + with caplog.at_level(logging.WARNING): + await node( + AgentGuardrailsGraphState( + messages=[], guardrail_validation_result="invalid tool args" + ) + ) + + # Verify log record properties + log_record = caplog.records[0] + assert log_record.levelno == logging.WARNING + + # Verify default message includes all context + assert ( + "Guardrail [Tool Guardrail v2] validation failed for [TOOL] [PRE_EXECUTION]" + in log_record.message + ) + assert "invalid tool args" in log_record.message + assert ( + log_record.message + == "Guardrail [Tool Guardrail v2] validation failed for [TOOL] [PRE_EXECUTION] with the following reason: invalid tool args" + ) + + @pytest.mark.asyncio + async def test_node_name_and_exception_post_tool( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """PostExecution + TOOL: name is sanitized and custom message is logged.""" + action = LogAction(message="Tool execution failed", level=logging.ERROR) + guardrail = MagicMock() + guardrail.name = "Special-Tool@Guardrail" + + node_name, node = action.action_node( + guardrail=guardrail, + scope=GuardrailScope.TOOL, + execution_stage=ExecutionStage.POST_EXECUTION, + guarded_component_name="test_tool", + ) + + # Verify node name format (special characters are sanitized) + assert node_name == "tool_post_execution_special_tool_guardrail_log" + assert isinstance(node_name, str) + assert node_name.endswith("_log") + assert "tool" in node_name + assert "post_execution" in node_name + + # Verify node returns empty dict + with caplog.at_level(logging.ERROR): + await node( + AgentGuardrailsGraphState( + messages=[], guardrail_validation_result="tool error" + ) + ) + + # Verify log record properties + log_record = caplog.records[0] + assert log_record.levelno == logging.ERROR + + # Verify custom message was logged (not default message) + assert log_record.message == "Tool execution failed" + assert ( + "Guardrail" not in log_record.message + ) # Custom message doesn't include guardrail context + assert ( + "validation failed" not in log_record.message + ) # Custom message doesn't include default format + + @pytest.mark.asyncio + async def test_node_name_and_exception_post_tool_warning( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """PostExecution + TOOL: name is sanitized and default message is logged at WARNING level.""" + action = LogAction(message=None, level=logging.WARNING) + guardrail = MagicMock() + guardrail.name = "Post Tool Guardrail" + + node_name, node = action.action_node( + guardrail=guardrail, + scope=GuardrailScope.TOOL, + execution_stage=ExecutionStage.POST_EXECUTION, + guarded_component_name="test_tool", + ) + + # Verify node name format + assert node_name == "tool_post_execution_post_tool_guardrail_log" + assert isinstance(node_name, str) + assert node_name.endswith("_log") + assert "tool" in node_name + assert "post_execution" in node_name + + # Verify node returns empty dict + with caplog.at_level(logging.WARNING): + await node( + AgentGuardrailsGraphState( + messages=[], guardrail_validation_result="post execution error" + ) + ) + + # Verify log record properties + log_record = caplog.records[0] + assert log_record.levelno == logging.WARNING + + # Verify default message includes all context + assert ( + "Guardrail [Post Tool Guardrail] validation failed for [TOOL] [POST_EXECUTION]" + in log_record.message + ) + assert "post execution error" in log_record.message + assert ( + log_record.message + == "Guardrail [Post Tool Guardrail] validation failed for [TOOL] [POST_EXECUTION] with the following reason: post execution error" + )