From 552dd976876bf403ed8be3f152807fff1476292d Mon Sep 17 00:00:00 2001 From: 2002yy <15135142681@163.com> Date: Fri, 5 Jun 2026 23:50:03 +0800 Subject: [PATCH] Add agentic local knowledge and API service foundation --- README.md | 19 +- docs/INTERVIEW_NOTES.md | 2 +- docs/RAG.md | 32 +- docs/STUDY_AGENT_OPTIMIZATION_ROADMAP.md | 66 +++-- docs/TECH_STACK.md | 11 +- docs/TESTING.md | 7 +- src/api.py | 357 ++++++++++++++++++++++- src/tools/__init__.py | 1 + src/tools/local_knowledge.py | 292 ++++++++++++++++++ src/ui/chat_panel.py | 41 +-- src/ui/wechat_panel.py | 42 +-- tests/test_api.py | 157 ++++++++++ tests/test_local_knowledge_tool.py | 106 +++++++ 13 files changed, 1037 insertions(+), 96 deletions(-) create mode 100644 src/tools/__init__.py create mode 100644 src/tools/local_knowledge.py create mode 100644 tests/test_local_knowledge_tool.py diff --git a/README.md b/README.md index 9e3d7bf..36222cc 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@

CI Python 3.12 - 277 tests passed + 290 tests passed

