Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/kit/mcp/dev_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class ExtractSymbolsParams(BaseModel):
repo_id: str
file_path: str
symbol_type: Optional[str] = None
include_code: bool = Field(
default=False,
description="Include full source code of each symbol. Default false to reduce context size.",
)


class FindSymbolUsagesParams(BaseModel):
Expand All @@ -117,6 +121,22 @@ class FindSymbolUsagesParams(BaseModel):

class GetFileTreeParams(BaseModel):
repo_id: str
compact: bool = Field(
default=True,
description="Return compact newline-separated paths instead of full JSON. Reduces context by ~75%.",
)
include_dirs: bool = Field(
default=False,
description="Include directory entries (only relevant when compact=true).",
)


class GetSymbolCodeParams(BaseModel):
"""Get the source code of a specific symbol (lazy loading)."""

repo_id: str
file_path: str
symbol_name: str = Field(description="Name of the symbol to get code for")


class GetCodeSummaryParams(BaseModel):
Expand Down Expand Up @@ -283,6 +303,29 @@ def find_symbol_usages(
# Repository.find_symbol_usages doesn't accept keyword arguments
return repo.find_symbol_usages(symbol_name)

def get_symbol_code(self, repo_id: str, file_path: str, symbol_name: str) -> Dict[str, Any]:
"""Get source code of a specific symbol (lazy loading)."""
repo = self.get_repo(repo_id)
try:
symbols = repo.extract_symbols(file_path)
for symbol in symbols:
if symbol.get("name") == symbol_name:
return {
"name": symbol.get("name"),
"type": symbol.get("type"),
"file": file_path,
"start_line": symbol.get("start_line"),
"end_line": symbol.get("end_line"),
"code": symbol.get("code", ""),
}
# Symbol not found - return list of available symbols
available = [s.get("name") for s in symbols]
raise MCPError(INVALID_PARAMS, f"Symbol '{symbol_name}' not found. Available: {available[:20]}")
except ValueError as e:
if "outside repository bounds" in str(e):
raise MCPError(INVALID_PARAMS, f"Path traversal attempted: {e}")
raise MCPError(INVALID_PARAMS, str(e))

def get_code_summary(self, repo_id: str, file_path: str, symbol_name: Optional[str] = None) -> Dict[str, Any]:
"""Get code summary."""
repo = self.get_repo(repo_id)
Expand Down Expand Up @@ -497,6 +540,11 @@ def list_tools(self) -> List[Tool]:
description="Search code using AST patterns (semantic search)",
inputSchema=GrepASTParams.model_json_schema(),
),
Tool(
name="get_symbol_code",
description="Get source code of a specific symbol (lazy loading for context efficiency)",
inputSchema=GetSymbolCodeParams.model_json_schema(),
),
]


Expand Down Expand Up @@ -974,6 +1022,7 @@ async def call_tool(name: str, arguments: dict) -> List[TextContent]:
"get_git_info",
"review_diff",
"grep_ast",
"get_symbol_code",
]:
# Route to parent class method
if name == "open_repository":
Expand Down Expand Up @@ -1010,6 +1059,12 @@ async def call_tool(name: str, arguments: dict) -> List[TextContent]:
result = logic.extract_symbols(
symbol_params.repo_id, symbol_params.file_path, symbol_params.symbol_type
)
# Filter out code field unless explicitly requested (saves ~90% context)
if not symbol_params.include_code:
result = [
{k: v for k, v in symbol.items() if k != "code"}
for symbol in result
]
return [TextContent(type="text", text=json.dumps(result, indent=2))]
elif name == "find_symbol_usages":
usage_params = FindSymbolUsagesParams(**arguments)
Expand All @@ -1020,6 +1075,14 @@ async def call_tool(name: str, arguments: dict) -> List[TextContent]:
elif name == "get_file_tree":
tree_params = GetFileTreeParams(**arguments)
result = logic.get_file_tree(tree_params.repo_id)
# Compact mode: newline-separated paths (saves ~75% context)
if tree_params.compact:
paths = []
for item in result:
is_dir = item.get("is_dir", False)
if tree_params.include_dirs or not is_dir:
paths.append(item.get("path", ""))
return [TextContent(type="text", text="\n".join(paths))]
return [TextContent(type="text", text=json.dumps(result, indent=2))]
elif name == "get_code_summary":
summary_params = GetCodeSummaryParams(**arguments)
Expand Down Expand Up @@ -1051,6 +1114,14 @@ async def call_tool(name: str, arguments: dict) -> List[TextContent]:
ast_params.max_results,
)
return [TextContent(type="text", text=json.dumps(result, indent=2))]
elif name == "get_symbol_code":
symbol_code_params = GetSymbolCodeParams(**arguments)
result = logic.get_symbol_code(
symbol_code_params.repo_id,
symbol_code_params.file_path,
symbol_code_params.symbol_name,
)
return [TextContent(type="text", text=json.dumps(result, indent=2))]
else:
# Should not happen since we checked the name is in the list
return [TextContent(type="text", text=f"Tool {name} is recognized but not implemented")]
Expand Down
104 changes: 104 additions & 0 deletions tests/mcp/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,3 +522,107 @@ def test_mcp_tool_output_get_file_tree(logic: KitServerLogic, temp_git_repo):
result = logic.get_file_tree(repo_id)
assert isinstance(result, list)
assert len(result) > 0


