Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
183 changes: 183 additions & 0 deletions demohouse/computer_use/computer_use_mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import asyncio
import datetime
import json
from typing import Optional
from contextlib import AsyncExitStack
from mcp.client.sse import sse_client

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

from arkitect.core.component.llm import BaseChatLanguageModel
from arkitect.core.component.llm.model import ArkChatParameters, ArkMessage
from volcenginesdkarkruntime import AsyncArk
from arkitect.core.component.llm.utils import convert_response_message
from converter import create_chat_completion_tool, create_tool_response
from utils import pretty_print_message
from config import LLM_ENDPOINT, ARK_API_KEY


modelClient = AsyncArk(
api_key=ARK_API_KEY, # doubao 1.5
)



class MCPClient:
def __init__(self):
# Initialize session and client objects
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()

async def connect_to_stdio_server(self, server_script_path: str):
"""Connect to an MCP server

Args:
server_script_path: Path to the server script (.py or .js)
"""
is_python = server_script_path.endswith(".py")
is_js = server_script_path.endswith(".js")
if not (is_python or is_js):
raise ValueError("Server script must be a .py or .js file")

command = "python" if is_python else "node"
server_params = StdioServerParameters(
command=command, args=[server_script_path], env=None
)

stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(self.stdio, self.write, read_timeout_seconds=datetime.timedelta(seconds=10))
)

await self.session.initialize()

# List available tools
response = await self.session.list_tools()
tools = response.tools
print("\nConnected to server with tools:", [tool.name for tool in tools])

async def connect_to_sse_server(self, server_url: str):
"""Connect to an MCP server running with SSE transport"""
# Store the context managers so they stay alive
self._streams_context = sse_client(url=server_url)
streams = await self._streams_context.__aenter__()

self._session_context = ClientSession(*streams)
self.session: ClientSession = await self._session_context.__aenter__()

# Initialize
await self.session.initialize()

# List available tools to verify connection
print("Initialized SSE client...")
print("Listing tools...")
response = await self.session.list_tools()
tools = response.tools
print("\nConnected to server with tools:", [tool.name for tool in tools])

async def process_query(self, query: str) -> str:
sp = "你是一个会使用代码工具和Bash命令行的助手,你需要通过使用工具完成用户给你的任务。如果任务不确定时,你应该向用户确认他的意图,当你认为可以执行时,生产对应的工具使用指令。在需要多步工具使用时,你需要记住你之前定下的计划和使用过的工具。如果计划有变动,请描述一下新的计划然后再继续使用工具执行任务"

"""Process a query using Claude and available tools"""
messages = [
ArkMessage(
role="system",
content=sp,
),
ArkMessage(
role="user",
content=query,
)
]
for message in messages:
pretty_print_message(message)

response = await self.session.list_tools()
available_tools = [create_chat_completion_tool(mcp_tool=tool) for tool in response.tools]
parameters = ArkChatParameters()
parameters.tools =available_tools

# Initialize LLM
llm = BaseChatLanguageModel(
endpoint_id=LLM_ENDPOINT,
messages=messages,
parameters=parameters,
client=modelClient,
)
response = await llm.arun()

# Process response and handle tool calls
tool_results = []
final_text = []

while response:
msg = response.choices[0].message
new_ark_message = convert_response_message(msg)
pretty_print_message(new_ark_message)
messages.append(new_ark_message)
if msg.content:
final_text.append(msg.content)

if not msg.tool_calls:
# no more tool calls, break
break
tool_calls = msg.tool_calls
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
# Execute tool call
result = await self.session.call_tool(tool_name, tool_args)
tool_results.append({"call": tool_name, "result": result})
final_text.append(f"[Calling tool {tool_name} with args {tool_args}]")

tool_response = ArkMessage(
role="tool",
tool_call_id=tool_call.id,
content=create_tool_response(result)
)
pretty_print_message(tool_response)
messages.append(tool_response)
llm.messages = messages
response = await llm.arun()
return final_text

async def chat_loop(self):
"""Run an interactive chat loop"""
print("\nMCP Client Started!")
print("Type your queries or 'quit' to exit.")

while True:
try:
query = input("\nQuery: ").strip()

if query.lower() == "quit":
break

await self.process_query(query)

except Exception as e:
print(f"\nError: {str(e)}")

async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()


async def main():
server_url = "http://0.0.0.0:8000/sse"
client = MCPClient()
try:
await client.connect_to_sse_server(server_url)
await client.chat_loop()
finally:
await client.cleanup()


if __name__ == "__main__":
import sys

asyncio.run(main())
36 changes: 36 additions & 0 deletions demohouse/computer_use/computer_use_mcp_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from local_bash_executor import LocalBashExecutor
from mcp.server.fastmcp import FastMCP


mcp = FastMCP("computer use")

bash_executor = LocalBashExecutor()


@mcp.tool()
async def run_command(command: str, work_dir: str = None) -> str:
"""执行 bash 命令

Args:
command: bash 命令
work_dir: 选填,bash 命令运行的工作目录
"""
console_output = await bash_executor.run_command(command, work_dir)
return console_output.text