A local AI learning assistant with long-term memory, role-based group chat, @@ -17,7 +17,7 @@ Study Agent 是一个本地优先的 AI 学习助手,重点不是简单调用 - **长期记忆**:Markdown memory + safe writer - **上下文分层**:fast / light / deep / archive - **联网搜索**:RSS / News fetch → article extraction → LLM digest → source tracing -- **RAG MVP**:本地 Markdown / TXT / DOCX / PDF 索引、关键词 / 本地向量原型 / hybrid / backend-vector 检索、可配置 embedding provider、可选 Chroma 持久化、引用上下文、来源块、Streamlit 检索/调试面板、聊天注入和 FastAPI RAG 接口 +- **RAG MVP**:本地 Markdown / TXT / DOCX / PDF 索引、关键词 / 本地向量原型 / hybrid / backend-vector 检索、可配置 embedding provider、可选 Chroma 持久化、受控本地知识检索工具、引用上下文、来源块、Streamlit 检索/调试面板、聊天注入和 FastAPI RAG / chat / memory 基础接口 - **工程安全**:SSRF protection、detect-secrets、配置模板 - **工程质量**:pytest 测试套件、Ruff、GitHub Actions CI、打包检查 @@ -27,11 +27,11 @@ Study Agent 是一个本地优先的 AI 学习助手,重点不是简单调用 - **Model routing** with fast / light / deep / archive context tiers - **Long-term memory** based on Markdown files and safe-writer persistence - **Web search pipeline**: feed registry → URL safety checks → article extraction → LLM digest → auditable source trace -- **RAG MVP**: local Markdown / TXT / DOCX / PDF indexing, lexical / local vector prototype / hybrid / backend-vector retrieval, configurable embedding providers, optional Chroma persistence, citation-first context formatting, source blocks, a Streamlit retrieval/debug panel, optional chat injection, and FastAPI RAG endpoints +- **RAG MVP**: local Markdown / TXT / DOCX / PDF indexing, lexical / local vector prototype / hybrid / backend-vector retrieval, configurable embedding providers, optional Chroma persistence, a controlled local-knowledge retrieval tool, citation-first context formatting, source blocks, a Streamlit retrieval/debug panel, optional chat injection, and FastAPI RAG / chat / memory foundation endpoints - **SSRF protection** for article fetching, **detect-secrets** in CI - **Batched session logging** and multi-layer caching for performance - **Performance budget**: mode-based `max_tokens` bounds on the main chat, WeChat, and news LLM paths -- **277 pytest tests**, Ruff clean, mypy clean, GitHub Actions CI workflow +- **290 pytest tests**, Ruff clean, mypy clean, GitHub Actions CI workflow For a detailed breakdown of the stack and engineering highlights, see [Technical Stack & Engineering Highlights](docs/TECH_STACK.md). @@ -109,7 +109,7 @@ Study Agent 的定位很明确:**一个运行在你本地的、有长期记忆 | **角色群聊** | 四位角色(三月七、刻晴、纳西妲、流萤)群聊讨论,各有独立人设 | | **联网搜索** | Google News + Bing News + RSSHub 多源聚合,页面正文三层提取 | | **来源追溯** | 搜索结果写入群聊记录,可回溯依据 | -| **RAG MVP** | 本地 Markdown / TXT / DOCX / PDF 文档索引,前端面板返回带文件路径、行号、分数、命中词和 score breakdown 的引用片段,并可注入单人聊天和微信群互动回复;FastAPI 提供 `/health`、`/rag`、`/rag/index`、`/rag/query` | +| **RAG MVP** | 本地 Markdown / TXT / DOCX / PDF 文档索引,前端面板返回带文件路径、行号、分数、命中词和 score breakdown 的引用片段,并可注入单人聊天和微信群互动回复;FastAPI 提供 `/health`、`/rag`、`/rag/index`、`/rag/query`、`/rag/status`、`/rag/upload`、`/rag/local-knowledge` | | **课后总结** | 学习完成后自动总结进展,用户确认后写入记忆 | | **长期记忆** | 学习者画像、进度追踪、项目上下文、当前焦点,多级记忆档案 | | **多 Provider** | 支持 OpenAI / DeepSeek / OpenRouter / SiliconFlow / 本地模型 | @@ -233,7 +233,7 @@ RAG_EMBEDDING_PROVIDER=local_hash │ ├── llm_router.py # 模型路由分发 │ ├── context_builder.py # 上下文构建 │ ├── mode_manager.py # 模式管理(版本/性能/氛围) -│ ├── api.py # FastAPI health / RAG endpoints +│ ├── api.py # FastAPI health / chat / memory / sessions / RAG endpoints │ ├── role_manager.py # 角色加载与管理 │ ├── performance_budget.py # 性能预算(max_tokens 分级) │ ├── memory.py # 记忆系统 @@ -250,6 +250,7 @@ RAG_EMBEDDING_PROVIDER=local_hash │ ├── router.py # 路由配置 │ ├── news/ # 新闻聚合链路 │ ├── rag/ # 本地 RAG MVP:加载、分块、索引、关键词/向量原型/embedding/可选后端检索 +│ ├── tools/ # 受控工具边界:本地知识检索等 │ └── ui/ # Streamlit UI 组件 ├── tests/ # pytest 测试套件 ├── docs/ # 设计文档与工程说明 @@ -270,7 +271,7 @@ RAG_EMBEDDING_PROVIDER=local_hash ## 测试 ```bash -pytest tests/ -v # current local baseline: 277 passed +pytest tests/ -v # current local baseline: 290 passed pytest tests/ --cov=src # 覆盖率 ruff check src/ tests/ # linting mypy --explicit-package-bases src/ # type check @@ -312,8 +313,8 @@ CI 通过 GitHub Actions 在 push / pull request 上运行,集成 `pytest`、` 求职导向的技术演进路线: -- [ ] FastAPI service layer (partial): `/health`, `/rag`, `/rag/index`, `/rag/query` implemented; `/chat` and `/memory` remain planned -- [x] RAG MVP: Markdown / TXT / DOCX / PDF loading, chunking, local keyword retrieval, local vector prototype, hybrid retrieval, backend-vector retrieval, configurable embedding provider, optional Chroma adapter, citation context, source blocks, Streamlit retrieval panel, optional single-chat and WeChat interactive injection +- [x] FastAPI service layer foundation: `/health`, `/chat`, `/memory/preview`, `/memory/commit`, `/sessions`, `/rag`, `/rag/index`, `/rag/query`, `/rag/status`, `/rag/upload` and `/rag/local-knowledge` implemented; streaming, auth and frontend-specific contracts remain planned +- [x] RAG MVP: Markdown / TXT / DOCX / PDF loading, chunking, local keyword retrieval, local vector prototype, hybrid retrieval, backend-vector retrieval, configurable embedding provider, optional Chroma adapter, controlled local-knowledge retrieval, citation context, source blocks, Streamlit retrieval panel, optional single-chat and WeChat interactive injection - [ ] RAG document QA (partial): PDF parsing has file-size, page-count, extracted-text and encrypted-file guards; production embedding requires explicit API/env configuration and Chroma remains optional - [ ] Vector store: Chroma optional adapter implemented; FAISS local prototype and pgvector engineering version remain planned - [ ] Web UI: TypeScript + Vue3 / React, streaming chat, source panel diff --git a/docs/INTERVIEW_NOTES.md b/docs/INTERVIEW_NOTES.md index 2fc3797..4b4e0d0 100644 --- a/docs/INTERVIEW_NOTES.md +++ b/docs/INTERVIEW_NOTES.md @@ -10,7 +10,7 @@ Study Agent 是一个本地优先的 AI 学习助手,重点在多 Provider 模 2. **长期记忆写入安全** — safe writer + preview/confirm 机制,防止不可逆的记忆污染 3. **联网搜索来源追溯** — Feed registry / RSS 多源聚合 → URL safety matrix → 文章正文三层提取 → LLM digest → pipeline trace 全过程来源可回溯 4. **Streamlit 重渲染性能优化** — 多层缓存策略、按模式批量落盘、主链路 token 预算控制 -5. **CI / Ruff / detect-secrets 工程检查** — 277 pytest tests、Ruff clean、mypy local clean、GitHub Actions workflow、detect-secrets 对未豁免发现硬阻断 +5. **CI / Ruff / detect-secrets 工程检查** — 290 pytest tests、Ruff clean、mypy local clean、GitHub Actions workflow、detect-secrets 对未豁免发现硬阻断 ## 可讲亮点 diff --git a/docs/RAG.md b/docs/RAG.md index b2e8c8d..c874677 100644 --- a/docs/RAG.md +++ b/docs/RAG.md @@ -19,10 +19,11 @@ Implemented: - Streamlit retrieval panel for uploads, local paths, indexing, querying and citation preview - Optional single-chat and WeChat interactive reply injection through the `用于聊天回答` toggle - UI source blocks for retrieved file paths, line ranges, scores and matched terms -- FastAPI endpoints: `GET /health`, `POST /rag`, `POST /rag/index`, `POST /rag/query` +- FastAPI endpoints: `GET /health`, `POST /rag`, `POST /rag/index`, `POST /rag/query`, `GET /rag/status`, `POST /rag/upload`, `POST /rag/local-knowledge` - Streamlit knowledge/debug panel with index summary, document rows, chunk preview and score breakdowns - Optional vector backend interface with local fallback and Chroma adapter - Configurable embedding providers: deterministic `local_hash` by default, OpenAI-compatible embeddings when explicitly configured +- Controlled local-knowledge retrieval tool with intent gating, deterministic query rewrite and explicit not-found behavior Not implemented yet: @@ -44,7 +45,8 @@ Not implemented yet: | `src/rag/eval.py` | LLM-free retrieval quality evaluation over gold query fixtures | | `src/rag/service.py` | Application-facing helpers for indexing, querying and context formatting | | `src/rag/schema.py` | Dataclasses for documents, chunks, indexes and search results | -| `src/api.py` | FastAPI health and RAG endpoints | +| `src/tools/local_knowledge.py` | Controlled retrieval boundary for agentic local knowledge use | +| `src/api.py` | FastAPI health, chat, memory, session, RAG and local-knowledge endpoints | ## Data Flow @@ -56,7 +58,9 @@ local files -> save_rag_index -> query_documents -> build_rag_context + -> optional controlled local-knowledge tool -> optional single-chat / WeChat interactive prompt injection or FastAPI response + -> optional frontend-facing chat / memory / session API flow ``` ## Retrieval Behavior @@ -111,8 +115,10 @@ Regression coverage lives in `tests/test_rag.py` and verifies: - Local hash-vector and hybrid retrieval behavior - Citation formatting and context budget behavior - Streamlit RAG panel helpers for uploaded filenames and local path parsing -- FastAPI `/health`, `/rag`, `/rag/index` and `/rag/query` +- FastAPI `/health`, `/rag`, `/rag/index`, `/rag/query`, `/rag/status`, `/rag/upload` and `/rag/local-knowledge` +- FastAPI `/chat`, `/memory/preview`, `/memory/commit`, `/sessions` and `/sessions/{session_id}/flush` - Prompt injection behavior for cited RAG context +- Controlled local-knowledge tool behavior for skip / found / not-found / rewrite `tests/test_rag_eval.py` adds a small gold fixture suite under `tests/fixtures/rag_eval/` and verifies: @@ -182,7 +188,19 @@ Goal: turn the Streamlit expander into a usable knowledge panel. Goal: let the model retrieve when it needs evidence instead of always pre-retrieving. -- Add a `retrieve_local_knowledge(query)` tool boundary. -- Route retrieval only for knowledge-grounded questions. -- Allow query rewrite and second-pass retrieval when first-pass evidence is weak. -- Require explicit "not found in local knowledge" behavior when no source is retrieved. +- [x] Add a `retrieve_local_knowledge(query)` tool boundary. +- [x] Route retrieval only for knowledge-grounded questions through deterministic intent gating. +- [x] Allow deterministic query rewrite and second-pass retrieval when first-pass evidence is weak. +- [x] Require explicit "not found in local knowledge" behavior when no source is retrieved. +- [x] Expose the same boundary through `POST /rag/local-knowledge` for future frontends. +- [ ] Add LLM tool-calling / function-calling integration; current implementation is controlled pre-generation retrieval, not free-form tool use. + +### P8: Service API Foundation + +Goal: expose the current local-first capabilities through stable API boundaries before building a separate web frontend. + +- [x] Add RAG status and upload endpoints for index inspection and rebuilds. +- [x] Add a non-streaming `/chat` endpoint that reuses model routing, role prompts, memory bundles, local-knowledge retrieval and session logging. +- [x] Add memory preview / commit endpoints with the same runtime write-mode guard as the Streamlit UI. +- [x] Add session listing and force-flush endpoints for local session inspection. +- [ ] Add streaming chat, auth, CORS policy and frontend-oriented error envelopes before public or LAN deployment. diff --git a/docs/STUDY_AGENT_OPTIMIZATION_ROADMAP.md b/docs/STUDY_AGENT_OPTIMIZATION_ROADMAP.md index 47fad03..714547d 100644 --- a/docs/STUDY_AGENT_OPTIMIZATION_ROADMAP.md +++ b/docs/STUDY_AGENT_OPTIMIZATION_ROADMAP.md @@ -286,7 +286,7 @@ Study Agent 后续的核心竞争力应该来自 RAG,而不是普通聊天。 不要让模型无限制自由调用工具,而是先用可控路由实现稳定 Agent 工作流。 -## 9. P1:FastAPI 服务化 +## 9. P8:FastAPI 服务化 不建议立刻推翻 Streamlit。推荐三步走: @@ -298,7 +298,7 @@ Streamlit UI → core/chat_engine.py ### 阶段 2:增加 FastAPI -最小接口: +当前基础接口已经落地: ```text GET /health @@ -307,12 +307,21 @@ POST /memory/preview POST /memory/commit POST /rag/upload POST /rag/query +GET /rag/status +POST /rag/local-knowledge GET /sessions +POST /sessions/{session_id}/flush ``` +仍需补齐:streaming chat、auth、CORS、统一错误响应、OpenAPI 示例和 Docker 部署配置。 + ### 阶段 3:补前端 -前端可用 Vue3 或 React。推荐先 Vue3,开发成本较低。 +前端建议进入 P9 后使用 React + Vite + TypeScript。理由是: + +- React 生态更适合后续做聊天流、引用面板、调试抽屉和状态组件拆分。 +- Vite 开发服务器启动快,生产构建输出静态 `dist`,可以独立部署,也可以由 FastAPI 挂载静态目录。 +- TypeScript 能把 API response、RAG source、memory preview、session row 等数据结构固定下来,减少前后端联调时的隐性字段漂移。 最低页面: @@ -368,7 +377,7 @@ GET /sessions | RAG 测试 | chunk、入库、检索、引用来源 | | Tool 测试 | 新闻检索、文件读取、摘要 | | ContextBuilder 测试 | 不同模式下上下文是否正确 | -| API 测试 | /chat、/health、/rag/query | +| API 测试 | /chat、/health、/rag/query、/rag/upload、/rag/status、/memory/preview、/memory/commit、/sessions | | UI smoke 测试 | 页面能打开、基本交互不崩 | 最关键的是 Mock Provider。真实模型用于演示和实际使用,Mock Provider 用于自动测试和 CI,避免测试依赖外部 API。 @@ -438,15 +447,15 @@ docs/ 任务: -1. 增加 FastAPI -2. 实现 /health -3. 实现 /chat -4. 实现 /rag/upload -5. 实现 /rag/query -6. 实现 /memory/preview -7. 实现 /memory/commit -8. 补 API 测试 -9. 补 Docker Compose +1. [x] 增加 FastAPI +2. [x] 实现 /health +3. [x] 实现 /chat(当前为非流式) +4. [x] 实现 /rag/upload +5. [x] 实现 /rag/query +6. [x] 实现 /memory/preview +7. [x] 实现 /memory/commit +8. [x] 补 API 测试 +9. [ ] 补 streaming chat / auth / CORS / Docker Compose ### v1.0:前端产品化版本 @@ -454,7 +463,7 @@ docs/ 任务: -1. Vue3 / React 前端 +1. React + Vite + TypeScript 前端 2. 聊天页 3. 文件上传页 4. 知识库列表页 @@ -479,28 +488,27 @@ docs/ ## 15. 当前最建议执行的下一步 -第一步先画清主流程并拆模块: +当前主流程已经可以按 FastAPI 边界继续收口: ```text 用户输入 -→ UI 接收 +→ Streamlit 或 Web UI 接收 +→ FastAPI /chat → memory 读取 → context 构建 -→ tool 判断 +→ local knowledge tool 判断 → provider 调用 -→ stream 输出 +→ response 输出 → session 记录 → memory 写回确认 ``` -推荐重构顺序: - -1. Provider 抽象稳定 -2. MemoryManager 稳定 -3. ContextBuilder 稳定 -4. SessionLogger 批量写入 -5. ToolRouter 初步成型 -6. Streamlit 只保留 UI -7. 再加 FastAPI -8. 再加 RAG -9. 最后做前端 +推荐推进顺序: + +1. [x] Provider 抽象稳定 +2. [x] Memory / ContextBuilder 基础稳定 +3. [x] SessionLogger 批量写入 +4. [x] RAG MVP 与 local knowledge tool +5. [x] FastAPI 基础服务层 +6. [ ] streaming chat / auth / CORS / Docker +7. [ ] React + Vite + TypeScript 前端 diff --git a/docs/TECH_STACK.md b/docs/TECH_STACK.md index 38a92a9..ec39854 100644 --- a/docs/TECH_STACK.md +++ b/docs/TECH_STACK.md @@ -35,7 +35,7 @@ Study Agent 是一个本地运行的 AI 学习助理系统,面向个人学习 | Long-term Memory | Markdown files | 用 `summary.md`、`current_focus.md`、`learner_profile.md` 等文件保存长期记忆 | | Context Control | fast / light / deep / archive tiers | 按性能模式选择不同记忆文件组,控制 token 成本 | | Routing | Rule-based router + optional LLM router | 根据任务类型、用户选择和性能模式决定角色、学习模式和模型档位 | -| RAG MVP | `src/rag/*`, `src/ui/rag_panel.py`, `src/api.py`, JSON index | 本地 Markdown / TXT / DOCX / PDF 加载、分块、关键词 / 本地向量原型 / hybrid / backend-vector 检索、可配置 embedding provider、可选 Chroma adapter、引用上下文拼装、来源块、Streamlit 检索/调试面板、聊天注入和 FastAPI RAG endpoints | +| RAG MVP | `src/rag/*`, `src/tools/local_knowledge.py`, `src/ui/rag_panel.py`, `src/api.py`, JSON index | 本地 Markdown / TXT / DOCX / PDF 加载、分块、关键词 / 本地向量原型 / hybrid / backend-vector 检索、可配置 embedding provider、可选 Chroma adapter、受控本地知识检索工具、引用上下文拼装、来源块、Streamlit 检索/调试面板、聊天注入和 FastAPI RAG / chat / memory 基础服务 endpoints | | News Search | Feed registry / RSS / Google News / Bing News / RSSHub-style sources | 多源新闻聚合、源健康记录、去重、排序、来源追溯 | | Article Extraction | `trafilatura`, `readability-lxml`, `lxml` | 新闻网页正文读取与降级解析 | | Security | URL safety matrix, SSRF validation, redirect checks, secret scanning | 防止读取本地/内网资源,降低密钥误提交风险 | @@ -280,12 +280,13 @@ User query - Streamlit `本地资料检索` 面板支持上传资料、输入本地路径、建立索引、检索和查看引用上下文 - Streamlit 面板显示当前索引、文档列表、chunk preview、检索参数和 score breakdown - 单人聊天和微信群互动回复可通过 `用于聊天回答` 开关把检索结果注入 system prompt,并显示 RAG 引用来源块 -- FastAPI `GET /health`、`POST /rag`、`POST /rag/index`、`POST /rag/query` +- `retrieve_local_knowledge()` 作为受控工具边界,支持 skip / found / not-found / rewrite 状态 +- FastAPI `GET /health`、`POST /chat`、`POST /memory/preview`、`POST /memory/commit`、`GET /sessions`、`POST /sessions/{session_id}/flush`、`POST /rag`、`POST /rag/index`、`POST /rag/query`、`GET /rag/status`、`POST /rag/upload`、`POST /rag/local-knowledge` 未实现边界: - 默认仍是 local-first;生产 embedding 需要显式 API/env 配置,Chroma 需要额外安装 `chromadb`;FAISS、pgvector 或其他生产向量库仍未接入 -- FastAPI 目前覆盖 health 和 RAG;`/chat`、`/memory` 仍是后续服务化任务 +- FastAPI 当前覆盖 health、非流式 chat、memory preview/commit、session list/flush 和 RAG 基础接口;streaming、auth、CORS 策略和更稳定的前端错误协议仍是后续服务化任务 - 尚未自动注入所有生成路径;当前覆盖单人聊天和微信群互动回复,不覆盖新闻讨论或课后反馈 价值: @@ -329,7 +330,7 @@ User query - 设计 **OpenAI-compatible 多 Provider LLM 接入层**,支持 OpenAI、DeepSeek、OpenRouter、SiliconFlow 与本地模型,通过 `.env` 管理 base_url、模型名、超时、重试和任务级 token 预算。 - 实现 **规则路由 + 性能模式** 的模型选择机制,根据用户输入、任务类型、手动配置和 fast / standard / deep 模式动态选择角色、学习模式与 flash / pro 模型。 - 设计基于 **Markdown 文件的长期记忆系统**,按 fast / light / deep / archive 分层读取 `summary`、`current_focus`、`learner_profile`、`project_context` 等上下文,降低 token 消耗。 -- 实现 **本地 RAG MVP**,支持 Markdown / TXT / DOCX / PDF 加载、来源行号分块、关键词检索、本地向量原型、hybrid / backend-vector 检索、可配置 embedding provider、可选 Chroma adapter、引用上下文拼装、来源块、Streamlit 检索面板、单人聊天和微信群互动回复注入,并提供 FastAPI health / RAG endpoints;FAISS、pgvector 和更完整的来源面板仍作为后续演进。 +- 实现 **本地 RAG MVP**,支持 Markdown / TXT / DOCX / PDF 加载、来源行号分块、关键词检索、本地向量原型、hybrid / backend-vector 检索、可配置 embedding provider、可选 Chroma adapter、引用上下文拼装、来源块、Streamlit 检索面板、单人聊天和微信群互动回复注入,并提供 FastAPI health / chat / memory / session / RAG 基础 endpoints;FAISS、pgvector、streaming 和更完整的来源面板仍作为后续演进。 - 实现 **联网新闻检索与群聊讨论链路**,支持 RSS 聚合、链接解析、正文抽取、LLM 摘要、角色讨论和来源追溯。 - 在网页正文抓取模块加入 **SSRF 防护**,校验 HTTP scheme、DNS 解析结果、私网 IP、loopback、reserved 地址和重定向目标,提高本地联网模块安全性。 - 封装 **安全文件写入工具**,通过临时文件、原子替换、覆盖前备份和 PermissionError 重试保障记忆与日志写入可靠性。 @@ -343,7 +344,7 @@ User query - Implemented a unified **OpenAI-compatible multi-provider LLM layer** supporting OpenAI, DeepSeek, OpenRouter, SiliconFlow and local models, with environment-based model, timeout, retry and token-budget configuration. - Designed a **rule-based model routing system** with fast / standard / deep performance modes to dynamically select role, learning mode and flash / pro model tier based on task type and user configuration. - Built a **Markdown-based long-term memory system** with fast / light / deep / archive context tiers to balance personalization and token cost. -- Implemented a **local RAG MVP** for Markdown / TXT / DOCX / PDF loading, source-line chunking, lexical retrieval, a local vector prototype, hybrid / backend-vector retrieval, configurable embedding providers, an optional Chroma adapter, citation-first context formatting, source blocks, a Streamlit retrieval panel, optional single-chat / WeChat interactive injection, and FastAPI health / RAG endpoints; FAISS, pgvector and a richer source panel remain planned. +- Implemented a **local RAG MVP** for Markdown / TXT / DOCX / PDF loading, source-line chunking, lexical retrieval, a local vector prototype, hybrid / backend-vector retrieval, configurable embedding providers, an optional Chroma adapter, a controlled local-knowledge retrieval tool, citation-first context formatting, source blocks, a Streamlit retrieval panel, optional single-chat / WeChat interactive injection, and FastAPI health / chat / memory / session / RAG foundation endpoints; FAISS, pgvector, streaming and a richer source panel remain planned. - Implemented a **source-traced news pipeline** covering RSS aggregation, link resolution, article extraction, LLM digest generation and role-based group discussion. - Added **SSRF protection** to the article-fetching module by validating URL scheme, DNS resolution results, private IP ranges and redirect targets. - Encapsulated **safe file persistence** with temporary writes, atomic replacement, automatic backup and retry on permission errors. diff --git a/docs/TESTING.md b/docs/TESTING.md index 53138ba..79e614b 100644 --- a/docs/TESTING.md +++ b/docs/TESTING.md @@ -6,7 +6,7 @@ Current verified baseline: | Check | Status | Evidence | |---|---|---| -| pytest | Passed | `277 passed` locally on 2026-06-05 | +| pytest | Passed | `290 passed` locally on 2026-06-05 | | Ruff | Passed | `python -m ruff check .` clean locally on 2026-06-05 | | Package helper | Passed | `python tools/package_project_helper.py . NUL 0` locally on 2026-06-05 | | mypy | Passed locally; CI soft check | `python -m mypy --explicit-package-bases src` clean locally on 2026-06-05 | @@ -27,7 +27,8 @@ Current verified baseline: | **RAG MVP** | `test_rag.py` | 24 | | **RAG evaluation** | `test_rag_eval.py` | 5 | | **RAG vector backends** | `test_rag_backends.py` | 10 | -| **FastAPI RAG endpoints** | `test_api.py` | 6 | +| **Controlled local knowledge tool** | `test_local_knowledge_tool.py` | 7 | +| **FastAPI service endpoints** | `test_api.py` | 12 | | **Architecture flows** | `test_architecture_flows.py` | 12 | | **WeChat decoupling** | `test_wechat_decoupling.py` | 4 | | **Sidebar rerun** | `test_sidebar_global_rerun.py` | 12 | @@ -77,7 +78,7 @@ def test_flush_uses_safe_writer(): ## Running Tests ```bash -python -m pytest # current baseline: 277 passed +python -m pytest # current baseline: 290 passed pytest tests/ -v # Verbose pytest tests/ --cov=src # Coverage python -m ruff check . # Linting diff --git a/src/api.py b/src/api.py index 7364729..201495f 100644 --- a/src/api.py +++ b/src/api.py @@ -3,13 +3,29 @@ from pathlib import Path from typing import Any -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, File, HTTPException, UploadFile from pydantic import BaseModel, Field +from src import memory_writer +from src.context_builder import build_messages +from src.llm_client import chat +from src.memory import read_memory_bundle +from src.mode_manager import is_memory_write_allowed, load_runtime_modes +from src.performance_budget import chat_max_tokens from src.rag import build_rag_context, format_rag_sources, index_documents +from src.rag.backends import get_vector_backend_from_env from src.rag.eval import RagEvalCase, evaluate_case from src.rag.index import DEFAULT_RAG_INDEX_PATH, load_rag_index from src.rag.service import build_rag_debug, search_documents +from src.role_manager import load_role +from src.router import route_request +from src.session_logger import flush_current_session, get_or_create_session, init_session, log +from src.tools.local_knowledge import retrieve_local_knowledge + +ROOT = Path(__file__).resolve().parent.parent +RAG_UPLOAD_DIR = ROOT / "logs" / "rag_uploads" +SESSION_DIR = ROOT / "logs" / "sessions" +CURRENT_SESSION_DIR = ROOT / "logs" / "current" class HealthResponse(BaseModel): @@ -31,6 +47,14 @@ class RagIndexResponse(BaseModel): index_path: str +class RagStatusResponse(BaseModel): + index_path: str + index_exists: bool + documents: int = 0 + chunks: int = 0 + vector_backend: dict[str, Any] + + class RagQueryRequest(BaseModel): query: str = Field(min_length=1) index_path: str | None = None @@ -53,6 +77,94 @@ class RagQueryResponse(BaseModel): evaluation: dict[str, Any] | None = None +class LocalKnowledgeRequest(BaseModel): + query: str = Field(min_length=1) + enabled: bool = True + force: bool = False + index_path: str | None = None + top_k: int = Field(default=3, gt=0, le=20) + min_score: float = Field(default=0.01, ge=0) + retrieval_mode: str = Field(default="hybrid") + context_max_chars: int = Field(default=3000, gt=0, le=20_000) + allow_rewrite: bool = True + weak_score_threshold: float = Field(default=0.05, ge=0) + + +class LocalKnowledgeResponse(BaseModel): + status: str + query: str + retrieval_mode: str + reason: str + context: str + sources: str + result_count: int + results: list[dict[str, Any]] + debug: dict[str, Any] + attempts: list[dict[str, Any]] + rewritten_query: str + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatRequest(BaseModel): + user_input: str = Field(min_length=1) + selected_role: str = "auto" + selected_mode: str = "auto" + selected_model: str = "auto" + relationship_mode: str = "standard" + context_mode: str | None = None + chat_history: list[ChatMessage] = Field(default_factory=list) + session_id: str | None = None + rag_enabled: bool = False + rag_top_k: int = Field(default=3, gt=0, le=20) + rag_retrieval_mode: str = "hybrid" + + +class ChatResponse(BaseModel): + reply: str + session_id: str + route: dict[str, Any] + rag: dict[str, Any] + + +class MemoryUpdate(BaseModel): + target: str + content: str = Field(min_length=1) + append: bool = True + learner_pending: bool = False + + +class MemoryPreviewRequest(BaseModel): + updates: list[MemoryUpdate] = Field(min_length=1) + + +class MemoryPreviewItem(BaseModel): + target: str + path: str + action: str + allowed: bool + preview: str + + +class MemoryPreviewResponse(BaseModel): + writable: bool + memory_mode: str + safe_mode: bool + updates: list[MemoryPreviewItem] + + +class MemoryCommitResponse(BaseModel): + writable: bool + results: list[dict[str, str]] + + +class SessionListResponse(BaseModel): + sessions: list[dict[str, Any]] + + app = FastAPI(title="Study Agent API", version="0.1.0") @@ -60,6 +172,32 @@ def _index_path(value: str | None) -> Path: return Path(value) if value else DEFAULT_RAG_INDEX_PATH +def _memory_target_path(target: str) -> Path: + path = memory_writer.MEMORY_TARGETS.get(target) + if path is None: + raise HTTPException(status_code=400, detail=f"Unknown memory target: {target}") + return path + + +def _session_file_rows(directory: Path, kind: str, limit: int) -> list[dict[str, Any]]: + if not directory.is_dir(): + return [] + rows: list[dict[str, Any]] = [] + for path in directory.glob("*.md"): + stat = path.stat() + rows.append( + { + "kind": kind, + "name": path.name, + "path": str(path), + "size_bytes": stat.st_size, + "mtime_ns": stat.st_mtime_ns, + } + ) + rows.sort(key=lambda row: int(row["mtime_ns"]), reverse=True) + return rows[:limit] + + @app.get("/health", response_model=HealthResponse) def health() -> HealthResponse: return HealthResponse( @@ -69,6 +207,32 @@ def health() -> HealthResponse: ) +@app.get("/rag/status", response_model=RagStatusResponse) +def rag_status(index_path: str | None = None) -> RagStatusResponse: + target = _index_path(index_path) + documents = 0 + chunks = 0 + if target.exists(): + index = load_rag_index(target) + documents = len(index.documents) + chunks = len(index.chunks) + try: + backend_status = get_vector_backend_from_env().status().to_dict() + except Exception as exc: + backend_status = { + "name": "unknown", + "available": False, + "detail": str(exc), + } + return RagStatusResponse( + index_path=str(target), + index_exists=target.exists(), + documents=documents, + chunks=chunks, + vector_backend=backend_status, + ) + + @app.post("/rag/index", response_model=RagIndexResponse) def build_rag_index_endpoint(request: RagIndexRequest) -> RagIndexResponse: try: @@ -91,6 +255,46 @@ def build_rag_index_endpoint(request: RagIndexRequest) -> RagIndexResponse: ) +@app.post("/rag/upload", response_model=RagIndexResponse) +async def upload_rag_documents( + files: list[UploadFile] = File(...), + index_path: str | None = None, + max_chars: int = 900, + overlap_chars: int = 120, +) -> RagIndexResponse: + if not files: + raise HTTPException(status_code=400, detail="No files uploaded") + if max_chars <= 0 or max_chars > 10_000: + raise HTTPException(status_code=400, detail="max_chars out of range") + if overlap_chars < 0 or overlap_chars > 5_000: + raise HTTPException(status_code=400, detail="overlap_chars out of range") + + RAG_UPLOAD_DIR.mkdir(parents=True, exist_ok=True) + saved_paths = [] + for uploaded in files: + filename = Path(uploaded.filename or "document").name + target = RAG_UPLOAD_DIR / filename + target.write_bytes(await uploaded.read()) + saved_paths.append(target) + + target_index = _index_path(index_path) + try: + index = index_documents( + saved_paths, + index_path=target_index, + max_chars=max_chars, + overlap_chars=overlap_chars, + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + return RagIndexResponse( + documents=len(index.documents), + chunks=len(index.chunks), + index_path=str(target_index), + ) + + @app.post("/rag/query", response_model=RagQueryResponse) def query_rag_endpoint(request: RagQueryRequest) -> RagQueryResponse: try: @@ -143,3 +347,154 @@ def query_rag_endpoint(request: RagQueryRequest) -> RagQueryResponse: @app.post("/rag", response_model=RagQueryResponse) def query_rag_alias(request: RagQueryRequest) -> RagQueryResponse: return query_rag_endpoint(request) + + +@app.post("/rag/local-knowledge", response_model=LocalKnowledgeResponse) +def local_knowledge_endpoint(request: LocalKnowledgeRequest) -> LocalKnowledgeResponse: + result = retrieve_local_knowledge( + request.query, + enabled=request.enabled, + force=request.force, + index_path=_index_path(request.index_path), + top_k=request.top_k, + min_score=request.min_score, + retrieval_mode=request.retrieval_mode, + context_max_chars=request.context_max_chars, + allow_rewrite=request.allow_rewrite, + weak_score_threshold=request.weak_score_threshold, + ) + return LocalKnowledgeResponse(**result.to_dict()) + + +@app.post("/chat", response_model=ChatResponse) +def chat_endpoint(request: ChatRequest) -> ChatResponse: + runtime_modes = load_runtime_modes() + context_mode = request.context_mode or runtime_modes.context_mode + route = route_request( + user_input=request.user_input, + selected_role=request.selected_role, + selected_mode=request.selected_mode, + selected_model=request.selected_model, + runtime_modes=runtime_modes, + ) + role_prompt = load_role(route["role"]) + memory_bundle = read_memory_bundle(context_mode) + rag_result = retrieve_local_knowledge( + request.user_input, + enabled=request.rag_enabled, + top_k=request.rag_top_k, + retrieval_mode=request.rag_retrieval_mode, + ) + messages = build_messages( + user_input=request.user_input, + role_prompt=role_prompt, + mode=route["mode"], + memory_bundle=memory_bundle, + chat_history=[message.model_dump() for message in request.chat_history], + relationship_mode=request.relationship_mode, + runtime_modes=runtime_modes, + context_mode=context_mode, + rag_context=rag_result.context, + ) + reply = chat( + messages, + model_profile=route["model_profile"], + max_tokens=chat_max_tokens(runtime_modes.performance_mode), + task_name="single_chat", + ) + session_id = request.session_id or init_session() + log( + session_id=session_id, + role=route["role"], + mode=route["mode"], + model=route["model_profile"], + user_input=request.user_input, + agent_reply=reply, + memory_enabled=bool(memory_bundle), + route_info={**route, "rag_status": rag_result.status}, + ) + flush_current_session( + session_id, + performance_mode=runtime_modes.performance_mode, + debug_mode=runtime_modes.debug_mode, + ) + return ChatResponse( + reply=reply, + session_id=session_id, + route=route, + rag=rag_result.to_dict(), + ) + + +@app.post("/memory/preview", response_model=MemoryPreviewResponse) +def preview_memory_updates(request: MemoryPreviewRequest) -> MemoryPreviewResponse: + runtime_modes = load_runtime_modes() + writable = is_memory_write_allowed(runtime_modes) + items = [] + for update in request.updates: + target = _memory_target_path(update.target) + action = "append" if update.append else "replace" + prefix = "### 待确认观察\n\n" if update.learner_pending else "" + preview = f"{prefix}{update.content.strip()}\n" + items.append( + MemoryPreviewItem( + target=update.target, + path=str(target), + action=action, + allowed=writable, + preview=preview, + ) + ) + return MemoryPreviewResponse( + writable=writable, + memory_mode=runtime_modes.memory_mode, + safe_mode=runtime_modes.safe_mode, + updates=items, + ) + + +@app.post("/memory/commit", response_model=MemoryCommitResponse) +def commit_memory_updates(request: MemoryPreviewRequest) -> MemoryCommitResponse: + runtime_modes = load_runtime_modes() + writable = is_memory_write_allowed(runtime_modes) + if not writable: + raise HTTPException( + status_code=403, + detail={ + "memory_mode": runtime_modes.memory_mode, + "safe_mode": runtime_modes.safe_mode, + "reason": runtime_modes.profile.memory_write_reason, + }, + ) + results = [] + for update in request.updates: + _memory_target_path(update.target) + if update.target == "current_focus" and not update.append: + path = memory_writer.write_current_focus(update.content.strip()) + action = "replace" + else: + path = memory_writer.append_memory( + update.target, + update.content.strip(), + learner_pending=update.learner_pending, + ) + action = "append" + results.append({"target": update.target, "action": action, "path": path}) + return MemoryCommitResponse(writable=writable, results=results) + + +@app.get("/sessions", response_model=SessionListResponse) +def list_sessions(limit: int = 20) -> SessionListResponse: + safe_limit = max(1, min(limit, 100)) + current = _session_file_rows(CURRENT_SESSION_DIR, "current", safe_limit) + archived = _session_file_rows(SESSION_DIR, "archived", safe_limit) + sessions = [*current, *archived] + sessions.sort(key=lambda row: int(row["mtime_ns"]), reverse=True) + return SessionListResponse(sessions=sessions[:safe_limit]) + + +@app.post("/sessions/{session_id}/flush") +def flush_session(session_id: str) -> dict[str, Any]: + get_or_create_session(session_id) + flushed = flush_current_session(session_id, force=True) + return {"session_id": session_id, "flushed": flushed} diff --git a/src/tools/__init__.py b/src/tools/__init__.py new file mode 100644 index 0000000..e350faf --- /dev/null +++ b/src/tools/__init__.py @@ -0,0 +1 @@ +"""Controlled tool boundaries for Study Agent workflows.""" diff --git a/src/tools/local_knowledge.py b/src/tools/local_knowledge.py new file mode 100644 index 0000000..3d61693 --- /dev/null +++ b/src/tools/local_knowledge.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal + +from src.rag import build_rag_context, format_rag_sources +from src.rag.index import DEFAULT_RAG_INDEX_PATH, load_rag_index +from src.rag.schema import RagSearchResult +from src.rag.service import build_rag_debug, search_documents + +LocalKnowledgeStatus = Literal[ + "skipped", + "found", + "not_found", + "index_missing", + "error", +] + +_SKIP_PATTERNS = ( + r"^\s*(hi|hello|hey|thanks|thank you|你好|您好|谢谢|早上好|晚上好)[!!。.\s]*$", + r"(讲个笑话|随便聊|你是谁|自我介绍|打招呼)", +) + +_RETRIEVAL_HINTS = ( + "根据", + "基于", + "资料", + "文档", + "知识库", + "本地", + "引用", + "来源", + "笔记", + "论文", + "文件", + "readme", + "docs", + "document", + "source", + "citation", + "knowledge base", + "local knowledge", +) + +_STUDY_HINTS = ( + "解释", + "说明", + "总结", + "对比", + "怎么", + "如何", + "为什么", + "机制", + "架构", + "代码", + "rag", + "api", + "fastapi", + "react", + "embedding", + "chroma", + "explain", + "summarize", + "compare", + "how", + "why", +) + +_REWRITE_PATTERNS = ( + r"请?(根据|基于|参考|从|在)(本地)?(知识库|资料|文档|笔记|来源|引用|docs|readme)?(中|里)?", + r"(回答|说明|解释|总结|一下|一下这个|这个问题)", + r"(according to|based on|from|in)\s+(the\s+)?(local\s+)?(knowledge base|documents|docs|notes|sources)", + r"(please|can you|could you|explain|summarize|answer)", +) + +NOT_FOUND_CONTEXT = ( + "Local knowledge retrieval was attempted, but no relevant local documents were found. " + "When answering, explicitly state that the local knowledge base did not contain supporting evidence." +) + + +@dataclass(frozen=True) +class RetrievalAttempt: + query: str + result_count: int + top_score: float = 0.0 + + def to_dict(self) -> dict[str, Any]: + return { + "query": self.query, + "result_count": self.result_count, + "top_score": self.top_score, + } + + +@dataclass(frozen=True) +class LocalKnowledgeResult: + status: LocalKnowledgeStatus + query: str + retrieval_mode: str + reason: str + context: str = "" + sources: str = "" + results: list[RagSearchResult] = field(default_factory=list) + debug: dict[str, Any] = field(default_factory=dict) + attempts: tuple[RetrievalAttempt, ...] = () + rewritten_query: str = "" + + @property + def retrieved(self) -> bool: + return self.status == "found" + + @property + def attempted(self) -> bool: + return self.status in {"found", "not_found", "index_missing", "error"} + + def to_dict(self) -> dict[str, Any]: + return { + "status": self.status, + "query": self.query, + "retrieval_mode": self.retrieval_mode, + "reason": self.reason, + "context": self.context, + "sources": self.sources, + "result_count": len(self.results), + "results": [result.to_dict() for result in self.results], + "debug": self.debug, + "attempts": [attempt.to_dict() for attempt in self.attempts], + "rewritten_query": self.rewritten_query, + } + + +def should_retrieve_local_knowledge(query: str) -> tuple[bool, str]: + normalized = " ".join((query or "").strip().lower().split()) + if not normalized: + return False, "empty_query" + if any(re.search(pattern, normalized, re.I) for pattern in _SKIP_PATTERNS): + return False, "conversational_query" + if any(hint in normalized for hint in _RETRIEVAL_HINTS): + return True, "explicit_local_knowledge_hint" + if any(hint in normalized for hint in _STUDY_HINTS) and len(normalized) >= 8: + return True, "study_question_hint" + return False, "no_retrieval_signal" + + +def rewrite_local_knowledge_query(query: str) -> str: + rewritten = query or "" + for pattern in _REWRITE_PATTERNS: + rewritten = re.sub(pattern, " ", rewritten, flags=re.I) + rewritten = re.sub(r"[??。!!,,::;;]+", " ", rewritten) + rewritten = " ".join(rewritten.split()) + return rewritten or query.strip() + + +def _attempt_from_results(query: str, results: list[RagSearchResult]) -> RetrievalAttempt: + return RetrievalAttempt( + query=query, + result_count=len(results), + top_score=results[0].score if results else 0.0, + ) + + +def _is_weak_result(results: list[RagSearchResult], weak_score_threshold: float) -> bool: + if not results: + return True + return results[0].score < weak_score_threshold + + +def retrieve_local_knowledge( + query: str, + *, + enabled: bool = True, + force: bool = False, + index_path: str | Path = DEFAULT_RAG_INDEX_PATH, + top_k: int = 3, + min_score: float = 0.01, + retrieval_mode: str = "hybrid", + context_max_chars: int = 3000, + allow_rewrite: bool = True, + weak_score_threshold: float = 0.05, +) -> LocalKnowledgeResult: + if not enabled: + return LocalKnowledgeResult( + status="skipped", + query=query, + retrieval_mode=retrieval_mode, + reason="disabled", + ) + + should_retrieve, reason = should_retrieve_local_knowledge(query) + if not force and not should_retrieve: + return LocalKnowledgeResult( + status="skipped", + query=query, + retrieval_mode=retrieval_mode, + reason=reason, + ) + + try: + index = load_rag_index(index_path) + except FileNotFoundError: + return LocalKnowledgeResult( + status="index_missing", + query=query, + retrieval_mode=retrieval_mode, + reason="index_missing", + ) + except Exception as exc: + return LocalKnowledgeResult( + status="error", + query=query, + retrieval_mode=retrieval_mode, + reason=f"index_load_failed: {exc}", + ) + + attempts: list[RetrievalAttempt] = [] + debug: dict[str, Any] = {} + try: + results = search_documents( + index, + query, + top_k=top_k, + min_score=min_score, + retrieval_mode=retrieval_mode, + ) + attempts.append(_attempt_from_results(query, results)) + debug = build_rag_debug( + index, + query, + results, + retrieval_mode=retrieval_mode, + top_k=top_k, + min_score=min_score, + ) + + rewritten_query = "" + if allow_rewrite and _is_weak_result(results, weak_score_threshold): + candidate = rewrite_local_knowledge_query(query) + if candidate and candidate != query.strip(): + rewritten_results = search_documents( + index, + candidate, + top_k=top_k, + min_score=min_score, + retrieval_mode=retrieval_mode, + ) + attempts.append(_attempt_from_results(candidate, rewritten_results)) + if rewritten_results: + results = rewritten_results + rewritten_query = candidate + debug = build_rag_debug( + index, + candidate, + results, + retrieval_mode=retrieval_mode, + top_k=top_k, + min_score=min_score, + ) + + if not results: + return LocalKnowledgeResult( + status="not_found", + query=query, + retrieval_mode=retrieval_mode, + reason="no_relevant_local_documents", + context=NOT_FOUND_CONTEXT, + debug=debug, + attempts=tuple(attempts), + ) + + return LocalKnowledgeResult( + status="found", + query=query, + retrieval_mode=retrieval_mode, + reason=reason, + context=build_rag_context(results, max_chars=context_max_chars), + sources=format_rag_sources(results), + results=results, + debug=debug, + attempts=tuple(attempts), + rewritten_query=rewritten_query, + ) + except Exception as exc: + return LocalKnowledgeResult( + status="error", + query=query, + retrieval_mode=retrieval_mode, + reason=f"retrieval_failed: {exc}", + attempts=tuple(attempts), + ) diff --git a/src/ui/chat_panel.py b/src/ui/chat_panel.py index c3852c1..2f36f37 100644 --- a/src/ui/chat_panel.py +++ b/src/ui/chat_panel.py @@ -12,10 +12,10 @@ from src.mode_manager import load_runtime_modes from src.model_stats import estimate_tokens, record_call, record_perf from src.perf import PerfTracker, write_perf_log -from src.rag import build_rag_context, format_rag_sources, query_documents from src.role_manager import load_role from src.router import route_request from src.session_logger import flush_current_session, log +from src.tools.local_knowledge import retrieve_local_knowledge from src.ui.avatar import get_chat_avatar, get_html_avatar_uri, get_user_avatar from src.constants import ROLE_LABELS @@ -112,25 +112,26 @@ def _build_chat_rag_context(user_input: str) -> tuple[str, int, str]: if not st.session_state.get("rag_chat_enabled"): return "", 0, "" - try: - top_k = int(st.session_state.get("rag_chat_top_k", 3)) - results = query_documents( - user_input, - top_k=top_k, - retrieval_mode=st.session_state.get("rag_retrieval_mode", "hybrid"), - ) - except FileNotFoundError: - return "", 0, "RAG index missing" - except Exception as exc: - return "", 0, f"RAG retrieval failed: {exc}" - - context = build_rag_context(results) - if not results: - return "", 0, "No relevant local documents" - st.session_state.rag_results = results - st.session_state.rag_context = context - st.session_state.rag_source_block = format_rag_sources(results) - return context, len(results), "" + top_k = int(st.session_state.get("rag_chat_top_k", 3)) + result = retrieve_local_knowledge( + user_input, + top_k=top_k, + retrieval_mode=st.session_state.get("rag_retrieval_mode", "hybrid"), + ) + st.session_state.rag_debug = result.debug + if result.status == "skipped": + return "", 0, "" + if result.status == "found": + st.session_state.rag_results = result.results + st.session_state.rag_context = result.context + st.session_state.rag_source_block = result.sources + return result.context, len(result.results), "" + if result.status == "not_found": + st.session_state.rag_results = [] + st.session_state.rag_context = result.context + st.session_state.rag_source_block = "" + return result.context, 0, "No relevant local documents" + return "", 0, result.reason def _queue_input(prompt: str): diff --git a/src/ui/wechat_panel.py b/src/ui/wechat_panel.py index 659c5c3..cf90167 100644 --- a/src/ui/wechat_panel.py +++ b/src/ui/wechat_panel.py @@ -13,7 +13,7 @@ update_interaction_mode, update_wechat_join_state, ) -from src.rag import build_rag_context, format_rag_sources, query_documents +from src.tools.local_knowledge import retrieve_local_knowledge from src.session_logger import ( set_wechat_interactive, set_wechat_memory_candidates, @@ -113,26 +113,26 @@ def _build_wechat_rag_context(user_text: str) -> tuple[str, int, str]: if not st.session_state.get("rag_chat_enabled"): return "", 0, "" - try: - top_k = int(st.session_state.get("rag_chat_top_k", 3)) - results = query_documents( - user_text, - top_k=top_k, - retrieval_mode=st.session_state.get("rag_retrieval_mode", "hybrid"), - ) - except FileNotFoundError: - return "", 0, "RAG index missing" - except Exception as exc: - return "", 0, f"RAG retrieval failed: {exc}" - - context = build_rag_context(results) - if not results: - return "", 0, "No relevant local documents" - - st.session_state.rag_results = results - st.session_state.rag_context = context - st.session_state.rag_source_block = format_rag_sources(results) - return context, len(results), "" + top_k = int(st.session_state.get("rag_chat_top_k", 3)) + result = retrieve_local_knowledge( + user_text, + top_k=top_k, + retrieval_mode=st.session_state.get("rag_retrieval_mode", "hybrid"), + ) + st.session_state.rag_debug = result.debug + if result.status == "skipped": + return "", 0, "" + if result.status == "found": + st.session_state.rag_results = result.results + st.session_state.rag_context = result.context + st.session_state.rag_source_block = result.sources + return result.context, len(result.results), "" + if result.status == "not_found": + st.session_state.rag_results = [] + st.session_state.rag_context = result.context + st.session_state.rag_source_block = "" + return result.context, 0, "No relevant local documents" + return "", 0, result.reason def _active_wechat_content() -> tuple[str, str]: diff --git a/tests/test_api.py b/tests/test_api.py index ba72989..bb9c624 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -2,6 +2,7 @@ from fastapi.testclient import TestClient +from src.mode_manager import RuntimeModes from src.api import app @@ -118,3 +119,159 @@ def test_rag_index_endpoint_reports_missing_files(tmp_path): response = client.post("/rag/index", json={"paths": [str(tmp_path / "missing.md")]}) assert response.status_code == 404 + + +def test_local_knowledge_endpoint_applies_agentic_retrieval(tmp_path): + client = TestClient(app) + document = tmp_path / "agentic.md" + index_path = tmp_path / "rag_index.json" + document.write_text("Agentic RAG retrieves local evidence only when useful.", encoding="utf-8") + client.post("/rag/index", json={"paths": [str(document)], "index_path": str(index_path)}) + + skipped = client.post( + "/rag/local-knowledge", + json={"query": "你好", "index_path": str(index_path)}, + ) + found = client.post( + "/rag/local-knowledge", + json={ + "query": "请根据本地资料解释 Agentic RAG evidence", + "index_path": str(index_path), + "top_k": 1, + }, + ) + + assert skipped.status_code == 200 + assert skipped.json()["status"] == "skipped" + assert found.status_code == 200 + data = found.json() + assert data["status"] == "found" + assert data["result_count"] == 1 + assert "agentic.md" in data["sources"] + + +def test_rag_status_and_upload_endpoints(tmp_path): + client = TestClient(app) + index_path = tmp_path / "uploaded_index.json" + + upload_response = client.post( + "/rag/upload", + params={"index_path": str(index_path), "max_chars": 200, "overlap_chars": 0}, + files={"files": ("upload.md", b"Uploaded RAG files become indexed chunks.", "text/markdown")}, + ) + status_response = client.get("/rag/status", params={"index_path": str(index_path)}) + + assert upload_response.status_code == 200 + assert upload_response.json()["documents"] == 1 + assert status_response.status_code == 200 + data = status_response.json() + assert data["index_exists"] is True + assert data["documents"] == 1 + assert data["chunks"] == 1 + assert data["vector_backend"]["name"] == "local" + + +def test_chat_endpoint_builds_reply_and_logs_session(monkeypatch): + from src import api + + captured = {} + + def fake_chat(messages, **kwargs): + captured["messages"] = messages + captured["kwargs"] = kwargs + return "API reply" + + monkeypatch.setattr(api, "chat", fake_chat) + monkeypatch.setattr(api, "load_role", lambda role: f"role prompt for {role}") + monkeypatch.setattr(api, "read_memory_bundle", lambda context_mode: {}) + monkeypatch.setattr( + api, + "load_runtime_modes", + lambda: RuntimeModes(memory_mode="preview", performance_mode="fast"), + ) + client = TestClient(app) + + response = client.post( + "/chat", + json={ + "user_input": "hello api", + "selected_role": "march7", + "selected_mode": "普通", + "selected_model": "flash", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["reply"] == "API reply" + assert data["route"]["role"] == "march7" + assert data["rag"]["status"] == "skipped" + assert captured["kwargs"]["task_name"] == "single_chat" + assert captured["messages"][-1]["content"] == "hello api" + + +def test_memory_preview_and_commit_endpoints(monkeypatch, tmp_path): + from src import api, memory_writer + + target = tmp_path / "progress.md" + monkeypatch.setitem(memory_writer.MEMORY_TARGETS, "progress", target) + monkeypatch.setattr( + api, + "load_runtime_modes", + lambda: RuntimeModes(memory_mode="confirm_write", safe_mode=False), + ) + monkeypatch.setattr(memory_writer, "load_runtime_modes", api.load_runtime_modes) + client = TestClient(app) + payload = {"updates": [{"target": "progress", "content": "API memory update"}]} + + preview = client.post("/memory/preview", json=payload) + commit = client.post("/memory/commit", json=payload) + + assert preview.status_code == 200 + assert preview.json()["writable"] is True + assert preview.json()["updates"][0]["path"] == str(target) + assert commit.status_code == 200 + assert commit.json()["results"][0]["target"] == "progress" + assert "API memory update" in target.read_text(encoding="utf-8") + + +def test_memory_commit_rejects_when_runtime_is_not_writable(monkeypatch, tmp_path): + from src import api, memory_writer + + target = tmp_path / "progress.md" + monkeypatch.setitem(memory_writer.MEMORY_TARGETS, "progress", target) + monkeypatch.setattr( + api, + "load_runtime_modes", + lambda: RuntimeModes(memory_mode="preview", safe_mode=False), + ) + client = TestClient(app) + + response = client.post( + "/memory/commit", + json={"updates": [{"target": "progress", "content": "should not write"}]}, + ) + + assert response.status_code == 403 + assert response.json()["detail"]["reason"] == "preview" + assert not target.exists() + + +def test_sessions_endpoint_lists_current_and_archived_files(monkeypatch, tmp_path): + from src import api + + current_dir = tmp_path / "current" + archived_dir = tmp_path / "sessions" + current_dir.mkdir() + archived_dir.mkdir() + (current_dir / "active.md").write_text("active session", encoding="utf-8") + (archived_dir / "old.md").write_text("old session", encoding="utf-8") + monkeypatch.setattr(api, "CURRENT_SESSION_DIR", current_dir) + monkeypatch.setattr(api, "SESSION_DIR", archived_dir) + client = TestClient(app) + + response = client.get("/sessions") + + assert response.status_code == 200 + names = {item["name"] for item in response.json()["sessions"]} + assert names == {"active.md", "old.md"} diff --git a/tests/test_local_knowledge_tool.py b/tests/test_local_knowledge_tool.py new file mode 100644 index 0000000..05ea6f2 --- /dev/null +++ b/tests/test_local_knowledge_tool.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from src.rag import index_documents +from src.tools.local_knowledge import ( + NOT_FOUND_CONTEXT, + retrieve_local_knowledge, + rewrite_local_knowledge_query, + should_retrieve_local_knowledge, +) + + +def test_local_knowledge_router_skips_conversation(): + should_retrieve, reason = should_retrieve_local_knowledge("你好") + + assert should_retrieve is False + assert reason == "conversational_query" + + +def test_local_knowledge_router_detects_explicit_document_questions(): + should_retrieve, reason = should_retrieve_local_knowledge("请根据本地资料解释 RAG 架构") + + assert should_retrieve is True + assert reason == "explicit_local_knowledge_hint" + + +def test_local_knowledge_rewrite_removes_instruction_boilerplate(): + rewritten = rewrite_local_knowledge_query("请根据本地资料回答:requests session reuse 是什么?") + + assert rewritten == "requests session reuse 是什么" + + +def test_retrieve_local_knowledge_skips_without_retrieval_signal(tmp_path): + document = tmp_path / "notes.md" + index_path = tmp_path / "rag_index.json" + document.write_text("Local knowledge about retrieval.", encoding="utf-8") + index_documents([document], index_path=index_path, max_chars=200, overlap_chars=0) + + result = retrieve_local_knowledge("hello", index_path=index_path) + + assert result.status == "skipped" + assert result.attempted is False + assert result.results == [] + + +def test_retrieve_local_knowledge_finds_sources_with_explicit_hint(tmp_path): + document = tmp_path / "requests.md" + index_path = tmp_path / "rag_index.json" + document.write_text( + "Python requests Session reuses HTTP connections for efficiency.", + encoding="utf-8", + ) + index_documents([document], index_path=index_path, max_chars=200, overlap_chars=0) + + result = retrieve_local_knowledge( + "请根据本地资料解释 requests Session connections", + index_path=index_path, + top_k=1, + ) + + assert result.status == "found" + assert result.retrieved is True + assert len(result.results) == 1 + assert "requests.md" in result.sources + assert "[1] requests" in result.context + + +def test_retrieve_local_knowledge_rewrites_weak_queries(tmp_path): + document = tmp_path / "requests.md" + index_path = tmp_path / "rag_index.json" + document.write_text( + "Python requests Session reuse keeps HTTP connections warm.", + encoding="utf-8", + ) + index_documents([document], index_path=index_path, max_chars=200, overlap_chars=0) + + result = retrieve_local_knowledge( + "请根据本地资料回答:requests Session reuse 是什么?", + index_path=index_path, + top_k=1, + weak_score_threshold=999, + ) + + assert result.status == "found" + assert result.rewritten_query == "requests Session reuse 是什么" + assert [attempt.query for attempt in result.attempts] == [ + "请根据本地资料回答:requests Session reuse 是什么?", + "requests Session reuse 是什么", + ] + + +def test_retrieve_local_knowledge_not_found_returns_contract(tmp_path): + document = tmp_path / "notes.md" + index_path = tmp_path / "rag_index.json" + document.write_text("Only contains local retrieval notes.", encoding="utf-8") + index_documents([document], index_path=index_path, max_chars=200, overlap_chars=0) + + result = retrieve_local_knowledge( + "请根据本地资料解释 quantum banana", + index_path=index_path, + top_k=1, + ) + + assert result.status == "not_found" + assert result.context == NOT_FOUND_CONTEXT + assert result.sources == "" + assert result.attempted is True