class TestMCPContextOptimization:
"""Test MCP server context optimization features."""

def test_extract_symbols_excludes_code_by_default(self, logic, temp_git_repo):
"""Test that extract_symbols excludes code by default (90% context savings)."""
# Create a file with symbols
test_file = Path(temp_git_repo) / "symbols.py"
test_file.write_text("""
def hello():
'''A greeting function'''
print("Hello, world!")
return True

class MyClass:
'''A test class'''
def method(self):
pass
""")
subprocess.run(["git", "add", "."], cwd=temp_git_repo, check=True, capture_output=True)
subprocess.run(["git", "commit", "-m", "Add symbols"], cwd=temp_git_repo, check=True, capture_output=True)

repo_id = logic.open_repository(temp_git_repo)
result = logic.extract_symbols(repo_id, "symbols.py")

# Verify symbols are extracted
assert isinstance(result, list)
assert len(result) > 0

# Each symbol should have code (we'll filter it in the handler test)
for symbol in result:
assert "name" in symbol
# Code is still present in the logic result, filtering happens in handler

def test_get_symbol_code_returns_specific_symbol(self, logic, temp_git_repo):
"""Test get_symbol_code lazy loading."""
# Create a file with symbols
test_file = Path(temp_git_repo) / "code.py"
test_file.write_text("""
def target_function():
'''This is the target function'''
x = 1 + 2
return x * 3

def other_function():
'''Another function'''
pass
""")
subprocess.run(["git", "add", "."], cwd=temp_git_repo, check=True, capture_output=True)
subprocess.run(["git", "commit", "-m", "Add code"], cwd=temp_git_repo, check=True, capture_output=True)

repo_id = logic.open_repository(temp_git_repo)
result = logic.get_symbol_code(repo_id, "code.py", "target_function")

assert isinstance(result, dict)
assert result["name"] == "target_function"
assert result["type"] == "function"
assert "code" in result
assert "target_function" in result["code"]

def test_get_symbol_code_not_found(self, logic, temp_git_repo):
"""Test get_symbol_code with non-existent symbol."""
repo_id = logic.open_repository(temp_git_repo)

with pytest.raises(MCPError) as exc:
logic.get_symbol_code(repo_id, "test.py", "nonexistent_symbol")
assert exc.value.code == INVALID_PARAMS
assert "not found" in exc.value.message

def test_get_symbol_code_path_traversal(self, logic, temp_git_repo):
"""Test path traversal protection in get_symbol_code."""
repo_id = logic.open_repository(temp_git_repo)

with pytest.raises(MCPError) as exc:
logic.get_symbol_code(repo_id, "../../../etc/passwd", "symbol")
assert exc.value.code == INVALID_PARAMS
assert "Path traversal" in exc.value.message

def test_list_tools_includes_get_symbol_code(self, logic):
"""Test that tools list includes the new get_symbol_code tool."""
tools = logic.list_tools()
tool_names = [tool.name for tool in tools]
assert "get_symbol_code" in tool_names

def test_extract_symbols_params_schema(self, logic):
"""Test that extract_symbols has include_code in its schema."""
tools = logic.list_tools()
extract_tool = next(t for t in tools if t.name == "extract_symbols")
schema = extract_tool.inputSchema
assert "include_code" in schema.get("properties", {})
# Verify default is False
assert schema["properties"]["include_code"].get("default") is False

def test_get_file_tree_params_schema(self, logic):
"""Test that get_file_tree has compact mode in its schema."""
tools = logic.list_tools()
tree_tool = next(t for t in tools if t.name == "get_file_tree")
schema = tree_tool.inputSchema
assert "compact" in schema.get("properties", {})
assert "include_dirs" in schema.get("properties", {})
# Verify defaults
assert schema["properties"]["compact"].get("default") is True
assert schema["properties"]["include_dirs"].get("default") is False
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.