@mcp.tool()
async def create_file(file_name: str, content: str) -> str:
"""这个工具可以帮你创建文件, 如果该文件已经存在,则会覆盖之前的文件内容,如果文件不存在,则会自动创建这个文件

Args:
file_name: 文件名称,需包含文件名和相对路径名。示例:folder/text.txt
content: 文件内容
"""
console_output = await bash_executor.create_file(file_name, content)
return console_output.text


if __name__ == "__main__":
# Initialize and run the server
mcp.run(transport="sse")
7 changes: 7 additions & 0 deletions demohouse/computer_use/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# All the files you would like agent to process should be put under this folder
# All agents work will be saved in this folder
AI_WORKSPACE = "/YOUR/PATH/TO/ai_workspace"


LLM_ENDPOINT="YOUR_LLM_ENDPOINT"
ARK_API_KEY="YOUR_API_KEY"
36 changes: 36 additions & 0 deletions demohouse/computer_use/converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

from typing import Any, Dict
import mcp.types as mcp_types

from arkitect.core.component.llm.model import ChatCompletionTool, FunctionDefinition


def convert_schema(input_shema: Dict[str, Any], param_descriptions: Dict[str, str]={}) -> Dict[str, Any]:
properties = input_shema["properties"]
for key, val in properties.items():
if "description" not in val:
val["description"] = param_descriptions.get(key, "")
properties[key] = val
return input_shema


def create_chat_completion_tool(mcp_tool: mcp_types.Tool, param_descriptions: Dict[str, str]={}) -> ChatCompletionTool:
t = ChatCompletionTool(
type="function",
function=FunctionDefinition(
name=mcp_tool.name,
description=mcp_tool.description,
parameters=convert_schema(mcp_tool.inputSchema, param_descriptions),
),
)
return t


def create_tool_response(mcp_tool_result: mcp_types.CallToolResult) -> str:
message_parts =[]
for part in mcp_tool_result.content:
if (isinstance(part, mcp_types.TextContent)):
message_parts.append(part.text)
else:
raise NotImplementedError("Non-text tool response not supported")
return "\n".join(message_parts)
62 changes: 62 additions & 0 deletions demohouse/computer_use/local_bash_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Optional
from llm_sandbox.base import ConsoleOutput
from llm_sandbox.docker import SandboxDockerSession
from llm_sandbox import SandboxSession
import tempfile

from config import AI_WORKSPACE


class LocalBashExecutor:
def __init__(self):
self.session: Optional[SandboxDockerSession] = None
self.work_dir = "/ai_workspace"
self.mount_folder = AI_WORKSPACE

async def ensure_init(self):
if self.session is None:
self.session = SandboxSession(
image="local_ci:latest",
keep_template=True,
lang="python",
mounts=[
{
"target": "/ai_workspace", # Path inside the container
"source": self.mount_folder, # Path on the host
"type": "bind", # Use bind mount
}
],
verbose=True,
)
self.session.open()

async def run_command(self, command, work_dir=None) -> ConsoleOutput:
# return ConsoleOutput("Not implemented")
await self.ensure_init()
resolved_work_dir = self.work_dir
if work_dir is not None:
resolved_work_dir = resolved_work_dir + "/" + work_dir
return self.session.execute_command(command, resolved_work_dir)

async def create_file(self, file_name, content) -> ConsoleOutput:
await self.ensure_init()
with tempfile.NamedTemporaryFile("w") as tmp:
tmp.write(content)
tmp.flush()
self.session.copy_to_runtime(tmp.name, self.work_dir + "/" + file_name)
tmp_file_name = tmp.name.split("/")[-1]
rename = self.session.execute_command(f"mv {tmp_file_name} {file_name}", self.work_dir)
existing_files = self.session.execute_command("ls", self.work_dir)
return ConsoleOutput(
f"File successfully created. Now we have the following files:\n{existing_files.text}"
)

def close(self):
if self.session:
self.session.close()


if __name__ == "__main__":
import asyncio
executor = LocalBashExecutor()
asyncio.run(executor.create_file("aa.py", ""))
2 changes: 2 additions & 0 deletions demohouse/computer_use/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mcp==1.2.1
llm_sandbox==0.2.1
13 changes: 13 additions & 0 deletions demohouse/computer_use/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@


from arkitect.core.component.llm.model import ArkMessage


def pretty_print_message(message: ArkMessage):
formatted_message = f"{message.role}:\n{message.content}"
if message.tool_calls and len(message.tool_calls) > 0:
formatted_message += "\nTool Calls:\n"
for tool_call in message.tool_calls:
formatted_message += f"{tool_call.function.name}: {tool_call.function.arguments}\n\n"

print(formatted_message + "\n" + "-" * 50 + "\n")
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ python = ">=3.8.1,<3.12.0"
langchain_core = "0.1.52"
langchain = ">=0.1.0,<=0.2.0"
fastapi = ">=0.100.0,<1.0.0"
uvicorn = ">=0.22.0,<0.30.0"
uvicorn = ">=0.22.0"
opentelemetry-api = ">=1.22.0,<2.0.0"
pydantic = ">=2.0.0,<3.0.0"
opentelemetry-exporter-otlp = ">=1.22.0,<2.0.0"
Expand Down