Skip to content

Commit 13ca3a3

Browse files
committed
Merge branch 'develop_data_process_tool' of https://github.com/ModelEngine-Group/nexent into develop_multi_agent
# Conflicts: # sdk/nexent/core/utils/prompt_template_utils.py # sdk/nexent/core/utils/tools_common_message.py # test/backend/agents/test_create_agent_info.py # test/backend/services/test_tool_configuration_service.py # test/common/__init__.py # test/common/env_test_utils.py # test/sdk/core/utils/test_prompt_template_utils.py
2 parents 7766713 + bb8bbdb commit 13ca3a3

File tree

12 files changed

+823
-101
lines changed

12 files changed

+823
-101
lines changed

sdk/nexent/core/tools/analyze_text_file_tool.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@ class AnalyzeTextFileTool(Tool):
3737
inputs = {
3838
"file_url_list": {
3939
"type": "array",
40-
"description": "List of file URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs. Can also accept a single file URL which will be treated as a list with one element."
40+
"description": "List of file URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs."
4141
},
4242
"query": {
4343
"type": "string",
4444
"description": "User's question to guide the analysis"
4545
}
4646
}
47-
output_type = "string"
48-
category = ToolCategory.FILE.value
49-
tool_sign = ToolSign.FILE_OPERATION.value
47+
output_type = "array"
48+
category = ToolCategory.MULTIMODAL.value
49+
tool_sign = ToolSign.MULTIMODAL_OPERATION.value
5050

5151
def __init__(
5252
self,
@@ -76,30 +76,29 @@ def __init__(
7676
self.data_process_service_url = data_process_service_url
7777
self.mm = LoadSaveObjectManager(storage_client=self.storage_client)
7878

79-
self.running_prompt_zh = "正在分析文本文件..."
80-
self.running_prompt_en = "Analyzing text file..."
79+
self.running_prompt_zh = "正在分析文件..."
80+
self.running_prompt_en = "Analyzing file..."
8181
# Dynamically apply the load_object decorator to forward method
8282
self.forward = self.mm.load_object(input_names=["file_url_list"])(self._forward_impl)
8383

8484
def _forward_impl(
8585
self,
86-
file_url_list: Union[bytes, List[bytes]],
86+
file_url_list: List[bytes],
8787
query: str,
88-
) -> Union[str, List[str]]:
88+
) -> List[str]:
8989
"""
9090
Analyze text file content using a large language model.
9191
9292
Note: This method is wrapped by load_object decorator which downloads
9393
the image from S3 URL, HTTP URL, or HTTPS URL and passes bytes to this method.
9494
9595
Args:
96-
file_url_list: File bytes or a sequence of file bytes (converted from URLs by the decorator).
97-
The load_object decorator converts URLs to bytes before calling this method.
96+
file_url_list: List of file bytes converted from URLs by the decorator.
97+
The load_object decorator converts URLs to bytes before calling this method.
9898
query: User's question to guide the analysis
9999
100100
Returns:
101-
Union[str, List[str]]: Single analysis string for one file or a list
102-
of analysis strings that align with the order of the provided files.
101+
List[str]: One analysis string per file that aligns with the order
103102
"""
104103
# Send tool run message
105104
if self.observer:
@@ -109,19 +108,15 @@ def _forward_impl(
109108
self.observer.add_message("", ProcessType.CARD, json.dumps(card_content, ensure_ascii=False))
110109

111110
if file_url_list is None:
112-
raise ValueError("file_url_list must contain at least one file")
111+
raise ValueError("file_url_list cannot be None")
113112

114-
if isinstance(file_url_list, (list, tuple)):
115-
file_inputs: List[bytes] = list(file_url_list)
116-
elif isinstance(file_url_list, bytes):
117-
file_inputs = [file_url_list]
118-
else:
119-
raise ValueError("file_url_list must be bytes or a list/tuple of bytes")
113+
if not isinstance(file_url_list, list):
114+
raise ValueError("file_url_list must be a list of bytes")
120115

121116
try:
122117
analysis_results: List[str] = []
123118

124-
for index, single_file in enumerate(file_inputs, start=1):
119+
for index, single_file in enumerate(file_url_list, start=1):
125120
logger.info(f"Extracting text content from file #{index}, query: {query}")
126121
filename = f"file_{index}.txt"
127122

@@ -143,8 +138,6 @@ def _forward_impl(
143138
logger.error(f"Failed to analyze file #{index}: {analysis_error}")
144139
analysis_results.append(str(analysis_error))
145140

146-
if len(analysis_results) == 1:
147-
return analysis_results[0]
148141
return analysis_results
149142

150143
except Exception as e:

sdk/nexent/core/utils/prompt_template_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44

55
import yaml
66

7-
8-
logger = logging.getLogger("prompt_template_utils")
9-
107
LANGUAGE = {
118
"ZH": "zh",
129
"EN": "en"
1310
}
1411

12+
logger = logging.getLogger("prompt_template_utils")
13+
1514
# Define template path mapping
1615
template_paths = {
1716
'analyze_image': {
@@ -27,14 +26,12 @@
2726
def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kwargs) -> Dict[str, Any]:
2827
"""
2928
Get prompt template
30-
3129
Args:
3230
template_type: Template type, supports the following values:
3331
- 'analyze_image': Analyze image template
3432
- 'analyze_file': Analyze file template (for text files)
3533
language: Language code ('zh' or 'en')
3634
**kwargs: Additional parameters, for agent type need to pass is_manager parameter
37-
3835
Returns:
3936
dict: Loaded prompt template
4037
"""
@@ -55,4 +52,4 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw
5552

5653
# Read and return template content
5754
with open(absolute_template_path, 'r', encoding='utf-8') as f:
58-
return yaml.safe_load(f)
55+
return yaml.safe_load(f)

sdk/nexent/core/utils/tools_common_message.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class ToolSign(Enum):
1111
TAVILY_SEARCH = "d" # Tavily search tool identifier
1212
FILE_OPERATION = "f" # File operation tool identifier
1313
TERMINAL_OPERATION = "t" # Terminal operation tool identifier
14-
MULTIMODAL_OPERATION = "m" # Multimodal operation tool identifier
14+
MULTIMODAL_OPERATION = "m" # Multimodal operation tool identifier
1515

1616

1717
# Tool sign mapping for backward compatibility

test/backend/agents/test_create_agent_info.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
11
import pytest
22
import sys
33
import types
4+
import importlib.util
5+
from pathlib import Path
46
from unittest.mock import AsyncMock, MagicMock, patch, Mock, PropertyMock
57

8+
TEST_ROOT = Path(__file__).resolve().parents[2]
9+
PROJECT_ROOT = TEST_ROOT.parent
10+
11+
# Ensure project backend package is found before test/backend
12+
for _path in (str(PROJECT_ROOT), str(TEST_ROOT)):
13+
if _path not in sys.path:
14+
sys.path.insert(0, _path)
15+
from test.common.env_test_utils import bootstrap_env
16+
17+
env_state = bootstrap_env()
18+
consts_const = env_state["mock_const"]
19+
620
from test.common.env_test_utils import bootstrap_env
721

822
env_state = bootstrap_env()
@@ -17,6 +31,31 @@ def _create_stub_module(name: str, **attrs):
1731
return module
1832

1933

34+
# Configure required constants via shared bootstrap env
35+
consts_const.MINIO_ENDPOINT = "http://localhost:9000"
36+
consts_const.MINIO_ACCESS_KEY = "test_access_key"
37+
consts_const.MINIO_SECRET_KEY = "test_secret_key"
38+
consts_const.MINIO_REGION = "us-east-1"
39+
consts_const.MINIO_DEFAULT_BUCKET = "test-bucket"
40+
consts_const.POSTGRES_HOST = "localhost"
41+
consts_const.POSTGRES_USER = "test_user"
42+
consts_const.NEXENT_POSTGRES_PASSWORD = "test_password"
43+
consts_const.POSTGRES_DB = "test_db"
44+
consts_const.POSTGRES_PORT = 5432
45+
consts_const.DEFAULT_TENANT_ID = "default_tenant"
46+
consts_const.LOCAL_MCP_SERVER = "http://localhost:5011"
47+
consts_const.MODEL_CONFIG_MAPPING = {"llm": "llm_config"}
48+
consts_const.LANGUAGE = {"ZH": "zh"}
49+
consts_const.DATA_PROCESS_SERVICE = "https://example.com/data-process"
50+
# Utilities ---------------------------------------------------------------
51+
def _create_stub_module(name: str, **attrs):
52+
"""Return a lightweight module stub with the provided attributes."""
53+
module = types.ModuleType(name)
54+
for attr_name, attr_value in attrs.items():
55+
setattr(module, attr_name, attr_value)
56+
return module
57+
58+
2059
# Configure required constants via shared bootstrap env
2160
consts_const.MINIO_ENDPOINT = "http://localhost:9000"
2261
consts_const.MINIO_ACCESS_KEY = "test_access_key"
@@ -46,6 +85,7 @@ def _create_stub_module(name: str, **attrs):
4685
# if the testing environment does not have it available.
4786
boto3_mock = MagicMock()
4887
sys.modules['boto3'] = boto3_mock
88+
sys.modules['dotenv'] = MagicMock(load_dotenv=MagicMock())
4989

5090
# Mock the entire client module
5191
client_mock = MagicMock()
@@ -92,6 +132,14 @@ def _create_stub_module(name: str, **attrs):
92132
sys.modules['services.image_service'] = _create_stub_module(
93133
"services.image_service", get_vlm_model=MagicMock(return_value="stub_vlm")
94134
)
135+
sys.modules['services.file_management_service'] = _create_stub_module(
136+
"services.file_management_service",
137+
get_llm_model=MagicMock(return_value="stub_llm_model"),
138+
)
139+
sys.modules['services.tool_configuration_service'] = _create_stub_module(
140+
"services.tool_configuration_service",
141+
initialize_tools_on_startup=AsyncMock(),
142+
)
95143
# Build top-level nexent module to avoid importing the real package
96144
nexent_module = _create_stub_module(
97145
"nexent",
@@ -127,7 +175,31 @@ def _create_stub_module(name: str, **attrs):
127175
sys.modules['smolagents'] = smolagents_module
128176
sys.modules['smolagents.tools'] = smolagents_tools_module
129177

130-
# Now import the module under test
178+
# Ensure real backend.agents.create_agent_info is available and uses our stubs
179+
backend_pkg = sys.modules.get("backend")
180+
if backend_pkg is None:
181+
backend_pkg = types.ModuleType("backend")
182+
backend_pkg.__path__ = [str((TEST_ROOT.parent) / "backend")]
183+
sys.modules["backend"] = backend_pkg
184+
185+
agents_pkg = sys.modules.get("backend.agents")
186+
if agents_pkg is None:
187+
agents_pkg = types.ModuleType("backend.agents")
188+
agents_pkg.__path__ = [str((TEST_ROOT.parent) / "backend" / "agents")]
189+
sys.modules["backend.agents"] = agents_pkg
190+
setattr(backend_pkg, "agents", agents_pkg)
191+
192+
create_agent_info_path = (TEST_ROOT.parent / "backend" / "agents" / "create_agent_info.py")
193+
spec = importlib.util.spec_from_file_location(
194+
"backend.agents.create_agent_info", create_agent_info_path
195+
)
196+
create_agent_info_module = importlib.util.module_from_spec(spec)
197+
sys.modules["backend.agents.create_agent_info"] = create_agent_info_module
198+
assert spec.loader is not None
199+
spec.loader.exec_module(create_agent_info_module)
200+
setattr(agents_pkg, "create_agent_info", create_agent_info_module)
201+
202+
# Now import the symbols under test
131203
from backend.agents.create_agent_info import (
132204
discover_langchain_tools,
133205
create_tool_config_list,
@@ -324,6 +396,43 @@ async def test_create_tool_config_list_with_knowledge_base_tool(self):
324396
last_call = mock_tool_config.call_args_list[-1]
325397
assert last_call[1]['class_name'] == "KnowledgeBaseSearchTool"
326398

399+
@pytest.mark.asyncio
400+
async def test_create_tool_config_list_with_analyze_text_file_tool(self):
401+
"""Ensure AnalyzeTextFileTool receives text-specific metadata."""
402+
mock_tool_instance = MagicMock()
403+
mock_tool_instance.class_name = "AnalyzeTextFileTool"
404+
mock_tool_config.return_value = mock_tool_instance
405+
406+
with patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \
407+
patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \
408+
patch('backend.agents.create_agent_info.get_llm_model') as mock_get_llm_model, \
409+
patch('backend.agents.create_agent_info.minio_client', new_callable=MagicMock) as mock_minio_client:
410+
411+
mock_search_tools.return_value = [
412+
{
413+
"class_name": "AnalyzeTextFileTool",
414+
"name": "analyze_text_file",
415+
"description": "Analyze text file tool",
416+
"inputs": "string",
417+
"output_type": "array",
418+
"params": [{"name": "prompt", "default": "describe"}],
419+
"source": "local",
420+
"usage": None
421+
}
422+
]
423+
mock_get_llm_model.return_value = "mock_llm_model"
424+
425+
result = await create_tool_config_list("agent_1", "tenant_1", "user_1")
426+
427+
assert len(result) == 1
428+
assert result[0] is mock_tool_instance
429+
mock_get_llm_model.assert_called_once_with(tenant_id="tenant_1")
430+
assert mock_tool_instance.metadata == {
431+
"llm_model": "mock_llm_model",
432+
"storage_client": mock_minio_client,
433+
"data_process_service_url": consts_const.DATA_PROCESS_SERVICE,
434+
}
435+
327436
@pytest.mark.asyncio
328437
async def test_create_tool_config_list_with_analyze_image_tool(self):
329438
"""Ensure AnalyzeImageTool receives VLM model metadata."""

0 commit comments

Comments
 (0)