diff --git a/ai/.env.example b/ai/.env.example index d1e783bc..c6e33414 100644 --- a/ai/.env.example +++ b/ai/.env.example @@ -1,8 +1,80 @@ -HTTPS_PROXY=http://127.0.0.1:7890 -HTTP_PROXY=http://127.0.0.1:7890 - -GEMINI_API_KEY=YOUR_API_KEY -SILICONFLOW_API_KEY=YOUR_API_KEY -DASHSCOPE_API_KEY=YOUR_API_KEY -PINECONE_API_KEY=YOUR_API_KEY -COHERE_API_KEY=YOUR_API_KEY \ No newline at end of file +# ============================================================================== +# AI 服务环境变量配置 +# ============================================================================== +# +# 本文件用于存储 AI 服务所需的敏感信息(如 API 密钥) +# +# 使用方法: +# 1. 复制本文件为 .env:cp .env.example .env +# 2. 修改 .env 文件中的值,填入您的实际 API 密钥 +# 3. 运行程序时,godotenv 会自动加载 .env 文件到环境变量 +# 4. 配置文件中的 ${VAR_NAME} 会被自动替换为对应的环境变量值 +# +# 示例: +# config/models/models.yaml 中使用: +# dashscope: +# api_key: "${DASHSCOPE_API_KEY}" +# +# 程序启动时会自动将 ${DASHSCOPE_API_KEY} 替换为环境变量的值 +# +# ============================================================================== + +# HTTPS Proxy Configuration (可选) +# 如果需要通过代理访问 API 服务,请取消注释并配置 +# HTTPS_PROXY=http://127.0.0.1:7890 +# HTTP_PROXY=http://127.0.0.1:7890 + +# ============================================================================== +# Schema Configuration (必填) +# ============================================================================== + +# Schema Directory (必填) +# 指定 JSON Schema 文件的目录路径 +# +# 重要: 此配置必须通过环境变量设置,不能在 YAML 配置文件中设置 +# 原因: 避免 YAML 解析和 Schema 验证的循环依赖问题 +# +# 支持相对路径(相对于 config.yaml)和绝对路径 +# +# 默认值: schema/json (相对于项目根目录) +SCHEMA_DIR=schema/json + +# ============================================================================== +# AI Model API Keys +# ============================================================================== +# 模型提供商的 API 密钥 +# 至少需要配置一个 LLM 提供商的 API 密钥才能使用 AI 服务 + +# 通义千问 (Dashscope) +# 获取地址: https://dashscope.console.aliyun.com/ +# 支持的模型: qwen-max, qwen-plus, qwen-flash, qwen3-coder-plus +DASHSCOPE_API_KEY=your_dashscope_api_key_here + +# Google Gemini +# 获取地址: https://makersuite.google.com/app/apikey +# 支持的模型: gemini-pro, gemini-pro-vision +GEMINI_API_KEY=your_gemini_api_key_here + +# SiliconFlow +# 获取地址: https://cloud.siliconflow.cn/account/ak +# 支持的模型: gpt-3.5-turbo, gpt-4 +SILICONFLOW_API_KEY=your_siliconflow_api_key_here + +# ============================================================================== +# Vector Database API Keys (可选) +# ============================================================================== + +# Pinecone Vector Database (可选) +# 获取地址: https://www.pinecone.io/ +# 用于 RAG 系统的向量存储 +PINECONE_API_KEY=your_pinecone_api_key_here + +# ============================================================================== +# Rerank Service API Keys (可选) +# ============================================================================== + +# Cohere Rerank (可选) +# 获取地址: https://dashboard.cohere.com/ +# 用于 RAG 系统的结果重排序 +COHERE_API_KEY=your_cohere_api_key_here + diff --git a/ai/TEST_CHECKLIST_V7.md b/ai/TEST_CHECKLIST_V7.md new file mode 100644 index 00000000..be04a21f --- /dev/null +++ b/ai/TEST_CHECKLIST_V7.md @@ -0,0 +1,250 @@ +# AI 模块单元测试 TODO LIST + +> 版本:v7(三层校验重构统一版:Parser / Loader+Schema / Runtime / Component) + +## 快速开始 + +```bash +go test -race -v ./config/... ./runtime/... ./component/... ./test/... ./cmd/... +``` + +--- + +## Tier 1:Parser(语法层,4个) + +### Config Parser + +**文件**: `config/loader_test.go` + +| # | 测试名称 | Mock | 说明 | 输入 | 预期输出 | +|---|---------|------|------|------|---------| +| 1 | `TestLoader_MainConfig_ParseError` | 无 | 主配置 YAML 语法错误 | 非法 `config.yaml`(如缺失冒号) | 返回 `parse error`,不进入结构校验 | +| 2 | `TestLoader_Component_ParseError` | 无 | 组件配置 YAML 语法错误 | 非法组件 YAML | 返回 `parse error`,不进入结构校验 | +| 3 | `TestLoader_MainConfig_ParseError_Priority` | 无 | 语法错误优先级 | 同时包含结构问题与语法错误的主配置 | 优先返回 `parse error` | +| 4 | `TestLoader_Component_ParseError_Priority` | 无 | 语法错误优先级 | 同时包含结构问题与语法错误的组件配置 | 优先返回 `parse error` | + +--- + +## Tier 2:Loader + Schema(结构层,12个) + +### Main Config Structural + +**文件**: `config/loader_test.go` + +| # | 测试名称 | Mock | 说明 | 输入 | 预期输出 | +|---|---------|------|------|------|---------| +| 5 | `TestLoader_MainConfig_UnknownField` | 无 | 主配置 unknown field 拒绝 | `config.yaml` 含未定义字段 | 返回 `structural error` | +| 6 | `TestLoader_MainConfig_ComponentsTypeInvalid` | 无 | components 项类型约束 | `components.x` 为对象/数字 | 返回 `structural error`(仅允许 string/[]string) | +| 7 | `TestLoader_MainConfig_ComponentsArrayItemInvalid` | 无 | components 数组项类型约束 | `components.x` 为 `["a.yaml", 1]` | 返回 `structural error` | +| 8 | `TestLoader_MainConfig_DefaultSchemaDir` | 无 | SCHEMA_DIR 默认路径 | 环境变量未设置 | 默认使用 `schema/json` 成功加载 | + +### Component Structural + +**文件**: `config/loader_test.go` + +| # | 测试名称 | Mock | 说明 | 输入 | 预期输出 | +|---|---------|------|------|------|---------| +| 9 | `TestLoader_Component_MissingType` | 无 | 组件 type 必填 | 组件 YAML 缺失 `type` | 返回 `structural error` | +| 10 | `TestLoader_Component_MissingSpec` | 无 | 组件 spec 必填 | 组件 YAML 缺失 `spec` | 返回 `structural error` | +| 11 | `TestLoader_Component_UnknownTopField` | 无 | top-level unknown field 拒绝 | 组件 YAML 含未定义字段 | 返回 `structural error` | +| 12 | `TestLoader_Component_DefaultInjection_Server` | 无 | server 默认值注入 | 仅给最小 server 配置 | decode 后包含 port/host/timeout 默认值 | +| 13 | `TestLoader_Component_DefaultInjection_Agent` | 无 | agent 阶段默认值注入 | stage 省略 temperature/top_p 等 | decode 后包含默认值 | + +### Conditional Required + +**文件**: `config/loader_test.go` + +| # | 测试名称 | Mock | 说明 | 输入 | 预期输出 | +|---|---------|------|------|------|---------| +| 14 | `TestLoader_Tools_MCPEnabled_RequireHost` | 无 | 条件必填(tools) | `enable_mcp_tools=true` 且缺 `mcp_host_name` | 返回 `structural error` | +| 15 | `TestLoader_RAG_RerankerEnabled_RequireAPIKey` | 无 | 条件必填(rag) | `reranker.enabled=true` 且缺 `api_key` | 返回 `structural error` | +| 16 | `TestLoader_RAG_Splitter_OneOfBranchValidation` | 无 | oneOf 分支结构约束 | splitter spec 与 type 不匹配 | 返回 `structural error` | + +--- + +## Tier 3:Runtime 调度层(6个) + +### Runtime Orchestration + +**文件**: `runtime/runtime_test.go` + +| # | 测试名称 | Mock | 说明 | 输入 | 预期输出 | +|---|---------|------|------|------|---------| +| 17 | `TestRuntime_RegisterFactory_Duplicate` | 无 | 重复注册覆盖 | 同类型注册两次 | 第二次覆盖生效;包含重复注册提示 | +| 18 | `TestRuntime_RegisterFactory_Concurrent` | 无 | 并发注册安全 | 100 goroutine 并发注册 | 无 data race;数量正确 | +| 19 | `TestRuntime_GetFactoryFn_NotFound` | 无 | 未注册工厂 | 类型名 `test` | 返回 error,包含 `not registered` | +| 20 | `TestRuntime_GetComponent_NotFound` | 无 | 未注册组件 | 名称 `agent` | 返回 error,包含 `component not found` | +| 21 | `TestRuntime_ComponentInitOrder` | Stub Component | 初始化顺序 | 注册多个工厂 | 按 `factoryOrder` 顺序 `Validate -> Init` | +| 22 | `TestBootstrap_ValidateFailStopsInit` | Stub Component | 语义失败中断 | 某组件 `Validate()` 返回 error | `Bootstrap` 返回 `failed to validate `,不执行 Init | + +--- + +## Tier 4:Component 语义层(8个) + +### Component Validate + +**文件**: `component/*/test/*_test.go` + +| # | 测试名称 | Mock | 说明 | 输入 | 预期输出 | +|---|---------|------|------|------|---------| +| 23 | `TestServerComponent_Validate_PortRange` | 无 | server 端口语义 | `port=70000` | `Validate()` 返回 error | +| 24 | `TestServerComponent_Validate_TimeoutPositive` | 无 | server 超时语义 | `read_timeout<=0` 或 `write_timeout<=0` | `Validate()` 返回 error | +| 25 | `TestMemoryComponent_Validate_MaxTurns` | 无 | memory 轮次语义 | `max_turns<=0` | `Validate()` 返回 error | +| 26 | `TestToolsComponent_Validate_MCPConfig` | 无 | tools MCP 语义 | MCP enabled 且 host 为空 | `Validate()` 返回 error | +| 27 | `TestModelsComponent_Validate_Providers` | 无 | models 语义一致性 | providers 为空或 base_url 为空 | `Validate()` 返回 error | +| 28 | `TestRAGComponent_Validate_SplitterSemantic` | 无 | rag 分块语义 | `overlap_size >= chunk_size` | `Validate()` 返回 error | +| 29 | `TestAgentComponent_Validate_StageFlowType` | 无 | agent 阶段语义 | 非法 `flow_type` | `Validate()` 返回 error | +| 30 | `TestAgentComponent_Validate_StagePromptRequired` | 无 | agent 阶段语义 | 缺 `prompt_file` | `Validate()` 返回 error | + +--- + +## Tier 5:Business Workflows(业务流程保留项,8个) + +### RAG Workflow + +**文件**: `component/rag/test/workflow_test.go` + +| # | 测试名称 | Mock | 说明 | 输入 | 预期输出 | +|---|---------|------|------|------|---------| +| 31 | `TestRAGWorkflow_Index_Retrieve` | Mock Retriever/Indexer | 索引后可检索 | 文档 `Dubbo is RPC`,查询 `RPC` | 检索结果包含 `Dubbo` | +| 32 | `TestRAGWorkflow_Split_Index` | Mock Splitter/Indexer | 分块后索引 | 长文档 | `len(chunks)>1` 且全部进入索引 | +| 33 | `TestRAGWorkflow_Namespace` | Mock Retriever | 命名空间隔离 | ns1/ns2 各索引文档 | ns1 查询不返回 ns2 内容 | +| 34 | `TestRAG_Retrieve_EmptyQuery` | Mock Retriever | 空查询处理 | `queries=nil` | 返回空 map(非nil),无 error | + +### Agent Workflow + +**文件**: `component/agent/react/test/flow_test.go` + +| # | 测试名称 | Mock | 说明 | 输入 | 预期输出 | +|---|---------|------|------|------|---------| +| 35 | `TestActFlow_GeneralInquiry_NoTools` | Mock Prompt | 一般询问不调工具 | `Intent=GeneralInquiry` | 返回 `ToolOutputs`,`len(Outputs)=0` | +| 36 | `TestActFlow_WithToolCall_ReturnsToolOutputs` | Mock Prompt+Tool | 工具调用主流程 | `Intent=PerformanceInvestigation` | 返回 `ToolOutputs`,至少1条结果 | +| 37 | `TestActFlow_ToolErrorHandling` | Mock Prompt/工具 | 工具错误处理 | 工具返回 error | 返回 error,包含工具名 | +| 38 | `TestThinkFlow_ExecuteError_NoNilDeref` | Mock Prompt | think 异常路径健壮性 | `Execute` 返回 error | 不应因 `resp.Text()` 引发 nil deref | + +--- + +## Tier 6:并发与边界(保留项,8个) + +### Memory Concurrency & State + +**文件**: `component/memory/test/history_test.go` + +| # | 测试名称 | Mock | 说明 | 输入 | 预期输出 | +|---|---------|------|------|------|---------| +| 39 | `TestHistoryMemory_AddHistory_UserMessage` | 无 | 添加用户消息 | session-1 + user message | 进入当前 turn 的 `UserMessages` | +| 40 | `TestHistoryMemory_NextTurn_ArchivesCurrentTurn` | 无 | 推进会话归档当前 turn | 已有1个turn | 旧 turn 进入 history,window 前移 | +| 41 | `TestHistoryMemory_NextTurn_WhenSessionFull` | 无 | 会话窗口满时行为 | 将窗口填满后 `NextTurn` | 返回 error,包含 `context is full` | +| 42 | `TestHistoryMemory_ConcurrentAddHistory` | 无 | 并发写历史安全 | 100 goroutine 写入 | 无 panic,无 race | +| 43 | `TestHistoryMemory_ConcurrentReadWrite` | 无 | 并发读写安全 | 10写+10读 goroutine | 不 panic,无 data race | +| 44 | `TestHistoryMemory_NextTurn_EmptyWindowSafety` | 无 | 空窗口推进安全 | session 被 pop 空后重复 `NextTurn` | 固定当前 panic 风险或修复后断言 error | + +### Runtime/Bootstrap Boundary + +**文件**: `runtime/runtime_test.go`, `component/*/test/*_test.go` + +| # | 测试名称 | Mock | 说明 | 输入 | 预期输出 | +|---|---------|------|------|------|---------| +| 45 | `TestBootstrap_MissingFactoryForConfiguredType` | 无 | 配置类型无工厂 | 配置 type 未注册 | error 包含 `no factory for` | +| 46 | `TestRuntime_GetRuntime_NotInitialized` | 无 | 全局Runtime未初始化 | 直接调用 `GetRuntime()` | 触发 panic `Runtime not initialized` | + +--- + +## 按层统计 + +| 层/域 | 数量 | 说明 | +|------|------|------| +| Parser | 4 | 只验证语法失败归属 | +| Loader+Schema | 12 | 只验证结构校验、条件必填、默认注入 | +| Runtime | 6 | 只验证调度顺序与生命周期 | +| Component Semantic | 8 | 只验证语义规则 | +| Business Workflows | 8 | 保留原有高价值业务流程测试 | +| Robustness & Concurrency | 8 | 保留原有并发与边界保护测试 | +| **总计** | **46** | 三层重构 + 原有单测统一清单 | + +--- + +## 编写规范 + +- 错误断言按层级关键字匹配: + - 语法层:`parse error` + - 结构层:`structural error` + - 语义层:`failed to validate` 或组件语义错误信息 +- 业务流程测试不承担结构层职责断言。 +- 并发测试建议配合 `-race` 在 CI 中执行。 + +--- + +## Mock 说明 + +| Mock对象 | 用途 | 实现方式 | +|---------|------|---------| +| Mock Genkit Registry | 模拟模型注册与组件初始化 | 使用 `testutils.CreateMockGenkitRegistry()` 或等价 stub | +| Mock Prompt | 模拟 LLM 返回(含 ToolRequests) | 自定义 `MockPrompt` / stub prompt,覆盖成功与失败分支 | +| Mock Tool | 模拟工具调用成功/失败/超时 | 本地 mock struct + 可配置返回值 | +| Mock Retriever/Indexer/Splitter | 控制 RAG 行为并断言参数 | 本地 mock struct + 调用计数/入参记录 | +| Stub Component | 验证 Runtime 初始化顺序/错误传播 | 自定义实现 `runtime.Component`,可注入 `Validate/Init/Start/Stop` 行为 | +| 临时配置文件夹(Fixture Dir) | 隔离配置输入与路径解析 | `t.TempDir()` 下写入最小 YAML/Schema 夹具 | + +--- + +## Go 表格驱动单测编写规范 + +### 1) 基本模板 + +```go +func TestXXX(t *testing.T) { + tests := []struct { + name string + input any + wantErr bool + errLike string + }{ + { + name: "valid case", + input: ..., + wantErr: false, + }, + { + name: "invalid case", + input: ..., + wantErr: true, + errLike: "structural error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // arrange + // act + // assert + }) + } +} +``` + +### 2) 断言规范 + +- 错误断言使用关键子串匹配,避免脆弱的全量字符串匹配。 +- 每条测试只验证一个主行为(单一断言目标)。 +- 分层测试中禁止跨层断言: + - Parser 测试不验证语义 + - Loader 测试不验证组件语义 + - Component 测试假设结构已通过 + +### 3) 并发与稳定性 + +- 并发测试统一使用 `go test -race` 验证。 +- 避免真实网络/端口依赖,全部使用本地 mock。 +- 使用 `t.Helper()` 封装重复构建逻辑,提高可读性。 + +### 4) 夹具与命名 + +- 测试名采用 `Test<模块>_<行为>_<预期>` 风格。 +- 配置夹具优先最小化:只保留触发当前断言所需字段。 +- 使用 `t.TempDir()` 管理临时文件,避免污染仓库。 + +### 5) 回归要求 + +- 每个缺陷修复至少补 1 条失败重现用例。 +- 先写失败断言,再修复实现(红-绿流程)。 diff --git a/ai/agent/react/react_test.go b/ai/agent/react/react_test.go deleted file mode 100644 index 4ce79197..00000000 --- a/ai/agent/react/react_test.go +++ /dev/null @@ -1,155 +0,0 @@ -package react - -import ( - "encoding/json" - "fmt" - "testing" - - "dubbo-admin-ai/agent" - "dubbo-admin-ai/config" - "dubbo-admin-ai/manager" - "dubbo-admin-ai/memory" - "dubbo-admin-ai/plugins/dashscope" - "dubbo-admin-ai/schema" - "dubbo-admin-ai/tools" -) - -var ( - reActAgent *ReActAgent - chatHistoryCtx = memory.NewMemoryContext(memory.ChatHistoryKey) -) - -func init() { - reActAgent, _ = Create(manager.Registry(dashscope.Qwen3_coder.Key(), config.PROJECT_ROOT+"/.env", manager.ProductionLogger())) -} - -func TestThinkWithToolReq(t *testing.T) { - input := ActOut{ - Outputs: []tools.ToolOutput{ - { - ToolName: "prometheus_query_service_latency", - Summary: "服务 order-service 在过去10分钟内的 P95 延迟为 3500ms", - Result: map[string]any{ - "quantile": 0.95, - "value_millis": 3500, - }, - }, - { - ToolName: "prometheus_query_service_traffic", - Summary: "服务 order-service 的 QPS 为 250.0, 错误率为 5.2%", - Result: map[string]any{ - "error_rate_percentage": 5.2, - "request_rate_qps": 250, - }, - }, - }, - } - channels := agent.NewChannels(config.STAGE_CHANNEL_BUFFER_SIZE) - - defer func() { - channels.Close() - channels.Close() - }() - reActAgent.orchestrator.RunStage(chatHistoryCtx, agent.ThinkFlowName, input, channels) - -} - -func TestAct(t *testing.T) { - actInJson := `{ - "tool_requests": [ - { - "tool_name": "prometheus_query_service_latency", - "parameter": { - "service_name": "order-service", - "time_range_minutes": 15, - "quantile": 0.95 - } - }, - { - "tool_name": "prometheus_query_service_traffic", - "parameter": { - "service_name": "order-service", - "time_range_minutes": 15 - } - }, - { - "tool_name": "trace_dependency_view", - "parameter": { - "service_name": "order-service" - } - }, - { - "tool_name": "dubbo_service_status", - "parameter": { - "service_name": "order-service" - } - } - ], - "status": "CONTINUED", - "thought": "开始对 order-service 运行缓慢的问题进行系统性诊断。首先需要了解当前服务的整体性能表现,包括延迟、流量和错误率等关键指标。同时,获取服务的实例状态和拓扑依赖关系,以判断是否存在明显的异常。由于多个数据源可独立查询,为提高效率,将并行执行多个工具调用:查询延迟指标、服务流量、服务依赖关系和服务实例状态。" -}` - actIn := ActIn{} - if err := json.Unmarshal([]byte(actInJson), &actIn); err != nil { - t.Fatalf("failed to unmarshal actInJson: %v", err) - } - - reActAgent.orchestrator.RunStage(chatHistoryCtx, agent.ActFlowName, actIn, nil) - // resp := actOuts - // if err != nil { - // t.Fatalf("failed to run act flow: %v", err) - // } - // if resp == nil { - // t.Fatal("expected non-nil response") - // } - -} - -// func TestIntent(t *testing.T) { -// userInput := schema.UserInput{ -// Content: "我的微服务 order-service 运行缓慢,请帮助我诊断原因", -// } - -// channels := agent.NewChannels(config.STAGE_CHANNEL_BUFFER_SIZE) -// go reActAgent.orchestrator.RunStage(chatHistoryCtx, agent.IntentFlowName, userInput, channels) - -// for !channels.Closed() { -// select { -// case stream := <-channels.UserRespChan: -// fmt.Print(stream.Text) - -// case flowData := <-channels.FlowChan: -// fmt.Println(flowData) -// } -// } -// } - -func TestAgent(t *testing.T) { - agentInput := schema.UserInput{ - Content: "我的微服务 order-service 运行缓慢,请帮助我诊断原因", - } - channels := reActAgent.Interact(&agentInput, "session-test") - for !channels.Closed() { - select { - case err, ok := <-channels.ErrorChan: - if !ok { - channels.ErrorChan = nil - continue - } - if err != nil { - t.Fatalf("agent interaction error: %v", err) - channels.Close() - return - } - case chunk, ok := <-channels.UserRespChan: - if !ok { - channels.UserRespChan = nil - continue - } - if chunk != nil { - fmt.Print(chunk.Text()) - } - default: - } - } - -} diff --git a/ai/cmd/index.go b/ai/cmd/index.go new file mode 100644 index 00000000..4aa80bad --- /dev/null +++ b/ai/cmd/index.go @@ -0,0 +1,177 @@ +package main + +import ( + "context" + compRag "dubbo-admin-ai/component/rag" + appconfig "dubbo-admin-ai/config" + "flag" + "fmt" + "log" + "os" + "path/filepath" + "strings" + + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/compat_oai" + "github.com/firebase/genkit/go/plugins/googlegenai" + "github.com/firebase/genkit/go/plugins/pinecone" + "github.com/openai/openai-go/option" +) + +type IndexCommand struct { + Directory string + ConfigPath string +} + +func main() { + cmd := parseFlags() + + if err := validateCommand(cmd); err != nil { + log.Fatalf("Validation error: %v", err) + } + + if err := executeIndexing(cmd); err != nil { + log.Fatalf("Indexing failed: %v", err) + } + + fmt.Println("Indexing completed successfully") +} + +func parseFlags() *IndexCommand { + cmd := &IndexCommand{} + + flag.StringVar(&cmd.Directory, "dir", "/Users/liwener/programming/ospp/dubbo-admin/ai/reference/k8s_docs/concepts", "Directory to index (required)") + flag.StringVar(&cmd.ConfigPath, "config", "component/rag/rag.yaml", "Configuration file path") + + flag.Parse() + return cmd +} + +func validateCommand(cmd *IndexCommand) error { + if cmd.Directory == "" { + return fmt.Errorf("directory parameter is required") + } + + // Check if directory exists + info, err := os.Stat(cmd.Directory) + if err != nil { + return fmt.Errorf("failed to access directory %s: %w", cmd.Directory, err) + } + + if !info.IsDir() { + return fmt.Errorf("%s is not a directory", cmd.Directory) + } + + // Check if config file exists + if _, err := os.Stat(cmd.ConfigPath); err != nil { + return fmt.Errorf("failed to access config file %s: %w", cmd.ConfigPath, err) + } + + return nil +} + +func executeIndexing(cmd *IndexCommand) error { + ctx := context.Background() + + // Load configuration from YAML file + cfg, err := loadRAGConfig(cmd.ConfigPath) + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + if err := cfg.Validate(); err != nil { + return fmt.Errorf("invalid rag config: %w", err) + } + + // Initialize Genkit registry independently + plugins := []api.Plugin{} + // Initialize plugins from environment variables + if key := os.Getenv("DASHSCOPE_API_KEY"); key != "" { + plugins = append(plugins, &compat_oai.OpenAICompatible{ + Provider: "dashscope", + Opts: []option.RequestOption{ + option.WithAPIKey(key), + option.WithBaseURL("https://dashscope.aliyuncs.com/compatible-mode/v1"), + }, + }) + } + if key := os.Getenv("GEMINI_API_KEY"); key != "" { + plugins = append(plugins, &googlegenai.GoogleAI{APIKey: key}) + } + if key := os.Getenv("SILICONFLOW_API_KEY"); key != "" { + plugins = append(plugins, &compat_oai.OpenAICompatible{ + Provider: "siliconflow", + Opts: []option.RequestOption{ + option.WithAPIKey(key), + option.WithBaseURL("https://api.siliconflow.cn/v1"), + }, + }) + } + if key := os.Getenv("PINECONE_API_KEY"); key != "" { + plugins = append(plugins, &pinecone.Pinecone{APIKey: key}) + } + + g := genkit.Init(ctx, genkit.WithPlugins(plugins...)) + + // Build-time target index selection: default to directory name for this CLI + targetIndex := getNamespace("", cmd.Directory) + + sys, err := compRag.BuildRAGFromSpec(ctx, g, cfg) + if err != nil { + return fmt.Errorf("failed to build RAG system: %w", err) + } + + // Load documents via the built loader + docs, err := compRag.LoadDirectory(ctx, sys.Loader, cmd.Directory) + if err != nil { + return fmt.Errorf("failed to load directory: %w", err) + } + + fmt.Printf("Loaded %d documents from %s\n", len(docs), cmd.Directory) + + if len(docs) == 0 { + fmt.Printf("No supported documents found in directory: %s\n", cmd.Directory) + return nil + } + + // Split documents (semantic/header or recursive based on config) + splitDocs, err := sys.Split(ctx, docs) + if err != nil { + return fmt.Errorf("failed to split documents: %w", err) + } + + // Index documents (namespace is a per-call runtime parameter) + namespace := targetIndex + if _, err := sys.Index(ctx, namespace, splitDocs, compRag.WithIndexerTargetIndex(targetIndex)); err != nil { + return fmt.Errorf("failed to index documents: %w", err) + } + + return nil +} + +// getNamespace generates a namespace from the provided parameter or directory name +func getNamespace(namespace, directory string) string { + if namespace != "" { + return namespace + } + return filepath.Base(strings.TrimSpace(directory)) +} + +// loadRAGConfig loads RAG configuration from a YAML file +func loadRAGConfig(configPath string) (*compRag.RAGSpec, error) { + loader := appconfig.NewLoader("config.yaml") + componentCfg, err := loader.LoadComponent(configPath) + if err != nil { + return nil, err + } + if componentCfg.Type != "rag" { + return nil, fmt.Errorf("structural error: component type must be rag, got %s", componentCfg.Type) + } + + var cfg compRag.RAGSpec + if err := componentCfg.Spec.Decode(&cfg); err != nil { + return nil, fmt.Errorf("failed to decode rag config: %w", err) + } + + return &cfg, nil +} diff --git a/ai/cmd/rag.go b/ai/cmd/rag.go deleted file mode 100644 index f569050f..00000000 --- a/ai/cmd/rag.go +++ /dev/null @@ -1,21 +0,0 @@ -package main - -import ( - "dubbo-admin-ai/config" - "dubbo-admin-ai/manager" - "dubbo-admin-ai/plugins/dashscope" - "dubbo-admin-ai/utils" -) - -func main() { - mdDir := "./reference/k8s_docs/concepts" - chunks, err := utils.ProcessMarkdownDirectory(mdDir) - if err != nil { - panic(err) - } - g := manager.Registry(dashscope.Qwen3.Key(), config.PROJECT_ROOT+"/.env", manager.ProductionLogger()) - err = utils.IndexInPinecone(g, "kube-docs", "concepts", dashscope.Qwen3_embedding.Key(), nil, chunks) - if err != nil { - panic(err) - } -} diff --git a/ai/agent/agent.go b/ai/component/agent/agent.go similarity index 87% rename from ai/agent/agent.go rename to ai/component/agent/agent.go index 43e4d49d..4d0b9e2e 100644 --- a/ai/agent/agent.go +++ b/ai/component/agent/agent.go @@ -5,8 +5,7 @@ import ( "errors" "fmt" - "dubbo-admin-ai/config" - "dubbo-admin-ai/memory" + "dubbo-admin-ai/component/memory" "dubbo-admin-ai/schema" "github.com/firebase/genkit/go/core" @@ -34,7 +33,7 @@ const ( type Agent interface { Interact(*schema.UserInput, string) *Channels - GetMemory() *memory.History + GetMemory() *memory.HistoryMemory } type Channels struct { @@ -58,6 +57,8 @@ func (chans *Channels) Reset() { chans.closed = false } +// This method won't destroy the Channels because it will be reused for each interaction. +// If you want to completely destroy the Channels, please call Destroy() method. func (chans *Channels) Close() { chans.closed = true } @@ -149,14 +150,15 @@ type Orchestrator interface { } type OrderOrchestrator struct { - stages map[string]*Stage - beforeLoop []string - loop []string - afterLoop []string + stages map[string]*Stage + beforeLoop []string + loop []string + afterLoop []string + maxIterations int } // The order of stages is the order in which they are executed -func NewOrderOrchestrator(stages ...*Stage) *OrderOrchestrator { +func NewOrderOrchestrator(maxIterations int, stages ...*Stage) *OrderOrchestrator { stagesMap := make(map[string]*Stage, len(stages)) loop := make([]string, 0, len(stages)) beforeLoop := make([]string, 0, len(stages)) @@ -182,10 +184,11 @@ func NewOrderOrchestrator(stages ...*Stage) *OrderOrchestrator { } return &OrderOrchestrator{ - stages: stagesMap, - loop: loop, - beforeLoop: beforeLoop, - afterLoop: afterLoop, + stages: stagesMap, + loop: loop, + beforeLoop: beforeLoop, + afterLoop: afterLoop, + maxIterations: maxIterations, } } @@ -217,7 +220,7 @@ func (orchestrator *OrderOrchestrator) Run(ctx context.Context, input schema.Sch // Iterate until reaching maximum iterations or status is Finished var finalOutput schema.Observation Outer: - for range config.MAX_REACT_ITERATIONS { + for i := 0; i < orchestrator.maxIterations; i++ { for _, order := range orchestrator.loop { // Execute current stage curStage, ok := orchestrator.stages[order] @@ -230,6 +233,11 @@ Outer: } output := <-chans.FlowChan + // Validate output is not nil to prevent nil pointer panic + if output == nil { + return fmt.Errorf("stage %s returned nil output", order) + } + // Check if LLM returned final answer if out, ok := output.(schema.Observation); ok { if !out.Heartbeat && out.FinalAnswer != "" { @@ -257,6 +265,11 @@ Outer: } output := <-chans.FlowChan + // Validate output is not nil to prevent nil pointer panic + if output == nil { + return fmt.Errorf("after-loop stage %s returned nil output", key) + } + input = output } diff --git a/ai/component/agent/agent.yaml b/ai/component/agent/agent.yaml new file mode 100644 index 00000000..6af3825d --- /dev/null +++ b/ai/component/agent/agent.yaml @@ -0,0 +1,46 @@ +type: agent +spec: + agent_type: "react" + default_model: "dashscope/qwen-max" + prompt_base_path: "./prompts" + max_iterations: 10 + stage_channel_buffer_size: 5 + mcp_host_name: "mcp_host" + + stages: + - name: "agentThinking" + flow_type: "think" + prompt_file: "agentThink.txt" + temperature: 0.7 + top_p: 0.9 + max_tokens: 2000 + timeout: 60 + enable_tools: true + # extra_prompt: "available tools: [...]" + + - name: "agentTool" + flow_type: "act" + prompt_file: "agentAct.txt" + temperature: 0.7 + top_p: 0.9 + max_tokens: 3000 + timeout: 90 + enable_tools: true + + - name: "observe" + flow_type: "observe" + prompt_file: "agentObserve.txt" + temperature: 0.7 + top_p: 0.9 + max_tokens: 2000 + timeout: 60 + enable_tools: false + + - name: "agentFeedback" + flow_type: "feedback" + prompt_file: "agentFeedback.txt" + temperature: 0.7 + top_p: 0.9 + max_tokens: 1000 + timeout: 30 + enable_tools: false diff --git a/ai/component/agent/react/component.go b/ai/component/agent/react/component.go new file mode 100644 index 00000000..60b4b082 --- /dev/null +++ b/ai/component/agent/react/component.go @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package react + +import ( + "dubbo-admin-ai/component/tools" + "dubbo-admin-ai/runtime" + "fmt" +) + +// AgentComponent Agent component implementation +type AgentComponent struct { + Agent *ReActAgent + agentType string + defaultModel string + promptBasePath string + maxIterations int + stageChannelBufferSize int + mcpHostName string + stages []StageInfo +} + +func NewAgentComponent( + agentType string, + defaultModel string, + promptBasePath string, + maxIterations int, + stageChannelBufferSize int, + mcpHostName string, + stages []StageInfo, +) (runtime.Component, error) { + return &AgentComponent{ + agentType: agentType, + defaultModel: defaultModel, + promptBasePath: promptBasePath, + maxIterations: maxIterations, + stageChannelBufferSize: stageChannelBufferSize, + mcpHostName: mcpHostName, + stages: stages, + }, nil +} + +func (a *AgentComponent) Name() string { + return "agent" +} + +func (a *AgentComponent) Validate() error { + cfg := AgentSpec{ + AgentType: a.agentType, + DefaultModel: a.defaultModel, + PromptBasePath: a.promptBasePath, + MaxIterations: a.maxIterations, + StageChannelBufferSize: a.stageChannelBufferSize, + MCPHostName: a.mcpHostName, + Stages: a.stages, + } + return cfg.Validate() +} + +func (a *AgentComponent) Init(rt *runtime.Runtime) error { + toolsComp, err := rt.GetComponent("tools") + if err != nil { + return fmt.Errorf("tools component not found: %w", err) + } + tools, ok := toolsComp.(*tools.ToolsComponent) + if !ok { + return fmt.Errorf("invalid tools component type") + } + toolRefs := tools.GetToolRefs() + + reactAgent, err := NewReactAgent(rt.GetGenkitRegistry(), a.promptBasePath, a.defaultModel, a.maxIterations, a.stages, toolRefs) + if err != nil { + return fmt.Errorf("failed to create ReAct agent: %w", err) + } + + a.Agent = reactAgent + + rt.GetLogger().Info("Agent component initialized", + "agent_type", a.agentType, + "default_model", a.defaultModel, + "max_iterations", a.maxIterations, + "stages", len(a.stages)) + + return nil +} + +func (a *AgentComponent) Start() error { + return nil +} + +func (a *AgentComponent) Stop() error { + return nil +} diff --git a/ai/component/agent/react/config.go b/ai/component/agent/react/config.go new file mode 100644 index 00000000..54fcba85 --- /dev/null +++ b/ai/component/agent/react/config.go @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package react + +import "fmt" + +const AgentTypeReAct = "react" + +type AgentSpec struct { + AgentType string `yaml:"agent_type"` + DefaultModel string `yaml:"default_model"` + PromptBasePath string `yaml:"prompt_base_path"` + MaxIterations int `yaml:"max_iterations"` + StageChannelBufferSize int `yaml:"stage_channel_buffer_size"` + MCPHostName string `yaml:"mcp_host_name"` + Stages []StageInfo `yaml:"stages"` +} + +type StageInfo struct { + Name string `yaml:"name"` + FlowType string `yaml:"flow_type"` + Model string `yaml:"model,omitempty"` + PromptFile string `yaml:"prompt_file"` + Temperature float64 `yaml:"temperature"` + TopP float64 `yaml:"top_p,omitempty"` + MaxTokens int `yaml:"max_tokens"` + Timeout int `yaml:"timeout"` + EnableTools bool `yaml:"enable_tools"` + ExtraPrompt string `yaml:"extra_prompt,omitempty"` +} + +// ReActDefaultSpec returns default ReAct Agent configuration +func ReActDefaultSpec() *AgentSpec { + return &AgentSpec{ + AgentType: "react", + DefaultModel: "qwen-max", + PromptBasePath: "./prompts", + MaxIterations: 10, + StageChannelBufferSize: 5, + MCPHostName: "mcp_host", + Stages: []StageInfo{ + { + Name: "agentThinking", + FlowType: "think", + PromptFile: "agentThink.txt", + Temperature: 0.7, + TopP: 0.9, + MaxTokens: 2000, + Timeout: 60, + EnableTools: true, + }, + { + Name: "agentTool", + FlowType: "act", + PromptFile: "agentAct.txt", + Temperature: 0.7, + TopP: 0.9, + MaxTokens: 3000, + Timeout: 90, + EnableTools: true, + }, + { + Name: "agentFeedback", + FlowType: "feedback", + PromptFile: "agentFeedback.txt", + Temperature: 0.7, + TopP: 0.9, + MaxTokens: 1000, + Timeout: 30, + EnableTools: false, + }, + { + Name: "observe", + FlowType: "observe", + PromptFile: "agentObserve.txt", + Temperature: 0.7, + TopP: 0.9, + MaxTokens: 2000, + Timeout: 60, + EnableTools: false, + }, + }, + } +} + +// Validate validates the configuration +func (c *AgentSpec) Validate() error { + if c.AgentType == "" { + return fmt.Errorf("agent_type is required") + } + if c.DefaultModel == "" { + return fmt.Errorf("default_model is required") + } + if c.PromptBasePath == "" { + return fmt.Errorf("prompt_base_path is required") + } + if c.MaxIterations <= 0 { + return fmt.Errorf("max_iterations must be greater than 0") + } + if c.StageChannelBufferSize <= 0 { + return fmt.Errorf("stage_channel_buffer_size must be greater than 0") + } + if len(c.Stages) == 0 { + return fmt.Errorf("stages is required") + } + + for i, stage := range c.Stages { + if err := stage.Validate(i); err != nil { + return err + } + } + + return nil +} + +// Validate validates the stage configuration +func (s *StageInfo) Validate(index int) error { + // Validate name + if s.Name == "" { + return fmt.Errorf("stage[%d]: name is required", index) + } + + // Validate flow type + validFlowTypes := map[string]bool{ + "think": true, + "act": true, + "observe": true, + "feedback": true, + } + if !validFlowTypes[s.FlowType] { + return fmt.Errorf("stage[%d]: invalid flow_type '%s', must be one of: think, act, observe, feedback", index, s.FlowType) + } + + if s.PromptFile == "" { + return fmt.Errorf("stage[%d]: prompt_file is required", index) + } + + if s.Temperature <= 0 || s.Temperature > 2.0 { + return fmt.Errorf("stage[%d]: temperature must be in (0, 2.0]", index) + } + + if s.TopP <= 0 || s.TopP > 1.0 { + return fmt.Errorf("stage[%d]: top_p must be in (0, 1.0]", index) + } + + if s.MaxTokens <= 0 { + return fmt.Errorf("stage[%d]: max_tokens must be greater than 0", index) + } + + if s.Timeout <= 0 { + return fmt.Errorf("stage[%d]: timeout must be greater than 0", index) + } + + return nil +} diff --git a/ai/component/agent/react/factory.go b/ai/component/agent/react/factory.go new file mode 100644 index 00000000..2a14509b --- /dev/null +++ b/ai/component/agent/react/factory.go @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package react + +import ( + "dubbo-admin-ai/runtime" + "fmt" + + "gopkg.in/yaml.v3" +) + +// AgentFactory creates an agent component (explicit registration, no init) +func AgentFactory(spec *yaml.Node) (runtime.Component, error) { + var cfg AgentSpec + if err := spec.Decode(&cfg); err != nil { + return nil, fmt.Errorf("failed to decode agent spec: %w", err) + } + + return NewAgentComponent( + cfg.AgentType, + cfg.DefaultModel, + cfg.PromptBasePath, + cfg.MaxIterations, + cfg.StageChannelBufferSize, + cfg.MCPHostName, + cfg.Stages, + ) +} diff --git a/ai/agent/react/react.go b/ai/component/agent/react/react.go similarity index 58% rename from ai/agent/react/react.go rename to ai/component/agent/react/react.go index dd1bc91a..faf8f27e 100644 --- a/ai/agent/react/react.go +++ b/ai/component/agent/react/react.go @@ -5,13 +5,13 @@ import ( "encoding/json" "fmt" "os" + "path" - "dubbo-admin-ai/agent" - "dubbo-admin-ai/config" - "dubbo-admin-ai/manager" - "dubbo-admin-ai/memory" + "dubbo-admin-ai/component/agent" + "dubbo-admin-ai/component/memory" + toolEngine "dubbo-admin-ai/component/tools/engine" + "dubbo-admin-ai/runtime" "dubbo-admin-ai/schema" - "dubbo-admin-ai/tools" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core" @@ -24,12 +24,15 @@ type ThinkOut = schema.ThinkOutput type ActIn = ThinkOut type ActOut = schema.ToolOutputs -// ReActAgent implements Agent interface type ReActAgent struct { registry *genkit.Genkit memoryCtx context.Context orchestrator agent.Orchestrator channels *agent.Channels + + defaultModel string // Default model in "provider/model" format (e.g., "dashscope/qwen-max") + promptBasePath string + maxIterations int } func onStreaming2User(channels *agent.Channels, chunk schema.StreamChunk) error { @@ -49,73 +52,111 @@ func onOutput2Flow(channels *agent.Channels, output schema.Schema) error { return nil } -func Create(g *genkit.Genkit) (*ReActAgent, error) { - var ( - thinkPrompt ai.Prompt - feedbackPrompt ai.Prompt - toolPrompt ai.Prompt - observePrompt ai.Prompt - err error - ) +// stageTypeInfo defines metadata for each stage type +type stageTypeInfo struct { + inType any + outType any + needsTools bool + isStreaming bool +} - memoryCtx := memory.NewMemoryContext(memory.ChatHistoryKey) - history, ok := memoryCtx.Value(memory.ChatHistoryKey).(*memory.History) - if !ok { - return nil, fmt.Errorf("failed to get history from context") - } - // Get Available Tools - var toolManagers []tools.ToolManager - // mcpToolManager, err := tools.NewMCPToolManager(g, config.MCP_HOST_NAME) - // if err != nil { - // return nil, fmt.Errorf("failed to create MCP tool manager: %v", err) - // } - toolManagers = append(toolManagers, - tools.NewMockToolManager(g), - tools.NewInternalToolManager(g, history), - // mcpToolManager, - ) - toolRefs := tools.NewToolRegistry(toolManagers...).AllToolRefs() +var stageTypeRegistry = map[string]stageTypeInfo{ + "think": {inType: ThinkIn{}, outType: ThinkOut{}, needsTools: true, isStreaming: false}, + "act": {inType: ThinkOut{}, outType: nil, needsTools: true, isStreaming: false}, + "observe": {inType: nil, outType: schema.Observation{}, needsTools: false, isStreaming: true}, + "feedback": {inType: ThinkIn{}, outType: nil, needsTools: false, isStreaming: false}, +} - // Build and Register ReAct think prompt - if thinkPrompt, err = buildThinkPrompt(g, toolRefs...); err != nil { - return nil, err - } - if feedbackPrompt, err = buildFeedBackPrompt(g); err != nil { - return nil, err - } - if toolPrompt, err = buildToolSelectionPrompt(g, toolRefs...); err != nil { - return nil, err - } - if observePrompt, err = buildObservePrompt(g); err != nil { +func NewReactAgent(g *genkit.Genkit, promptBasePath string, defaultModel string, maxIterations int, stagesCfg []StageInfo, toolRefs []ai.ToolRef) (*ReActAgent, error) { + memoryCtx := memory.NewMemoryContext(memory.ChatHistoryKey) + channels := agent.NewChannels(len(stagesCfg)) + stages, err := buildStagesFromConfig(g, stagesCfg, promptBasePath, defaultModel, toolRefs) + if err != nil { return nil, err } - channels := agent.NewChannels(config.STAGE_CHANNEL_BUFFER_SIZE) + return &ReActAgent{ + registry: g, + orchestrator: agent.NewOrderOrchestrator(maxIterations, stages...), + memoryCtx: memoryCtx, + channels: channels, + defaultModel: defaultModel, + promptBasePath: promptBasePath, + maxIterations: maxIterations, + }, nil +} + +func buildStagesFromConfig(g *genkit.Genkit, stagesCfg []StageInfo, promptBasePath string, defaultModel string, toolRefs []ai.ToolRef) ([]*agent.Stage, error) { + var stages []*agent.Stage + var observePrompt, feedbackPrompt ai.Prompt - thinkStage := agent.NewStage( - think(g, thinkPrompt), - agent.InLoop, - ) + for _, stageCfg := range stagesCfg { + // 1. Get type information + typeInfo, ok := stageTypeRegistry[stageCfg.FlowType] + if !ok { + return nil, fmt.Errorf("unknown flow type: %s", stageCfg.FlowType) + } - actStage := agent.NewStage( - act(g, nil, toolPrompt), - agent.InLoop, - ) - observerStage := agent.NewStreamStage( - observe(g, observePrompt, feedbackPrompt), - agent.InLoop, - onStreaming2User, - onOutput2Flow, - ) + // 2. Read and build prompt + promptPath := path.Join(promptBasePath, stageCfg.PromptFile) + systemPrompt, err := os.ReadFile(promptPath) + if err != nil { + return nil, fmt.Errorf("failed to read prompt file %s: %w", promptPath, err) + } - orchestrator := agent.NewOrderOrchestrator(thinkStage, actStage, observerStage) + // Prepare tools + var tools []ai.ToolRef + if typeInfo.needsTools && stageCfg.EnableTools { + tools = toolRefs + } - return &ReActAgent{ - registry: g, - orchestrator: orchestrator, - memoryCtx: memoryCtx, - channels: channels, - }, nil + // Prepare additional prompt + extraPrompt := stageCfg.ExtraPrompt + if stageCfg.FlowType == "think" && extraPrompt == "" { + toolsJson, err := json.Marshal(tools) + if err != nil { + return nil, fmt.Errorf("failed to marshal tools: %w", err) + } + extraPrompt = fmt.Sprintf("available tools: %s", string(toolsJson)) + } + + // When building prompt, use default model if not specified in configuration + model := stageCfg.Model + if model == "" { + model = defaultModel + } + + prompt, err := buildPrompt(g, typeInfo.inType, typeInfo.outType, stageCfg.Name, + string(systemPrompt), stageCfg.Temperature, model, extraPrompt, tools...) + if err != nil { + return nil, fmt.Errorf("failed to build prompt for stage %s: %w", stageCfg.Name, err) + } + + // 3. Create stage + var stage *agent.Stage + switch stageCfg.FlowType { + case "think": + stage = agent.NewStage(ThinkFlow(g, prompt), agent.InLoop) + case "act": + stage = agent.NewStage(ActFlow(g, nil, prompt), agent.InLoop) + case "observe": + observePrompt = prompt + continue + case "feedback": + feedbackPrompt = prompt + if observePrompt != nil { + stage = agent.NewStreamStage(observe(g, observePrompt, feedbackPrompt), + agent.InLoop, onStreaming2User, onOutput2Flow) + observePrompt, feedbackPrompt = nil, nil + } + } + + if stage != nil { + stages = append(stages, stage) + } + } + + return stages, nil } func (ra *ReActAgent) Interact(input *schema.UserInput, sessionID string) *agent.Channels { @@ -131,7 +172,7 @@ func (ra *ReActAgent) Interact(input *schema.UserInput, sessionID string) *agent // Add user input to history ra.memoryCtx = context.WithValue(ra.memoryCtx, memory.SessionIDKey, sessionID) - history, ok := ra.memoryCtx.Value(memory.ChatHistoryKey).(*memory.History) + history, ok := ra.memoryCtx.Value(memory.ChatHistoryKey).(*memory.HistoryMemory) if !ok { err = fmt.Errorf("failed to get history from context") ra.channels.ErrorChan <- err @@ -154,77 +195,36 @@ func (ra *ReActAgent) Interact(input *schema.UserInput, sessionID string) *agent return ra.channels } -func (ra *ReActAgent) GetMemory() *memory.History { - h, err := memory.GetHistory(ra.memoryCtx, memory.ChatHistoryKey) +func (ra *ReActAgent) GetMemory() *memory.HistoryMemory { + h, err := memory.GetHistoryMemory(ra.memoryCtx, memory.ChatHistoryKey) if err != nil { return nil } return h } -func buildThinkPrompt(registry *genkit.Genkit, tools ...ai.ToolRef) (ai.Prompt, error) { - // Load system prompt from filesystem - data, err := os.ReadFile(config.PROMPT_DIR_PATH + "/agentThink.txt") - if err != nil { - return nil, fmt.Errorf("failed to read agentThink prompt: %w", err) - } - systemPromptText := string(data) - - toolsJson, err := json.Marshal(tools) - if err != nil { - return nil, fmt.Errorf("failed to marshal tools: %w", err) - } - - return genkit.DefinePrompt(registry, "agentThinking", - ai.WithSystem(systemPromptText), - ai.WithInputType(ThinkIn{}), - ai.WithOutputType(ThinkOut{}), - ai.WithPrompt(fmt.Sprintf("available tools: %s", string(toolsJson))), +func buildPrompt(registry *genkit.Genkit, inType, outType any, tag, prompt string, temp float64, model string, extraPrompt string, tools ...ai.ToolRef) (ai.Prompt, error) { + opts := []ai.PromptOption{ + ai.WithSystem(prompt), ai.WithConfig(&openai.ChatCompletionNewParams{ - Temperature: openai.Float(0.2), + Temperature: openai.Float(temp), }), - ), nil -} - -func buildToolSelectionPrompt(registry *genkit.Genkit, toolRefs ...ai.ToolRef) (ai.Prompt, error) { - data, err := os.ReadFile(config.PROMPT_DIR_PATH + "/agentTool.txt") - if err != nil { - return nil, fmt.Errorf("failed to read agentTool prompt: %w", err) + ai.WithModelName(model), } - return genkit.DefinePrompt(registry, "agentTool", - ai.WithSystem(string(data)), - ai.WithInputType(ThinkOut{}), - ai.WithTools(toolRefs...), - ai.WithReturnToolRequests(true), - ), nil -} - -func buildFeedBackPrompt(registry *genkit.Genkit) (ai.Prompt, error) { - data, err := os.ReadFile(config.PROMPT_DIR_PATH + "/agentFeedback.txt") - if err != nil { - return nil, fmt.Errorf("failed to read agentFeedback prompt: %w", err) + if inType != nil { + opts = append(opts, ai.WithInputType(inType)) } - return genkit.DefinePrompt(registry, "agentFeedback", - ai.WithSystem(string(data)), - ai.WithInputType(ThinkIn{}), - ai.WithConfig(&openai.ChatCompletionNewParams{ - Temperature: openai.Float(0.7), - }), - ), nil -} - -func buildObservePrompt(registry *genkit.Genkit) (ai.Prompt, error) { - data, err := os.ReadFile(config.PROMPT_DIR_PATH + "/agentObserve.txt") - if err != nil { - return nil, fmt.Errorf("failed to read agentObserve prompt: %w", err) + if outType != nil { + opts = append(opts, ai.WithOutputType(outType)) } - return genkit.DefinePrompt(registry, "observe", - ai.WithSystem(string(data)), - ai.WithOutputType(schema.Observation{}), - ai.WithConfig(&openai.ChatCompletionNewParams{ - Temperature: openai.Float(0.5), - }), - ), nil + if extraPrompt != "" { + opts = append(opts, ai.WithPrompt(extraPrompt)) + } + if tools != nil { + opts = append(opts, ai.WithTools(tools...), ai.WithReturnToolRequests(true)) + } + + return genkit.DefinePrompt(registry, tag, opts...), nil } func rawChunkHandler(cb core.StreamCallback[schema.StreamChunk]) ai.ModelStreamCallback { @@ -254,18 +254,18 @@ func feedback(feedbackPrompt ai.Prompt, ctx context.Context, cb core.StreamCallb // ai.WithStreaming() receives ai.ModelStreamCallback type callback function // This callback function is called when the model generates each raw streaming chunk, used for raw chunk processing // The passed cb is user-defined callback function for handling streaming data logic, such as printing -func think( +func ThinkFlow( g *genkit.Genkit, thinkPrompt ai.Prompt, ) agent.NormalFlow { return genkit.DefineFlow(g, agent.ThinkFlowName, func(ctx context.Context, in schema.Schema) (out schema.Schema, err error) { - manager.GetLogger().Info("Thinking...", "input", in) + runtime.GetLogger().Info("Thinking...", "input", in) defer func() { - manager.GetLogger().Info("Think Done.", "output", out, "error", err) + runtime.GetLogger().Info("Think Done.", "output", out, "error", err) }() - history, ok := ctx.Value(memory.ChatHistoryKey).(*memory.History) + history, ok := ctx.Value(memory.ChatHistoryKey).(*memory.HistoryMemory) if !ok { return nil, fmt.Errorf("failed to get history from context") } @@ -277,13 +277,15 @@ func think( return nil, fmt.Errorf("history is empty") } - // execute prompt - manager.GetLogger().Info("Thinking...", "input", history.WindowMemory(sessionID)) + // Execute the thinking prompt with window memory context resp, err := thinkPrompt.Execute(ctx, ai.WithMessages(history.WindowMemory(sessionID)...)) - manager.GetLogger().Info("Think response:", "response", resp.Text()) if err != nil { return nil, fmt.Errorf("failed to execute agentThink prompt: %w", err) } + if resp == nil { + return nil, fmt.Errorf("failed to execute agentThink prompt: empty response") + } + runtime.GetLogger().Info("Think response:", "response", resp.Text()) // Parse output var thinkOut ThinkOut @@ -300,12 +302,12 @@ func think( }) } -func act(g *genkit.Genkit, mcpToolManager *tools.MCPToolManager, toolPrompt ai.Prompt) agent.NormalFlow { +func ActFlow(g *genkit.Genkit, mcpToolManager *toolEngine.MCPToolManager, toolPrompt ai.Prompt) agent.NormalFlow { return genkit.DefineFlow(g, agent.ActFlowName, func(ctx context.Context, in schema.Schema) (out schema.Schema, err error) { - manager.GetLogger().Info("Acting...", "input", in) + runtime.GetLogger().Info("Acting...", "input", in) defer func() { - manager.GetLogger().Info("Act Done.", "output", out, "error", err) + runtime.GetLogger().Info("Act Done.", "output", out, "error", err) }() // Beacause the input is in the history, so don't need to use, just check the type @@ -317,7 +319,7 @@ func act(g *genkit.Genkit, mcpToolManager *tools.MCPToolManager, toolPrompt ai.P return ActOut{}, nil } - history, ok := ctx.Value(memory.ChatHistoryKey).(*memory.History) + history, ok := ctx.Value(memory.ChatHistoryKey).(*memory.HistoryMemory) if !ok { return nil, fmt.Errorf("failed to get history from context") } @@ -326,7 +328,7 @@ func act(g *genkit.Genkit, mcpToolManager *tools.MCPToolManager, toolPrompt ai.P return nil, fmt.Errorf("session id not found in context") } - // Get tool requests form LLM + // Get tool requests from LLM if history.IsEmpty(sessionID) { return nil, fmt.Errorf("history is empty") } @@ -339,14 +341,14 @@ func act(g *genkit.Genkit, mcpToolManager *tools.MCPToolManager, toolPrompt ai.P if len(toolReqs.ToolRequests()) == 0 { return ActOut{Thought: fmt.Sprintf("have unavailable tools in %v, please check available tools list", input.SuggestedTools)}, nil } - manager.GetLogger().Info("tool requests:", "req", toolReqs.ToolRequests()) + runtime.GetLogger().Info("tool requests:", "req", toolReqs.ToolRequests()) // Call tool requests and collect outputs var parts []*ai.Part var actOuts ActOut actOuts.UsageInfo = &ai.GenerationUsage{} for _, req := range toolReqs.ToolRequests() { - output, err := tools.Call(g, mcpToolManager, req.Name, req.Input) + output, err := toolEngine.Call(g, mcpToolManager, req.Name, req.Input) if err != nil { return nil, fmt.Errorf("failed to call tool %s: %w", req.Name, err) } @@ -358,7 +360,7 @@ func act(g *genkit.Genkit, mcpToolManager *tools.MCPToolManager, toolPrompt ai.P parts = append(parts, ai.NewJSONPart(string(outputJson))) actOuts.Add(&output) } - manager.GetLogger().Info("act out:", "out", actOuts) + runtime.GetLogger().Info("act out:", "out", actOuts) // ai.RoleTool's messages will be ingored by ai.WithMessages history.AddHistory(sessionID, ai.NewMessage(ai.RoleModel, nil, parts...)) schema.AccumulateUsage(actOuts.UsageInfo, toolReqs.Usage, in.Usage()) @@ -370,12 +372,12 @@ func act(g *genkit.Genkit, mcpToolManager *tools.MCPToolManager, toolPrompt ai.P func observe(g *genkit.Genkit, observePrompt ai.Prompt, feedbackPrompt ai.Prompt) agent.StreamFlow { return genkit.DefineStreamingFlow(g, agent.ObserveFlowName, func(ctx context.Context, in schema.Schema, cb core.StreamCallback[schema.StreamChunk]) (out schema.Schema, err error) { - manager.GetLogger().Info("Observing...", "input", in) + runtime.GetLogger().Info("Observing...", "input", in) defer func() { - manager.GetLogger().Info("Observe Done.", "output", out, "error", err) + runtime.GetLogger().Info("Observe Done.", "output", out, "error", err) }() - history, ok := ctx.Value(memory.ChatHistoryKey).(*memory.History) + history, ok := ctx.Value(memory.ChatHistoryKey).(*memory.HistoryMemory) if !ok { return nil, fmt.Errorf("failed to get history from context") } @@ -403,7 +405,7 @@ func observe(g *genkit.Genkit, observePrompt ai.Prompt, feedbackPrompt ai.Prompt if err != nil { return nil, fmt.Errorf("failed to parse observe prompt response: %w", err) } - manager.GetLogger().Info("Observe out:", "out", observation) + runtime.GetLogger().Info("Observe out:", "out", observation) history.AddHistory(sessionID, resp.Message) feedback(feedbackPrompt, ctx, cb, history.WindowMemory(sessionID)...) diff --git a/ai/component/agent/react/test/flow_test.go b/ai/component/agent/react/test/flow_test.go new file mode 100644 index 00000000..99e92350 --- /dev/null +++ b/ai/component/agent/react/test/flow_test.go @@ -0,0 +1,49 @@ +package reacttest + +import ( + "strings" + "testing" + + compReact "dubbo-admin-ai/component/agent/react" +) + +func validAgentSpec() *compReact.AgentSpec { + return &compReact.AgentSpec{ + AgentType: compReact.AgentTypeReAct, + DefaultModel: "qwen-max", + PromptBasePath: "./prompts", + MaxIterations: 5, + StageChannelBufferSize: 2, + MCPHostName: "mcp_host", + Stages: []compReact.StageInfo{{ + Name: "thinking", + FlowType: "think", + PromptFile: "agentThink.txt", + Temperature: 0.7, + TopP: 0.9, + MaxTokens: 1000, + Timeout: 30, + }}, + } +} + +func TestAgentComponent_ValidateStage(t *testing.T) { + tests := []struct { + name string + mutate func(*compReact.AgentSpec) + errContain string + }{ + {name: "invalid_flow_type", mutate: func(c *compReact.AgentSpec) { c.Stages[0].FlowType = "invalid" }, errContain: "invalid flow_type"}, + {name: "prompt_required", mutate: func(c *compReact.AgentSpec) { c.Stages[0].PromptFile = "" }, errContain: "prompt_file is required"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := validAgentSpec() + tt.mutate(cfg) + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), tt.errContain) { + t.Fatalf("expected error containing %q, got %v", tt.errContain, err) + } + }) + } +} diff --git a/ai/component/agent/react/test/workflow_test.go b/ai/component/agent/react/test/workflow_test.go new file mode 100644 index 00000000..56403707 --- /dev/null +++ b/ai/component/agent/react/test/workflow_test.go @@ -0,0 +1,129 @@ +package reacttest + +import ( + "context" + "errors" + "strings" + "testing" + + compReact "dubbo-admin-ai/component/agent/react" + "dubbo-admin-ai/component/memory" + "dubbo-admin-ai/schema" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" +) + +type stubPrompt struct { + resp *ai.ModelResponse + err error +} + +func (s *stubPrompt) Name() string { return "stub" } +func (s *stubPrompt) Execute(ctx context.Context, opts ...ai.PromptExecuteOption) (*ai.ModelResponse, error) { + return s.resp, s.err +} +func (s *stubPrompt) Render(ctx context.Context, input any) (*ai.GenerateActionOptions, error) { + return &ai.GenerateActionOptions{}, nil +} + +func contextWithHistory(sessionID string) context.Context { + ctx := memory.NewMemoryContext(memory.ChatHistoryKey) + history, _ := memory.GetHistoryMemory(ctx, memory.ChatHistoryKey) + history.AddHistory(sessionID, ai.NewUserMessage(ai.NewTextPart("hello"))) + ctx = context.WithValue(ctx, memory.SessionIDKey, sessionID) + return ctx +} + +func TestActFlow(t *testing.T) { + tests := []struct { + name string + setup func(*genkit.Genkit) ai.Prompt + input schema.ThinkOutput + errContain string + assertFn func(t *testing.T, out schema.Schema) + }{ + { + name: "general_inquiry_no_tools", + setup: func(*genkit.Genkit) ai.Prompt { return &stubPrompt{} }, + input: schema.ThinkOutput{Intent: schema.GeneralInquiry}, + assertFn: func(t *testing.T, out schema.Schema) { + actOut, ok := out.(schema.ToolOutputs) + if !ok { + t.Fatalf("output type = %T, want schema.ToolOutputs", out) + } + if len(actOut.Outputs) != 0 { + t.Fatalf("expected no tool outputs, got %d", len(actOut.Outputs)) + } + }, + }, + { + name: "with_tool_call_returns_outputs", + setup: func(g *genkit.Genkit) ai.Prompt { + genkit.DefineTool(g, "mock_tool", "mock tool", func(ctx *ai.ToolContext, input map[string]any) (map[string]any, error) { + return map[string]any{"tool_name": "mock_tool", "summary": "ok", "result": map[string]any{"echo": input["q"]}}, nil + }) + return &stubPrompt{resp: &ai.ModelResponse{Message: ai.NewMessage(ai.RoleModel, nil, + ai.NewToolRequestPart(&ai.ToolRequest{Name: "mock_tool", Input: map[string]any{"q": "ping"}}), + )}} + }, + input: schema.ThinkOutput{Intent: schema.PerformanceInvestigation, SuggestedTools: []string{"mock_tool"}}, + assertFn: func(t *testing.T, out schema.Schema) { + actOut, ok := out.(schema.ToolOutputs) + if !ok { + t.Fatalf("output type = %T, want schema.ToolOutputs", out) + } + if len(actOut.Outputs) < 1 || actOut.Outputs[0].ToolName != "mock_tool" { + t.Fatalf("unexpected tool outputs: %+v", actOut.Outputs) + } + }, + }, + { + name: "tool_error_handling", + setup: func(g *genkit.Genkit) ai.Prompt { + genkit.DefineTool(g, "broken_tool", "broken", func(ctx *ai.ToolContext, input map[string]any) (map[string]any, error) { + return nil, errors.New("boom") + }) + return &stubPrompt{resp: &ai.ModelResponse{Message: ai.NewMessage(ai.RoleModel, nil, + ai.NewToolRequestPart(&ai.ToolRequest{Name: "broken_tool", Input: map[string]any{"x": 1}}), + )}} + }, + input: schema.ThinkOutput{Intent: schema.PerformanceInvestigation, SuggestedTools: []string{"broken_tool"}}, + errContain: "broken_tool", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := genkit.Init(context.Background()) + flow := compReact.ActFlow(g, nil, tt.setup(g)) + out, err := flow.Run(contextWithHistory("session"), tt.input) + if tt.errContain != "" { + if err == nil || !strings.Contains(err.Error(), tt.errContain) { + t.Fatalf("expected error containing %q, got %v", tt.errContain, err) + } + return + } + if err != nil { + t.Fatalf("act flow error: %v", err) + } + tt.assertFn(t, out) + }) + } +} + +func TestThinkFlow(t *testing.T) { + g := genkit.Init(context.Background()) + flow := compReact.ThinkFlow(g, &stubPrompt{resp: nil, err: errors.New("execute failed")}) + + defer func() { + if r := recover(); r != nil { + t.Fatalf("unexpected panic: %v", r) + } + }() + + _, err := flow.Run(contextWithHistory("s3"), schema.ThinkInput{SessionID: "s3", UserInput: &schema.UserInput{Content: "hi"}}) + if err == nil || !strings.Contains(err.Error(), "failed to execute agentThink prompt") { + t.Fatalf("expected wrapped execute error, got %v", err) + } +} diff --git a/ai/component/logger/component.go b/ai/component/logger/component.go new file mode 100644 index 00000000..1950acf1 --- /dev/null +++ b/ai/component/logger/component.go @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logger + +import ( + "dubbo-admin-ai/runtime" + "fmt" + "log/slog" + + "github.com/dusted-go/logging/prettylog" +) + +// LoggerComponent implements the logger component +type LoggerComponent struct { + // Runtime configuration fields (converted from cfg, original cfg not retained) + level string +} + +// NewLoggerComponent creates a logger component instance +func NewLoggerComponent(level string) (runtime.Component, error) { + return &LoggerComponent{ + level: level, + }, nil +} + +// Name returns the component name +func (l *LoggerComponent) Name() string { + return "logger" +} + +func (l *LoggerComponent) Validate() error { + switch l.level { + case "debug", "info", "warn", "error": + return nil + default: + return fmt.Errorf("invalid logger level: %s", l.level) + } +} + +func (l *LoggerComponent) Init(rt *runtime.Runtime) error { + var logger *slog.Logger + switch l.level { + case "debug": + logger = l.debugLogger() + case "info": + logger = l.infoLogger() + default: + logger = l.infoLogger() + } + slog.SetDefault(logger) + slog.Info("Logger component initialized", + "level", l.level) + + return nil +} + +func (l *LoggerComponent) Start() error { + return nil +} + +func (l *LoggerComponent) Stop() error { + return nil +} + +func (l *LoggerComponent) debugLogger() *slog.Logger { + slog.SetDefault( + slog.New( + prettylog.NewHandler(&slog.HandlerOptions{ + Level: slog.LevelDebug, + AddSource: true, + ReplaceAttr: nil, + }), + ), + ) + return slog.Default() +} + +func (l *LoggerComponent) infoLogger() *slog.Logger { + slog.SetDefault( + slog.New( + prettylog.NewHandler(&slog.HandlerOptions{ + Level: slog.LevelInfo, + AddSource: false, + ReplaceAttr: nil, + }), + ), + ) + return slog.Default() +} diff --git a/ai/component/logger/config.go b/ai/component/logger/config.go new file mode 100644 index 00000000..35a86483 --- /dev/null +++ b/ai/component/logger/config.go @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logger + +// LoggerSpec defines logger configuration +type LoggerSpec struct { + Level string `yaml:"level"` +} + +// DefaultLoggerSpec returns default logger configuration +func DefaultLoggerSpec() *LoggerSpec { + return &LoggerSpec{ + Level: "info", + } +} diff --git a/ai/component/logger/factory.go b/ai/component/logger/factory.go new file mode 100644 index 00000000..d72cdba8 --- /dev/null +++ b/ai/component/logger/factory.go @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logger + +import ( + "dubbo-admin-ai/runtime" + "fmt" + + "gopkg.in/yaml.v3" +) + +// LoggerFactory creates a logger component (explicit registration, no init) +func LoggerFactory(spec *yaml.Node) (runtime.Component, error) { + var cfg LoggerSpec + if err := spec.Decode(&cfg); err != nil { + return nil, fmt.Errorf("failed to decode logger spec: %w", err) + } + + // Call constructor, converting cfg fields to specific parameters + return NewLoggerComponent(cfg.Level) +} diff --git a/ai/component/logger/logger.yaml b/ai/component/logger/logger.yaml new file mode 100644 index 00000000..63731216 --- /dev/null +++ b/ai/component/logger/logger.yaml @@ -0,0 +1,3 @@ +type: logger +spec: + level: "info" # Logging level: debug, info, warn, error diff --git a/ai/component/memory/component.go b/ai/component/memory/component.go new file mode 100644 index 00000000..265458a2 --- /dev/null +++ b/ai/component/memory/component.go @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package memory + +import ( + "context" + "dubbo-admin-ai/runtime" + "fmt" +) + +// MemoryComponent implements the memory component +// TODO(memory, 2026-02-24): Inject unified memory interface to support different memory implementations +// Current implementation uses HistoryMemory directly +type MemoryComponent struct { + historyKey HistoryKey + maxTurns int + memoryCtx context.Context + memory *HistoryMemory +} + +func NewMemoryComponent(historyKey HistoryKey, maxTurns ...int) (runtime.Component, error) { + limit := 100 + if len(maxTurns) > 0 { + limit = maxTurns[0] + } + return &MemoryComponent{ + historyKey: historyKey, + maxTurns: limit, + }, nil +} + +func (m *MemoryComponent) Name() string { + return "memory" +} + +func (m *MemoryComponent) Validate() error { + if m.maxTurns <= 0 { + return fmt.Errorf("max_turns must be greater than 0") + } + return nil +} + +func (m *MemoryComponent) Init(rt *runtime.Runtime) error { + m.memoryCtx = NewMemoryContext(m.historyKey) + history, err := GetHistoryMemory(m.memoryCtx, m.historyKey) + if err != nil { + return fmt.Errorf("failed to initialize history: %w", err) + } + m.memory = history + + rt.GetLogger().Info("Memory component initialized", + "history_key", m.historyKey) + + return nil +} + +func (m *MemoryComponent) Start() error { + return nil +} + +func (m *MemoryComponent) Stop() error { + return nil +} + +// GetContext returns the memory context +func (m *MemoryComponent) GetContext() context.Context { + return m.memoryCtx +} + +// TODO(memory, 2026-02-24): Provide unified interface for different memory types (HistoryMemory, VectorMemory, etc.) +// GetMemory returns the underlying HistoryMemory instance +func (m *MemoryComponent) GetMemory() (*HistoryMemory, error) { + if m.memory == nil { + return nil, fmt.Errorf("history not initialized") + } + return m.memory, nil +} diff --git a/ai/component/memory/config.go b/ai/component/memory/config.go new file mode 100644 index 00000000..e5263fbe --- /dev/null +++ b/ai/component/memory/config.go @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package memory + +type MemorySpec struct { + HistoryKey HistoryKey `yaml:"history_key"` // History key name + MaxTurns int `yaml:"max_turns"` // Maximum conversation turns +} + +func DefaultMemorySpec() *MemorySpec { + return &MemorySpec{ + HistoryKey: ChatHistoryKey, + MaxTurns: 100, + } +} diff --git a/ai/component/memory/factory.go b/ai/component/memory/factory.go new file mode 100644 index 00000000..51a53f73 --- /dev/null +++ b/ai/component/memory/factory.go @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package memory + +import ( + "dubbo-admin-ai/runtime" + "fmt" + + "gopkg.in/yaml.v3" +) + +// MemoryFactory creates a memory component (explicit registration, no init) +func MemoryFactory(spec *yaml.Node) (runtime.Component, error) { + var cfg MemorySpec + if err := spec.Decode(&cfg); err != nil { + return nil, fmt.Errorf("failed to decode memory spec: %w", err) + } + return NewMemoryComponent(cfg.HistoryKey, cfg.MaxTurns) +} diff --git a/ai/memory/history.go b/ai/component/memory/history.go similarity index 82% rename from ai/memory/history.go rename to ai/component/memory/history.go index d23f5eaa..9c7cf13b 100644 --- a/ai/memory/history.go +++ b/ai/component/memory/history.go @@ -45,7 +45,7 @@ func (t *Turn) Messages() []*ai.Message { return msgs } -type History struct { +type HistoryMemory struct { mu sync.RWMutex windowMemory map[string]*utils.Window[*Turn] historyMemory map[string][]*Turn @@ -55,23 +55,23 @@ func NewMemoryContext(key HistoryKey) context.Context { return context.WithValue( context.Background(), key, - &History{ + &HistoryMemory{ windowMemory: make(map[string]*utils.Window[*Turn]), historyMemory: make(map[string][]*Turn), }, ) } -// GetHistory 从上下文中获取 History -func GetHistory(ctx context.Context, key HistoryKey) (*History, error) { - history, ok := ctx.Value(key).(*History) +// GetHistoryMemory retrieves HistoryMemory from context +func GetHistoryMemory(ctx context.Context, key HistoryKey) (*HistoryMemory, error) { + history, ok := ctx.Value(key).(*HistoryMemory) if !ok { return nil, fmt.Errorf("failed to get history from context") } return history, nil } -func (h *History) AddHistory(sessionID string, message ...*ai.Message) { +func (h *HistoryMemory) AddHistory(sessionID string, message ...*ai.Message) { h.mu.Lock() defer h.mu.Unlock() @@ -107,15 +107,15 @@ func (h *History) AddHistory(sessionID string, message ...*ai.Message) { turn.ModelMessages = append(turn.ModelMessages, modelMsgs...) } -func (h *History) IsEmpty(sessionID string) bool { +func (h *HistoryMemory) IsEmpty(sessionID string) bool { h.mu.RLock() defer h.mu.RUnlock() - // 检查该 session 的窗口是否为空 + // Check if the window for this session is empty return h.windowMemory == nil || h.windowMemory[sessionID] == nil || h.windowMemory[sessionID].IsEmpty() } -func (h *History) AllMemory(sessionID string) []*ai.Message { +func (h *HistoryMemory) AllMemory(sessionID string) []*ai.Message { h.mu.RLock() defer h.mu.RUnlock() @@ -137,7 +137,7 @@ func (h *History) AllMemory(sessionID string) []*ai.Message { return result } -func (h *History) WindowMemory(sessionID string) []*ai.Message { +func (h *HistoryMemory) WindowMemory(sessionID string) []*ai.Message { h.mu.RLock() defer h.mu.RUnlock() @@ -155,7 +155,7 @@ func (h *History) WindowMemory(sessionID string) []*ai.Message { return result } -func (h *History) Clear(sessionID string) { +func (h *HistoryMemory) Clear(sessionID string) { h.mu.Lock() defer h.mu.Unlock() @@ -164,14 +164,14 @@ func (h *History) Clear(sessionID string) { } } -func (h *History) ClearAll() { +func (h *HistoryMemory) ClearAll() { h.mu.Lock() defer h.mu.Unlock() h.windowMemory = make(map[string]*utils.Window[*Turn]) } -func (h *History) GetAllSessions() []string { +func (h *HistoryMemory) GetAllSessions() []string { h.mu.RLock() defer h.mu.RUnlock() @@ -186,7 +186,7 @@ func (h *History) GetAllSessions() []string { return sessions } -func (h *History) SystemMemory(sessionID string) []*ai.Message { +func (h *HistoryMemory) SystemMemory(sessionID string) []*ai.Message { h.mu.RLock() defer h.mu.RUnlock() @@ -203,7 +203,7 @@ func (h *History) SystemMemory(sessionID string) []*ai.Message { return result } -func (h *History) UserMemory(sessionID string) []*ai.Message { +func (h *HistoryMemory) UserMemory(sessionID string) []*ai.Message { h.mu.RLock() defer h.mu.RUnlock() @@ -220,8 +220,8 @@ func (h *History) UserMemory(sessionID string) []*ai.Message { return result } -// ModelMemory 获取指定 session 的模型消息历史 -func (h *History) ModelMemory(sessionID string) []*ai.Message { +// ModelMemory retrieves the model message history for the specified session +func (h *HistoryMemory) ModelMemory(sessionID string) []*ai.Message { h.mu.RLock() defer h.mu.RUnlock() @@ -238,7 +238,7 @@ func (h *History) ModelMemory(sessionID string) []*ai.Message { return result } -func (h *History) NextTurn(sessionID string) error { +func (h *HistoryMemory) NextTurn(sessionID string) error { h.mu.Lock() defer h.mu.Unlock() diff --git a/ai/component/memory/memory.yaml b/ai/component/memory/memory.yaml new file mode 100644 index 00000000..e39f5a48 --- /dev/null +++ b/ai/component/memory/memory.yaml @@ -0,0 +1,4 @@ +type: memory +spec: + history_key: "chat_history" + max_turns: 100 diff --git a/ai/component/memory/test/history_test.go b/ai/component/memory/test/history_test.go new file mode 100644 index 00000000..fadd86f7 --- /dev/null +++ b/ai/component/memory/test/history_test.go @@ -0,0 +1,171 @@ +package memorytest + +import ( + "fmt" + "strings" + "sync" + "testing" + + compMemory "dubbo-admin-ai/component/memory" + + "github.com/firebase/genkit/go/ai" +) + +func newHistoryMemory(t *testing.T) *compMemory.HistoryMemory { + t.Helper() + ctx := compMemory.NewMemoryContext(compMemory.ChatHistoryKey) + h, err := compMemory.GetHistoryMemory(ctx, compMemory.ChatHistoryKey) + if err != nil { + t.Fatalf("GetHistoryMemory() error: %v", err) + } + return h +} + +func TestHistoryMemory_AddHistory(t *testing.T) { + h := newHistoryMemory(t) + sid := "session-1" + h.AddHistory(sid, ai.NewUserMessage(ai.NewTextPart("hello"))) + + got := h.UserMemory(sid) + if len(got) != 1 { + t.Fatalf("user memory len = %d, want 1", len(got)) + } +} + +func TestHistoryMemory_NextTurn(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T, h *compMemory.HistoryMemory) + }{ + { + name: "archives_current_turn", + run: func(t *testing.T, h *compMemory.HistoryMemory) { + sid := "session-1" + h.AddHistory(sid, ai.NewUserMessage(ai.NewTextPart("first"))) + if err := h.NextTurn(sid); err != nil { + t.Fatalf("NextTurn() error: %v", err) + } + if len(h.WindowMemory(sid)) != 0 { + t.Fatalf("window memory should be empty after archive") + } + if len(h.AllMemory(sid)) == 0 { + t.Fatalf("history memory should contain archived turn") + } + }, + }, + { + name: "session_full", + run: func(t *testing.T, h *compMemory.HistoryMemory) { + sid := "session-full" + h.AddHistory(sid, ai.NewUserMessage(ai.NewTextPart("seed"))) + for i := 0; i < compMemory.TurnLimit; i++ { + if err := h.NextTurn(sid); err != nil { + if strings.Contains(err.Error(), "context is full") { + return + } + t.Fatalf("unexpected error at step %d: %v", i, err) + } + h.AddHistory(sid, ai.NewUserMessage(ai.NewTextPart(fmt.Sprintf("turn-%d", i)))) + } + if err := h.NextTurn(sid); err == nil || !strings.Contains(err.Error(), "context is full") { + t.Fatalf("expected context full error, got %v", err) + } + }, + }, + { + name: "empty_window_safety", + run: func(t *testing.T, h *compMemory.HistoryMemory) { + sid := "session-empty" + h.AddHistory(sid, ai.NewUserMessage(ai.NewTextPart("seed"))) + if err := h.NextTurn(sid); err != nil { + t.Fatalf("first NextTurn() error: %v", err) + } + defer func() { + r := recover() + if r == nil { + t.Fatalf("expected panic on empty window") + } + if !strings.Contains(fmt.Sprint(r), "window is empty") { + t.Fatalf("unexpected panic: %v", r) + } + }() + _ = h.NextTurn(sid) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t, newHistoryMemory(t)) + }) + } +} + +func TestHistoryMemory_Concurrency(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T, h *compMemory.HistoryMemory) + }{ + { + name: "concurrent_add_history", + run: func(t *testing.T, h *compMemory.HistoryMemory) { + sid := "session-concurrent" + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + h.AddHistory(sid, ai.NewUserMessage(ai.NewTextPart(fmt.Sprintf("m-%d", i)))) + }(i) + } + wg.Wait() + if len(h.UserMemory(sid)) == 0 { + t.Fatalf("expected user messages after concurrent writes") + } + }, + }, + { + name: "concurrent_read_write", + run: func(t *testing.T, h *compMemory.HistoryMemory) { + sid := "session-rw" + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + for j := 0; j < 50; j++ { + h.AddHistory(sid, ai.NewUserMessage(ai.NewTextPart(fmt.Sprintf("w-%d-%d", i, j)))) + } + }(i) + } + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 50; j++ { + _ = h.AllMemory(sid) + _ = h.WindowMemory(sid) + } + }() + } + wg.Wait() + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t, newHistoryMemory(t)) + }) + } +} + +func TestMemoryComponent_Validate(t *testing.T) { + comp, err := compMemory.NewMemoryComponent(compMemory.ChatHistoryKey, 0) + if err != nil { + t.Fatalf("NewMemoryComponent() error: %v", err) + } + if err := comp.Validate(); err == nil || !strings.Contains(err.Error(), "max_turns") { + t.Fatalf("expected max_turns validation error, got %v", err) + } +} diff --git a/ai/component/models/component.go b/ai/component/models/component.go new file mode 100644 index 00000000..d59242f2 --- /dev/null +++ b/ai/component/models/component.go @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package models + +import ( + "dubbo-admin-ai/runtime" + "fmt" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/compat_oai" + "github.com/openai/openai-go/option" +) + +type ModelsComponent struct { + defaultModel string + defaultEmbedding string + providers map[string]ProviderConfig +} + +// NewModelsComponent creates a Models component instance +func NewModelsComponent( + defaultModel string, + defaultEmbedding string, + providers map[string]ProviderConfig, +) (runtime.Component, error) { + return &ModelsComponent{ + defaultModel: defaultModel, + defaultEmbedding: defaultEmbedding, + providers: providers, + }, nil +} + +func (m *ModelsComponent) Name() string { + return "models" +} + +func (m *ModelsComponent) Validate() error { + if m.defaultModel == "" { + return fmt.Errorf("default_model is required") + } + if m.defaultEmbedding == "" { + return fmt.Errorf("default_embedding is required") + } + if len(m.providers) == 0 { + return fmt.Errorf("at least one provider must be configured") + } + for name, provider := range m.providers { + if provider.BaseURL == "" { + return fmt.Errorf("provider %s base_url is required", name) + } + } + return nil +} + +func (m *ModelsComponent) Init(rt *runtime.Runtime) error { + var plugins []api.Plugin + + for providerName, cfg := range m.providers { + if cfg.APIKey == "" { + continue + } + plugin := createProviderPlugin(providerName, cfg) + + if plugin != nil { + plugins = append(plugins, plugin) + } + } + + ctx := rt.GetContext() + genkitRegistry := genkit.Init(ctx, + genkit.WithPlugins(plugins...), + genkit.WithDefaultModel(m.defaultModel), + ) + + if rt.GetGenkitRegistry() == nil { + rt.SetGenkitRegistry(genkitRegistry) + } else { + rt.GetLogger().Warn("Genkit registry already set, skipping initialization") + } + + totalModels := 0 + totalEmbedders := 0 + + for _, plugin := range plugins { + if oaiCompat, ok := plugin.(*compat_oai.OpenAICompatible); ok { + providerName := oaiCompat.Provider + providerCfg, exists := m.providers[providerName] + if !exists { + continue + } + // Register all models + for _, modelCfg := range providerCfg.Models { + registerModel(oaiCompat, providerName, modelCfg) + totalModels++ + } + // Register all embedding models + for _, embedderCfg := range providerCfg.Embedders { + registerEmbedder(oaiCompat, providerName, embedderCfg) + totalEmbedders++ + } + } + } + + rt.GetLogger().Info("Models component initialized", + "default_model", m.defaultModel, + "default_embedding", m.defaultEmbedding, + "providers", len(m.providers), + "total_models", totalModels, + "total_embedders", totalEmbedders) + + return nil +} + +func (m *ModelsComponent) Start() error { + return nil +} + +func (m *ModelsComponent) Stop() error { + return nil +} + +func registerModel(oaiCompat *compat_oai.OpenAICompatible, providerName string, cfg ModelInfo) { + var supports *ai.ModelSupports + switch cfg.Type { + case "chat": + supports = &compat_oai.BasicText + case "multimodal": + supports = &compat_oai.Multimodal + case "code": + supports = &compat_oai.BasicText + default: + supports = &compat_oai.BasicText + } + + oaiCompat.DefineModel(providerName, cfg.Key, ai.ModelOptions{ + Label: cfg.Name, + Supports: supports, + Versions: []string{cfg.Key}, + }) +} + +func registerEmbedder(oaiCompat *compat_oai.OpenAICompatible, providerName string, cfg EmbedderInfo) { + var inputTypes []string + embedderType := cfg.Type + if embedderType == "" { + embedderType = "text" + } + + switch embedderType { + case "text": + inputTypes = []string{"text"} + case "image": + inputTypes = []string{"image", "text"} + case "audio": + inputTypes = []string{"audio"} + case "multimodal": + inputTypes = []string{"text", "image", "audio"} + default: + inputTypes = []string{"text"} + } + + oaiCompat.DefineEmbedder(providerName, cfg.Key, &ai.EmbedderOptions{ + Label: cfg.Name, + Supports: &ai.EmbedderSupports{Input: inputTypes}, + Dimensions: cfg.Dimensions, + }) +} + +// createProviderPlugin creates a single provider plugin +func createProviderPlugin(providerName string, cfg ProviderConfig) api.Plugin { + // Gemini requires special handling (not currently supported) + if providerName == "gemini" { + return nil + } + + // Check API Key + if cfg.APIKey == "" { + return nil + } + + // Use OpenAICompatible directly + return &compat_oai.OpenAICompatible{ + Provider: providerName, + Opts: []option.RequestOption{ + option.WithAPIKey(cfg.APIKey), + option.WithBaseURL(cfg.BaseURL), + }, + } +} diff --git a/ai/component/models/config.go b/ai/component/models/config.go new file mode 100644 index 00000000..ea1504cf --- /dev/null +++ b/ai/component/models/config.go @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package models + +// ModelsSpec Models configuration +type ModelsSpec struct { + DefaultModel string `yaml:"default_model"` + DefaultEmbedding string `yaml:"default_embedding"` + Providers map[string]ProviderConfig `yaml:"providers"` +} + +// ProviderConfig Provider configuration +type ProviderConfig struct { + APIKey string `yaml:"api_key"` + BaseURL string `yaml:"base_url"` + Models []ModelInfo `yaml:"models,omitempty"` + Embedders []EmbedderInfo `yaml:"embedders,omitempty"` + Config map[string]any `yaml:"config,omitempty"` +} + +// ModelInfo Model information +type ModelInfo struct { + Name string `yaml:"name"` + Key string `yaml:"key"` + Type string `yaml:"type,omitempty"` + Config map[string]any `yaml:"config,omitempty"` +} + +// EmbedderInfo Embedder information +type EmbedderInfo struct { + Name string `yaml:"name"` + Key string `yaml:"key"` + Type string `yaml:"type,omitempty"` + Dimensions int `yaml:"dimensions"` + Config map[string]any `yaml:"config,omitempty"` +} + +func DefaultModelsSpec() *ModelsSpec { + return &ModelsSpec{ + DefaultModel: "dashscope/qwen-max", + DefaultEmbedding: "dashscope/qwen3-embedding", + Providers: make(map[string]ProviderConfig), + } +} diff --git a/ai/component/models/factory.go b/ai/component/models/factory.go new file mode 100644 index 00000000..567b8c51 --- /dev/null +++ b/ai/component/models/factory.go @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package models + +import ( + "dubbo-admin-ai/runtime" + "fmt" + + "gopkg.in/yaml.v3" +) + +// ModelsFactory component factory function (explicit registration, does not use init) +func ModelsFactory(spec *yaml.Node) (runtime.Component, error) { + var cfg ModelsSpec + if err := spec.Decode(&cfg); err != nil { + return nil, fmt.Errorf("failed to decode models spec: %w", err) + } + + return NewModelsComponent( + cfg.DefaultModel, + cfg.DefaultEmbedding, + cfg.Providers, + ) +} diff --git a/ai/component/models/models.yaml b/ai/component/models/models.yaml new file mode 100644 index 00000000..03dc9055 --- /dev/null +++ b/ai/component/models/models.yaml @@ -0,0 +1,58 @@ +type: models +spec: + default_model: "dashscope/qwen-max" + default_embedding: "dashscope/qwen3-embedding" + providers: + dashscope: + api_key: "${DASHSCOPE_API_KEY}" + base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1" + models: + - name: "qwen-max" + key: "qwen-max" + type: "chat" + - name: "qwen-plus" + key: "qwen-plus" + type: "chat" + - name: "qwen-flash" + key: "qwen-flash" + type: "chat" + - name: "qwen3-coder" + key: "qwen3-coder-plus" + type: "code" + embedders: + - name: "qwen3-embedding" + key: "text-embedding-v4" + type: "text" + dimensions: 1024 + + gemini: + api_key: "${GEMINI_API_KEY}" + base_url: "https://generativelanguage.googleapis.com/v1beta" + models: + - name: "gemini-pro" + key: "gemini-pro" + type: "chat" + - name: "gemini-pro-vision" + key: "gemini-pro-vision" + type: "multimodal" + embedders: + - name: "gemini-embedding" + key: "text-embedding-004" + type: "text" + dimensions: 768 + + siliconflow: + api_key: "${SILICONFLOW_API_KEY}" + base_url: "https://api.siliconflow.cn/v1" + models: + - name: "gpt-3.5-turbo" + key: "gpt-3.5-turbo" + type: "chat" + - name: "gpt-4" + key: "gpt-4" + type: "chat" + embedders: + - name: "text-embedding-ada-002" + key: "text-embedding-ada-002" + type: "text" + dimensions: 1536 diff --git a/ai/component/models/test/models_test.go b/ai/component/models/test/models_test.go new file mode 100644 index 00000000..54c6bd2d --- /dev/null +++ b/ai/component/models/test/models_test.go @@ -0,0 +1,27 @@ +package modelstest + +import ( + "dubbo-admin-ai/component/models" + "strings" + "testing" +) + +func TestModelsComponent_Validate(t *testing.T) { + comp, err := models.NewModelsComponent("dashscope/qwen-max", "dashscope/qwen3-embedding", map[string]models.ProviderConfig{}) + if err != nil { + t.Fatalf("NewModelsComponent() error: %v", err) + } + if err := comp.Validate(); err == nil || !strings.Contains(err.Error(), "at least one provider") { + t.Fatalf("expected providers validation error, got %v", err) + } + + comp2, err := models.NewModelsComponent("dashscope/qwen-max", "dashscope/qwen3-embedding", map[string]models.ProviderConfig{ + "dashscope": {APIKey: "x", BaseURL: ""}, + }) + if err != nil { + t.Fatalf("NewModelsComponent() error: %v", err) + } + if err := comp2.Validate(); err == nil || !strings.Contains(err.Error(), "base_url") { + t.Fatalf("expected base_url validation error, got %v", err) + } +} diff --git a/ai/component/rag/component.go b/ai/component/rag/component.go new file mode 100644 index 00000000..18db962c --- /dev/null +++ b/ai/component/rag/component.go @@ -0,0 +1,372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rag + +import ( + "context" + "dubbo-admin-ai/config" + "dubbo-admin-ai/runtime" + "fmt" + + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/components/retriever" +) + +// Reranker 重排序器接口 +type Reranker interface { + Rerank(ctx context.Context, query string, docs any, opts ...RerankOption) ([]*RetrieveResult, error) +} + +// ============= 子组件包装器 ============= + +// loaderComponent Loader 组件包装器 +type loaderComponent struct { + cfg *config.Config + loader document.Loader +} + +func newLoaderComponent(cfg *config.Config) *loaderComponent { + return &loaderComponent{cfg: cfg} +} + +func (c *loaderComponent) Name() string { return "loader" } +func (c *loaderComponent) Validate() error { + return nil +} + +func (c *loaderComponent) Init(rt *runtime.Runtime) error { + loader, err := newLoader(c.cfg) + if err != nil { + return fmt.Errorf("failed to create loader: %w", err) + } + c.loader = loader + rt.GetLogger().Info("Loader component initialized", "type", c.cfg.Type) + return nil +} + +func (c *loaderComponent) Start() error { return nil } +func (c *loaderComponent) Stop() error { return nil } +func (c *loaderComponent) get() document.Loader { + return c.loader +} + +// splitterComponent Splitter 组件包装器 +type splitterComponent struct { + cfg *config.Config + splitter document.Transformer +} + +func newSplitterComponent(cfg *config.Config) *splitterComponent { + return &splitterComponent{cfg: cfg} +} + +func (c *splitterComponent) Name() string { return "splitter" } +func (c *splitterComponent) Validate() error { + return nil +} + +func (c *splitterComponent) Init(rt *runtime.Runtime) error { + splitter, err := newSplitter(c.cfg) + if err != nil { + return fmt.Errorf("failed to create splitter: %w", err) + } + c.splitter = splitter + + var spec SplitterSpec + if c.cfg.Spec.Decode(&spec) == nil { + rt.GetLogger().Info("Splitter component initialized", "type", c.cfg.Type, "chunk_size", spec.ChunkSize, "overlap_size", spec.OverlapSize) + } else { + rt.GetLogger().Info("Splitter component initialized", "type", c.cfg.Type) + } + return nil +} + +func (c *splitterComponent) Start() error { return nil } +func (c *splitterComponent) Stop() error { return nil } +func (c *splitterComponent) get() document.Transformer { + return c.splitter +} + +// indexerComponent Indexer 组件包装器 +type indexerComponent struct { + cfg *config.Config + embedderName string + indexer indexer.Indexer +} + +func newIndexerComponent(cfg *config.Config, embedderName string) *indexerComponent { + return &indexerComponent{cfg: cfg, embedderName: embedderName} +} + +func (c *indexerComponent) Name() string { return "indexer" } +func (c *indexerComponent) Validate() error { + return nil +} + +func (c *indexerComponent) Init(rt *runtime.Runtime) error { + registry := rt.GetGenkitRegistry() + if registry == nil { + return fmt.Errorf("genkit registry not initialized") + } + + idx, err := newIndexer(registry, c.cfg.Type, c.embedderName) + if err != nil { + return fmt.Errorf("failed to create indexer: %w", err) + } + c.indexer = idx + + rt.GetLogger().Info("Indexer component initialized", "type", c.cfg.Type, "embedder", c.embedderName) + return nil +} + +func (c *indexerComponent) Start() error { return nil } +func (c *indexerComponent) Stop() error { return nil } +func (c *indexerComponent) get() indexer.Indexer { + return c.indexer +} + +// retrieverComponent Retriever 组件包装器 +type retrieverComponent struct { + cfg *config.Config + embedderName string + retriever retriever.Retriever +} + +func newRetrieverComponent(cfg *config.Config, embedderName string) *retrieverComponent { + return &retrieverComponent{cfg: cfg, embedderName: embedderName} +} + +func (c *retrieverComponent) Name() string { return "retriever" } +func (c *retrieverComponent) Validate() error { + return nil +} + +func (c *retrieverComponent) Init(rt *runtime.Runtime) error { + registry := rt.GetGenkitRegistry() + if registry == nil { + return fmt.Errorf("genkit registry not initialized") + } + + rtv, err := newRetriever(registry, c.cfg.Type, c.embedderName) + if err != nil { + return fmt.Errorf("failed to create retriever: %w", err) + } + c.retriever = rtv + + rt.GetLogger().Info("Retriever component initialized", "type", c.cfg.Type, "embedder", c.embedderName) + return nil +} + +func (c *retrieverComponent) Start() error { return nil } +func (c *retrieverComponent) Stop() error { return nil } +func (c *retrieverComponent) get() retriever.Retriever { + return c.retriever +} + +// rerankerComponent Reranker 组件包装器 +type rerankerComponent struct { + enabled bool + model string + apiKey string + reranker Reranker +} + +func newRerankerComponent(enabled bool, model, apiKey string) *rerankerComponent { + return &rerankerComponent{enabled: enabled, model: model, apiKey: apiKey} +} + +func (c *rerankerComponent) Name() string { return "reranker" } +func (c *rerankerComponent) Validate() error { + return nil +} + +func (c *rerankerComponent) Init(rt *runtime.Runtime) error { + if !c.enabled { + rt.GetLogger().Info("Reranker component disabled") + return nil + } + + reranker, err := newReranker(c.enabled, c.model, c.apiKey) + if err != nil { + return fmt.Errorf("failed to create reranker: %w", err) + } + c.reranker = reranker + + rt.GetLogger().Info("Reranker component initialized", "model", c.model) + return nil +} + +func (c *rerankerComponent) Start() error { return nil } +func (c *rerankerComponent) Stop() error { return nil } +func (c *rerankerComponent) get() Reranker { + return c.reranker +} + +// ============= RAGComponent 主组件 ============= + +// RAGComponent RAG 系统组件 +type RAGComponent struct { + cfg *RAGSpec + embedderName string + loader *loaderComponent + splitter *splitterComponent + indexer *indexerComponent + retriever *retrieverComponent + reranker *rerankerComponent +} + +func (r *RAGComponent) Name() string { + return "rag" +} + +func (r *RAGComponent) Validate() error { + return r.cfg.Validate() +} + +func (r *RAGComponent) Init(rt *runtime.Runtime) error { + // 获取 embedder 模型名称 + var embedderSpec EmbedderSpec + if err := r.cfg.Embedder.Spec.Decode(&embedderSpec); err != nil { + return fmt.Errorf("failed to parse embedder spec: %w", err) + } + r.embedderName = embedderSpec.Model + + // 创建子组件 + r.loader = newLoaderComponent(r.cfg.Loader) + r.splitter = newSplitterComponent(r.cfg.Splitter) + r.indexer = newIndexerComponent(r.cfg.Indexer, r.embedderName) + r.retriever = newRetrieverComponent(r.cfg.Retriever, r.embedderName) + r.reranker = newRerankerComponent( + getRerankerEnabled(r.cfg.Reranker), + getRerankerModel(r.cfg.Reranker), + getRerankerAPIKey(r.cfg.Reranker), + ) + + // 初始化所有子组件 + components := []runtime.Component{r.loader, r.splitter, r.indexer, r.retriever, r.reranker} + for _, comp := range components { + if err := comp.Init(rt); err != nil { + return fmt.Errorf("failed to init %s: %w", comp.Name(), err) + } + } + + rt.GetLogger().Info("RAG component initialized", + "embedder", r.embedderName, + "indexer", r.cfg.Indexer.Type, + "retriever", r.cfg.Retriever.Type, + "splitter", r.cfg.Splitter.Type, + "reranker_enabled", r.cfg.Reranker != nil) + + return nil +} + +func (r *RAGComponent) Start() error { + components := []runtime.Component{r.loader, r.splitter, r.indexer, r.retriever, r.reranker} + for _, comp := range components { + if err := comp.Start(); err != nil { + return fmt.Errorf("failed to start %s: %w", comp.Name(), err) + } + } + return nil +} + +func (r *RAGComponent) Stop() error { + components := []runtime.Component{r.reranker, r.retriever, r.indexer, r.splitter, r.loader} + for _, comp := range components { + if err := comp.Stop(); err != nil { + return fmt.Errorf("failed to stop %s: %w", comp.Name(), err) + } + } + return nil +} + +// Getter 方法 +func (r *RAGComponent) GetLoader() document.Loader { + if r.loader != nil { + return r.loader.get() + } + return nil +} + +func (r *RAGComponent) GetSplitter() document.Transformer { + if r.splitter != nil { + return r.splitter.get() + } + return nil +} + +func (r *RAGComponent) GetIndexer() indexer.Indexer { + if r.indexer != nil { + return r.indexer.get() + } + return nil +} + +func (r *RAGComponent) GetRetriever() retriever.Retriever { + if r.retriever != nil { + return r.retriever.get() + } + return nil +} + +func (r *RAGComponent) GetReranker() Reranker { + if r.reranker != nil { + return r.reranker.get() + } + return nil +} + +func (r *RAGComponent) GetEmbedderName() string { + return r.embedderName +} + +// ============= 辅助函数 ============= + +func getRerankerEnabled(cfg *config.Config) bool { + if cfg == nil { + return false + } + var spec RerankerSpec + if err := cfg.Spec.Decode(&spec); err != nil { + return false + } + return spec.Enabled +} + +func getRerankerModel(cfg *config.Config) string { + if cfg == nil { + return "" + } + var spec RerankerSpec + if err := cfg.Spec.Decode(&spec); err != nil { + return "" + } + return spec.Model +} + +func getRerankerAPIKey(cfg *config.Config) string { + if cfg == nil { + return "" + } + var spec RerankerSpec + if err := cfg.Spec.Decode(&spec); err != nil { + return "" + } + return spec.APIKey +} diff --git a/ai/component/rag/config.go b/ai/component/rag/config.go new file mode 100644 index 00000000..7fee29c4 --- /dev/null +++ b/ai/component/rag/config.go @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rag + +import ( + "dubbo-admin-ai/config" + "fmt" +) + +// RAGSpec defines RAG component configuration with recursive structure +// Each subcomponent uses the standard Config pattern (type + spec) +type RAGSpec struct { + Embedder *config.Config `yaml:"embedder"` + Loader *config.Config `yaml:"loader"` + Splitter *config.Config `yaml:"splitter"` + Indexer *config.Config `yaml:"indexer"` + Retriever *config.Config `yaml:"retriever"` + Reranker *config.Config `yaml:"reranker,omitempty"` +} + +// EmbedderSpec defines embedder specific parameters +type EmbedderSpec struct { + Model string `yaml:"model"` +} + +// LoaderSpec defines loader specific parameters +// Loader has no specific parameters, configuration is through spec +type LoaderSpec struct { + // Loader has no concrete parameters +} + +// SplitterSpec defines splitter specific parameters +type SplitterSpec struct { + ChunkSize int `yaml:"chunk_size"` + OverlapSize int `yaml:"overlap_size"` +} + +// IndexerSpec defines indexer specific parameters +type IndexerSpec struct { + StoragePath string `yaml:"storage_path"` + IndexFormat string `yaml:"index_format"` + Dimension int `yaml:"dimension"` +} + +// RetrieverSpec defines retriever specific parameters +type RetrieverSpec struct { + StoragePath string `yaml:"storage_path"` + IndexFormat string `yaml:"index_format"` + Dimension int `yaml:"dimension"` +} + +// RerankerSpec defines reranker specific parameters +type RerankerSpec struct { + Enabled bool `yaml:"enabled"` + Model string `yaml:"model"` + APIKey string `yaml:"api_key,omitempty"` +} + +// DefaultEmbedderSpec returns default embedder configuration +func DefaultEmbedderSpec() *EmbedderSpec { + return &EmbedderSpec{Model: "dashscope/qwen3-embedding"} +} + +// DefaultSplitterSpec returns default splitter configuration +func DefaultSplitterSpec() *SplitterSpec { + return &SplitterSpec{ChunkSize: 1000, OverlapSize: 100} +} + +// DefaultIndexerSpec returns default indexer configuration +func DefaultIndexerSpec() *IndexerSpec { + return &IndexerSpec{ + StoragePath: "../../data/ai/index", + IndexFormat: "sqlite", + Dimension: 1536, + } +} + +// DefaultRetrieverSpec returns default retriever configuration +func DefaultRetrieverSpec() *RetrieverSpec { + return &RetrieverSpec{ + StoragePath: "../../data/ai/index", + IndexFormat: "sqlite", + Dimension: 1536, + } +} + +// DefaultRerankerSpec returns default reranker configuration +func DefaultRerankerSpec() *RerankerSpec { + return &RerankerSpec{ + Enabled: false, + Model: "rerank-english-v3.0", + } +} + +// Validate validates RAG configuration +func (c *RAGSpec) Validate() error { + if c == nil { + return fmt.Errorf("rag config is nil") + } + if c.Splitter != nil && c.Splitter.Type == "recursive" { + var splitter SplitterSpec + if err := c.Splitter.Spec.Decode(&splitter); err != nil { + return fmt.Errorf("failed to decode splitter spec: %w", err) + } + if splitter.ChunkSize <= 0 { + return fmt.Errorf("splitter.chunk_size must be greater than 0") + } + if splitter.OverlapSize < 0 { + return fmt.Errorf("splitter.overlap_size must be >= 0") + } + if splitter.OverlapSize >= splitter.ChunkSize { + return fmt.Errorf("splitter.overlap_size must be less than chunk_size") + } + } + if c.Indexer != nil { + switch c.Indexer.Type { + case "dev", "pinecone": + default: + return fmt.Errorf("unsupported indexer type: %s", c.Indexer.Type) + } + } + if c.Retriever != nil { + switch c.Retriever.Type { + case "dev", "pinecone": + default: + return fmt.Errorf("unsupported retriever type: %s", c.Retriever.Type) + } + } + return nil +} + +// --- Exported types for rag package to avoid circular dependency --- + +// CallOptions defines per-call options structure +type CallOptions struct { + TopK *int + TopN *int + Namespace *string + TargetIndex *string +} + +// RerankOption defines reranker option function type +type RerankOption func(*CallOptions) + +// RetrieveResult defines the unified result structure for RAG queries. +type RetrieveResult struct { + Content string `json:"content"` + RelevanceScore float64 `json:"relevance_score"` +} diff --git a/ai/component/rag/factory.go b/ai/component/rag/factory.go new file mode 100644 index 00000000..6abbe1cd --- /dev/null +++ b/ai/component/rag/factory.go @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rag + +import ( + "context" + "dubbo-admin-ai/config" + "dubbo-admin-ai/runtime" + "fmt" + + "github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown" + "github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive" + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/components/retriever" + "github.com/firebase/genkit/go/genkit" + "gopkg.in/yaml.v3" +) + +// RAGFactory creates a RAG component from configuration. +func RAGFactory(spec *yaml.Node) (runtime.Component, error) { + var cfg RAGSpec + if err := spec.Decode(&cfg); err != nil { + return nil, fmt.Errorf("failed to decode rag spec: %w", err) + } + return &RAGComponent{cfg: &cfg}, nil +} + +// ============= 创建函数 ============= + +func newLoader(cfg *config.Config) (document.Loader, error) { + if cfg == nil { + cfg = &config.Config{Type: "local"} + } + ctx := context.Background() + switch cfg.Type { + case "", "local": + return newLocalFileLoader(ctx) + default: + return nil, fmt.Errorf("unsupported loader type: %s", cfg.Type) + } +} + +func newSplitter(cfg *config.Config) (document.Transformer, error) { + if cfg == nil { + cfg = &config.Config{Type: "recursive"} + } + ctx := context.Background() + switch cfg.Type { + case "markdown_header": + var spec struct { + Headers map[string]string `yaml:"headers"` + TrimHeaders bool `yaml:"trim_headers"` + } + if err := cfg.Spec.Decode(&spec); err != nil { + return nil, fmt.Errorf("failed to decode markdown splitter spec: %w", err) + } + headers := spec.Headers + if len(headers) == 0 { + headers = map[string]string{"#": "h1", "##": "h2", "###": "h3", "####": "h4"} + } + return markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{Headers: headers, TrimHeaders: spec.TrimHeaders}) + case "", "recursive": + var spec SplitterSpec + if err := cfg.Spec.Decode(&spec); err != nil { + spec = *DefaultSplitterSpec() + } + chunkSize := spec.ChunkSize + if chunkSize <= 0 { + chunkSize = 1000 + } + overlap := spec.OverlapSize + if overlap <= 0 { + overlap = 100 + } + return recursive.NewSplitter(ctx, &recursive.Config{ChunkSize: chunkSize, OverlapSize: overlap}) + default: + return nil, fmt.Errorf("unsupported splitter type: %s", cfg.Type) + } +} + +func newIndexer(g *genkit.Genkit, indexerType, embedderModel string) (indexer.Indexer, error) { + const targetIndex = "default" + switch indexerType { + case "dev": + return newDevIndexer(g, embedderModel, targetIndex, 100), nil + case "pinecone": + return newPineconeIndexer(g, embedderModel, targetIndex, 100), nil + default: + return nil, fmt.Errorf("unsupported indexer type: %s", indexerType) + } +} + +func newRetriever(g *genkit.Genkit, retrieverType, embedderModel string) (retriever.Retriever, error) { + const targetIndex = "default" + const defaultTopK = 3 + switch retrieverType { + case "dev": + return newDevRetriever(g, embedderModel, targetIndex, defaultTopK), nil + case "pinecone": + return newPineconeRetriever(g, embedderModel, targetIndex, defaultTopK), nil + default: + return nil, fmt.Errorf("unsupported retriever type: %s", retrieverType) + } +} + +func newReranker(enabled bool, model, apiKey string) (Reranker, error) { + if !enabled { + return nil, nil + } + if model == "" { + model = "rerank-english-v3.0" + } + return &cohereReranker{cfg: &cohereRerankerConfig{APIKey: apiKey, Model: model, TopN: 3}}, nil +} + +// BuildRAGFromSpec 创建独立 RAG 实例(用于 CLI/工具) +func BuildRAGFromSpec(ctx context.Context, g *genkit.Genkit, cfg *RAGSpec) (*RAG, error) { + if g == nil { + return nil, fmt.Errorf("genkit registry is nil") + } + if cfg == nil { + return nil, fmt.Errorf("rag config is nil") + } + + if err := cfg.Validate(); err != nil { + return nil, err + } + + var embedderSpec EmbedderSpec + if err := cfg.Embedder.Spec.Decode(&embedderSpec); err != nil { + return nil, fmt.Errorf("failed to parse embedder spec: %w", err) + } + + loader, err := newLoader(cfg.Loader) + if err != nil { + return nil, fmt.Errorf("failed to create loader: %w", err) + } + + splitter, err := newSplitter(cfg.Splitter) + if err != nil { + return nil, fmt.Errorf("failed to create splitter: %w", err) + } + + idx, err := newIndexer(g, cfg.Indexer.Type, embedderSpec.Model) + if err != nil { + return nil, fmt.Errorf("failed to create indexer: %w", err) + } + + rtv, err := newRetriever(g, cfg.Retriever.Type, embedderSpec.Model) + if err != nil { + return nil, fmt.Errorf("failed to create retriever: %w", err) + } + + rr, err := newReranker(getRerankerEnabled(cfg.Reranker), getRerankerModel(cfg.Reranker), getRerankerAPIKey(cfg.Reranker)) + if err != nil { + return nil, fmt.Errorf("failed to create reranker: %w", err) + } + + return &RAG{ + Loader: loader, + Splitter: splitter, + Indexer: idx, + Retriever: rtv, + Reranker: rr, + }, nil +} diff --git a/ai/component/rag/indexer.go b/ai/component/rag/indexer.go new file mode 100644 index 00000000..d85fa808 --- /dev/null +++ b/ai/component/rag/indexer.go @@ -0,0 +1,203 @@ +package rag + +import ( + "context" + + "dubbo-admin-ai/utils" + "fmt" + "sync" + + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/schema" + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/localvec" + "github.com/firebase/genkit/go/plugins/pinecone" +) + +// --- Indexer --- +type PineconeIndexer struct { + g *genkit.Genkit + embedder string + target string + batchSz int + mu sync.Mutex + docstore map[string]*pinecone.Docstore // keyed by target index +} + +func newPineconeIndexer(g *genkit.Genkit, embedderModel string, targetIndex string, batchSize int) *PineconeIndexer { + return &PineconeIndexer{ + g: g, + embedder: embedderModel, + target: targetIndex, + batchSz: batchSize, + } +} + +func (idx *PineconeIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) { + // Handle options + implOpts := indexer.GetImplSpecificOptions(&CommonIndexerOptions{}, opts...) + namespace := implOpts.Namespace + effectiveTarget := idx.target + if implOpts.TargetIndex != nil && *implOpts.TargetIndex != "" { + effectiveTarget = *implOpts.TargetIndex + } + + // TODO(indexer, 2026-02-24): Validate namespace if needed for multi-tenancy support + // Initialize indexer docstore for this target if not already done + idx.mu.Lock() + if idx.docstore == nil { + idx.docstore = make(map[string]*pinecone.Docstore) + } + docstore := idx.docstore[effectiveTarget] + idx.mu.Unlock() + if docstore == nil { + embedder := genkit.LookupEmbedder(idx.g, idx.embedder) + if embedder == nil { + return nil, fmt.Errorf("failed to find embedder %s", idx.embedder) + } + + // Configure Pinecone connection + pineconeConfig := pinecone.Config{ + IndexID: effectiveTarget, + Embedder: embedder, + } + + newDocstore, _, err := pinecone.DefineRetriever(ctx, idx.g, + pineconeConfig, + &ai.RetrieverOptions{ + Label: effectiveTarget, + ConfigSchema: core.InferSchemaMap(pinecone.PineconeRetrieverOptions{}), + }) + if err != nil { + return nil, fmt.Errorf("failed to setup retriever for indexer: %w", err) + } + + idx.mu.Lock() + if idx.docstore == nil { + idx.docstore = make(map[string]*pinecone.Docstore) + } + if idx.docstore[effectiveTarget] == nil { + idx.docstore[effectiveTarget] = newDocstore + } + docstore = idx.docstore[effectiveTarget] + idx.mu.Unlock() + } + + // Convert to Genkit documents + genkitDocs := utils.ToGenkitDocuments(docs) + + // Index in batches + batchSize := idx.batchSz + if implOpts.BatchSize != nil && *implOpts.BatchSize > 0 { + batchSize = *implOpts.BatchSize + } + if batchSize <= 0 { + return nil, fmt.Errorf("batch size must be positive") + } + for i := 0; i < len(genkitDocs); i += batchSize { + end := min(i+batchSize, len(genkitDocs)) + batch := genkitDocs[i:end] + if err := pinecone.Index(ctx, batch, docstore, namespace); err != nil { + return nil, fmt.Errorf("failed to index documents batch %d-%d: %w", i+1, end, err) + } + } + + return nil, nil +} + +// --- DevIndexer --- +type DevIndexer struct { + g *genkit.Genkit + embedder string + target string + batchSz int + mu sync.Mutex + docstore map[string]*localvec.DocStore // keyed by target index +} + +func newDevIndexer(g *genkit.Genkit, embedderModel string, targetIndex string, batchSize int) *DevIndexer { + return &DevIndexer{ + g: g, + embedder: embedderModel, + target: targetIndex, + batchSz: batchSize, + } +} + +func (idx *DevIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) { + implOpts := indexer.GetImplSpecificOptions(&CommonIndexerOptions{}, opts...) + _ = implOpts.Namespace + effectiveTarget := idx.target + if implOpts.TargetIndex != nil && *implOpts.TargetIndex != "" { + effectiveTarget = *implOpts.TargetIndex + } + + // Initialize indexer docstore for this target if not already done + idx.mu.Lock() + if idx.docstore == nil { + idx.docstore = make(map[string]*localvec.DocStore) + } + docstore := idx.docstore[effectiveTarget] + idx.mu.Unlock() + if docstore == nil { + embedder := genkit.LookupEmbedder(idx.g, idx.embedder) + if embedder == nil { + return nil, fmt.Errorf("failed to find embedder %s", idx.embedder) + } + + // Initialize localvec if needed (idempotent) + if err := localvec.Init(); err != nil { + return nil, fmt.Errorf("failed to init localvec: %w", err) + } + + // Configure localvec with Dev-specific settings + localvecConfig := localvec.Config{ + Embedder: embedder, + } + + var err error + docstore, _, err = localvec.DefineRetriever(idx.g, effectiveTarget, localvecConfig, nil) + if err != nil { + return nil, fmt.Errorf("failed to define localvec retriever: %w", err) + } + + idx.mu.Lock() + if idx.docstore == nil { + idx.docstore = make(map[string]*localvec.DocStore) + } + if existing := idx.docstore[effectiveTarget]; existing != nil { + docstore = existing + } else { + idx.docstore[effectiveTarget] = docstore + } + idx.mu.Unlock() + } + + // Convert to Genkit documents + genkitDocs := utils.ToGenkitDocuments(docs) + + // Index documents in batches + batchSize := idx.batchSz + if implOpts.BatchSize != nil && *implOpts.BatchSize > 0 { + batchSize = *implOpts.BatchSize + } + if batchSize <= 0 { + return nil, fmt.Errorf("batch size must be positive") + } + for i := 0; i < len(genkitDocs); i += batchSize { + end := min(i+batchSize, len(genkitDocs)) + batch := genkitDocs[i:end] + if err := localvec.Index(ctx, batch, docstore); err != nil { + return nil, fmt.Errorf("failed to index documents batch %d-%d: %w", i+1, end, err) + } + } + + // Return IDs (localvec doesn't return IDs on Index, so we extract from docs) + ids := make([]string, len(docs)) + for i, doc := range docs { + ids[i] = doc.ID + } + return ids, nil +} diff --git a/ai/component/rag/loader.go b/ai/component/rag/loader.go new file mode 100644 index 00000000..78db7829 --- /dev/null +++ b/ai/component/rag/loader.go @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rag + +import ( + "context" + "fmt" + "os" + "path/filepath" + "slices" + "strings" + + "github.com/cloudwego/eino-ext/components/document/loader/file" + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/document/parser" + "github.com/cloudwego/eino/schema" +) + +type ExtType = string + +const ( + DotPDF ExtType = ".pdf" + DotMD ExtType = ".md" + DotTxt ExtType = ".txt" +) + +// newLocalFileLoader creates a FileLoader with an ExtParser that supports PDF and Markdown. +func newLocalFileLoader(ctx context.Context) (*file.FileLoader, error) { + // 1. Create Parsers + pdfParser, err := newPDFParserWrapper(ctx) + if err != nil { + return nil, err + } + mdParser := newMarkdownParser() + plainParser := parser.TextParser{} + + // 2. Create ExtParser + extParser, err := parser.NewExtParser(ctx, &parser.ExtParserConfig{ + Parsers: map[string]parser.Parser{ + DotPDF: pdfParser, + DotMD: mdParser, + DotTxt: plainParser, + }, + FallbackParser: plainParser, // Fallback to text parser for other files + }) + if err != nil { + return nil, fmt.Errorf("failed to create ext parser: %w", err) + } + + // 3. Create FileLoader + loader, err := file.NewFileLoader(ctx, &file.FileLoaderConfig{ + Parser: extParser, + }) + if err != nil { + return nil, fmt.Errorf("failed to create file loader: %w", err) + } + + return loader, nil +} + +// LoadDirectory loads all supported files from a directory recursively. +func LoadDirectory(ctx context.Context, loader document.Loader, dirPath string, opts ...LoaderOption) ([]*schema.Document, error) { + lo := defaultLoaderOptions() + for _, opt := range opts { + if opt != nil { + opt(&lo) + } + } + + // Normalize extensions + var targetExts []string + if len(lo.TargetExtensions) > 0 { + for _, e := range lo.TargetExtensions { + e = strings.ToLower(e) + if !strings.HasPrefix(e, ".") { + e = "." + e + } + targetExts = append(targetExts, e) + } + } + + var allDocs []*schema.Document + + err := filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + return nil + } + + // Check extension if provided + if len(targetExts) > 0 { + ext := strings.ToLower(filepath.Ext(path)) + if !slices.Contains(targetExts, ext) { + return nil + } + } + + // Load file + docs, err := loader.Load(ctx, document.Source{URI: path}) + if err != nil { + return fmt.Errorf("failed to load file %s: %w", path, err) + } + + allDocs = append(allDocs, docs...) + return nil + }) + + if err != nil { + return nil, err + } + + return allDocs, nil +} + +// ---- Loader options ---- + +type LoaderOptions struct { + TargetExtensions []string +} + +type LoaderOption func(*LoaderOptions) + +func defaultLoaderOptions() LoaderOptions { + return LoaderOptions{TargetExtensions: []string{".md", ".pdf", ".txt"}} +} + +// WithLoaderTargetExtensions sets the file extensions to include when loading directories. +// Extensions are normalized to lowercase and ensured to start with '.'. +func WithLoaderTargetExtensions(exts ...string) LoaderOption { + return func(o *LoaderOptions) { + if len(exts) == 0 { + o.TargetExtensions = nil + return + } + norm := make([]string, 0, len(exts)) + for _, e := range exts { + e = strings.ToLower(strings.TrimSpace(e)) + if e == "" { + continue + } + if !strings.HasPrefix(e, ".") { + e = "." + e + } + norm = append(norm, e) + } + o.TargetExtensions = norm + } +} diff --git a/ai/component/rag/options.go b/ai/component/rag/options.go new file mode 100644 index 00000000..dd3bb1ac --- /dev/null +++ b/ai/component/rag/options.go @@ -0,0 +1,72 @@ +package rag + +import ( + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/components/retriever" +) + +// RetrieveOption defines per-call retrieve/rerank options. +type RetrieveOption = RerankOption + +func WithTopK(topK int) RetrieveOption { + return func(o *CallOptions) { o.TopK = &topK } +} + +func WithTopN(topN int) RetrieveOption { + return func(o *CallOptions) { o.TopN = &topN } +} + +func WithTargetIndex(index string) RetrieveOption { + return func(o *CallOptions) { o.TargetIndex = &index } +} + +func WithRetrieverTargetIndex(index string) RetrieveOption { + return WithTargetIndex(index) +} + +func WithRetrieverNamespace(namespace string) RetrieveOption { + return func(o *CallOptions) { o.Namespace = &namespace } +} + +// CommonIndexerOptions are per-call indexing options. +type CommonIndexerOptions struct { + Namespace string + BatchSize *int + TargetIndex *string +} + +func WithIndexerNamespace(ns string) indexer.Option { + return indexer.WrapImplSpecificOptFn(func(opts *CommonIndexerOptions) { + opts.Namespace = ns + }) +} + +func WithIndexerBatchSize(batchSize int) indexer.Option { + return indexer.WrapImplSpecificOptFn(func(opts *CommonIndexerOptions) { + opts.BatchSize = &batchSize + }) +} + +func WithIndexerTargetIndex(targetIndex string) indexer.Option { + return indexer.WrapImplSpecificOptFn(func(opts *CommonIndexerOptions) { + opts.TargetIndex = &targetIndex + }) +} + +// CommonRetrieverOptions are per-call retrieval options. +type CommonRetrieverOptions struct { + Namespace string + TargetIndex *string +} + +func WithRetrieverImplNamespace(ns string) retriever.Option { + return retriever.WrapImplSpecificOptFn(func(opts *CommonRetrieverOptions) { + opts.Namespace = ns + }) +} + +func WithRetrieverImplTargetIndex(targetIndex string) retriever.Option { + return retriever.WrapImplSpecificOptFn(func(opts *CommonRetrieverOptions) { + opts.TargetIndex = &targetIndex + }) +} diff --git a/ai/component/rag/parser.go b/ai/component/rag/parser.go new file mode 100644 index 00000000..cbd64ea3 --- /dev/null +++ b/ai/component/rag/parser.go @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rag + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/cloudwego/eino-ext/components/document/parser/pdf" + "github.com/cloudwego/eino/components/document/parser" + "github.com/cloudwego/eino/schema" +) + +// --- Options --- + +type ParserConfig struct { + Preprocessors []PreprocessorFunc +} + +type ParserOption func(*ParserConfig) + +func WithPreprocessors(funcs ...PreprocessorFunc) ParserOption { + return func(c *ParserConfig) { + c.Preprocessors = append(c.Preprocessors, funcs...) + } +} + +// --- Parsers --- + +// MarkdownParser uses MarkdownCleaner and optional preprocessors +type MarkdownParser struct { + preprocessor *Preprocessor +} + +func newMarkdownParser(opts ...ParserOption) *MarkdownParser { + config := &ParserConfig{ + Preprocessors: []PreprocessorFunc{newMarkdownCleaner().Clean}, + } + for _, opt := range opts { + opt(config) + } + + return &MarkdownParser{ + preprocessor: newPreprocessor(config.Preprocessors...), + } +} + +func (p *MarkdownParser) Parse(ctx context.Context, reader io.Reader, opts ...parser.Option) ([]*schema.Document, error) { + content, err := io.ReadAll(reader) + if err != nil { + return nil, err + } + + cleaned := p.preprocessor.Process(string(content)) + if strings.TrimSpace(cleaned) == "" { + return nil, fmt.Errorf("empty content after cleaning") + } + + // Apply common options (like extra meta) + commonOpts := parser.GetCommonOptions(nil, opts...) + + return []*schema.Document{ + { + Content: cleaned, + MetaData: commonOpts.ExtraMeta, + }, + }, nil +} + +// PDFParserWrapper wraps eino-ext PDF parser and adds preprocessing +type PDFParserWrapper struct { + internalParser *pdf.PDFParser + preprocessor *Preprocessor +} + +func newPDFParserWrapper(ctx context.Context, opts ...ParserOption) (*PDFParserWrapper, error) { + p, err := pdf.NewPDFParser(ctx, nil) + if err != nil { + return nil, err + } + + config := &ParserConfig{ + Preprocessors: []PreprocessorFunc{PDFTextCleaner}, + } + for _, opt := range opts { + opt(config) + } + + return &PDFParserWrapper{ + internalParser: p, + preprocessor: newPreprocessor(config.Preprocessors...), + }, nil +} + +func (p *PDFParserWrapper) Parse(ctx context.Context, reader io.Reader, opts ...parser.Option) ([]*schema.Document, error) { + docs, err := p.internalParser.Parse(ctx, reader, opts...) + if err != nil { + return nil, err + } + + // Apply preprocessing + for _, doc := range docs { + doc.Content = p.preprocessor.Process(doc.Content) + } + + return docs, nil +} diff --git a/ai/component/rag/preprocessor.go b/ai/component/rag/preprocessor.go new file mode 100644 index 00000000..37c7ccb5 --- /dev/null +++ b/ai/component/rag/preprocessor.go @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rag + +import ( + "regexp" + "strings" + "unicode" + "unicode/utf8" + + "github.com/gomarkdown/markdown/ast" + mdparser "github.com/gomarkdown/markdown/parser" +) + +// PreprocessorFunc defines a function that processes text. +type PreprocessorFunc func(string) string + +// Preprocessor manages a chain of preprocessing functions. +type Preprocessor struct { + funcs []PreprocessorFunc +} + +// newPreprocessor creates a new Preprocessor with the given functions. +func newPreprocessor(funcs ...PreprocessorFunc) *Preprocessor { + return &Preprocessor{ + funcs: funcs, + } +} + +// Process applies all preprocessing functions in order. +func (p *Preprocessor) Process(text string) string { + for _, f := range p.funcs { + text = f(text) + } + return text +} + +// Common Preprocessors + +// PDFTextCleaner cleans text extracted from PDFs. +func PDFTextCleaner(text string) string { + // 1. Remove control characters and non-printable characters (keep newlines, tabs, and spaces) + cleaned := "" + for _, r := range text { + if r == '\n' || r == '\t' || r == ' ' || (r >= 32 && r < 127) || r > 127 { + cleaned += string(r) + } + } + + // 2. Remove extra whitespace and newlines + cleaned = strings.ReplaceAll(cleaned, "\n \n", "\n") + cleaned = strings.ReplaceAll(cleaned, " \n", "\n") + cleaned = strings.ReplaceAll(cleaned, "\n ", "\n") + + // 3. Merge multiple newlines + multipleNewlines := regexp.MustCompile(`\n{3,}`) + cleaned = multipleNewlines.ReplaceAllString(cleaned, "\n\n") + + // 4. Remove single character lines (likely PDF parsing errors) + lines := strings.Split(cleaned, "\n") + var cleanedLines []string + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + if len(line) <= 1 { + continue + } + if regexp.MustCompile(`^[^\w\s]+$`).MatchString(line) { + continue + } + cleanedLines = append(cleanedLines, line) + } + + // 5. Reassemble text + result := strings.Join(cleanedLines, "\n") + + // 6. Clean common PDF parsing artifacts + result = regexp.MustCompile(`(?m)^\d+$`).ReplaceAllString(result, "") + result = regexp.MustCompile(`\s+`).ReplaceAllString(result, " ") + result = strings.ReplaceAll(result, " \n", "\n") + result = strings.ReplaceAll(result, "\n ", "\n") + + // 7. Final trim + return strings.TrimSpace(result) +} + +// MarkdownCleaner cleans Markdown content for RAG. +type MarkdownCleaner struct { + preserveCodeContent bool + preserveListStructure bool + preserveTableContent bool + maxLineLength int + result strings.Builder + inList bool + listDepth int + inTable bool +} + +// newMarkdownCleaner creates a new MarkdownCleaner. +func newMarkdownCleaner() *MarkdownCleaner { + return &MarkdownCleaner{ + preserveCodeContent: true, + preserveListStructure: true, + preserveTableContent: true, + maxLineLength: 500, + } +} + +// Clean cleans the markdown text. +func (c *MarkdownCleaner) Clean(markdown string) string { + c.result.Reset() + c.inList = false + c.listDepth = 0 + c.inTable = false + + markdown = c.removeFrontmatter(markdown) + hugoRe := regexp.MustCompile(`{{<[^>]+>}}|{{%[^%]+%}}`) + markdown = hugoRe.ReplaceAllString(markdown, "") + + extensions := mdparser.CommonExtensions | mdparser.Mmark | mdparser.Footnotes + p := mdparser.NewWithExtensions(extensions) + doc := p.Parse([]byte(markdown)) + + c.walkAST(doc) + + return c.postProcess(c.result.String()) +} + +// Helper methods for MarkdownCleaner (private) + +func (c *MarkdownCleaner) walkAST(node ast.Node) { + if node == nil { + return + } + + switch n := node.(type) { + case *ast.Document: + c.processChildren(n) + case *ast.Heading: + c.result.WriteString("\r\n") + c.processHeading(n) + case *ast.Paragraph: + c.processParagraph(n) + c.result.WriteString("\r\n") + case *ast.List: + c.processList(n) + c.result.WriteString("\n") + case *ast.ListItem: + c.processListItem(n) + case *ast.CodeBlock: + c.processCodeBlock(n) + c.result.WriteString("\n") + case *ast.Table: + c.result.WriteString("\r\n") + c.processTable(n) + c.result.WriteString("\r\n") + case *ast.TableRow: + c.processTableRow(n) + c.result.WriteString("\n") + case *ast.TableCell: + c.processTableCell(n) + case *ast.Text: + c.processText(n) + case *ast.Emph, *ast.Strong: + c.processChildren(n) + case *ast.Link: + c.processLink(n) + case *ast.Image: + c.processImage(n) + case *ast.Code: + c.processInlineCode(n) + case *ast.Softbreak, *ast.Hardbreak: + c.result.WriteString(" ") + case *ast.BlockQuote: + c.processBlockQuote(n) + case *ast.HorizontalRule: + c.result.WriteString("\n") + default: + c.result.WriteString("") + } +} + +func (c *MarkdownCleaner) processChildren(node ast.Node) { + children := node.GetChildren() + for _, child := range children { + c.walkAST(child) + } +} + +func (c *MarkdownCleaner) processHeading(h *ast.Heading) { + c.result.WriteString(strings.Repeat("#", h.Level) + " ") + c.processChildren(h) + c.result.WriteString("\n") +} + +func (c *MarkdownCleaner) processParagraph(p *ast.Paragraph) { + c.processChildren(p) +} + +func (c *MarkdownCleaner) processList(l *ast.List) { + if !c.preserveListStructure { + c.processChildren(l) + return + } + wasInList := c.inList + c.inList = true + c.listDepth++ + c.processChildren(l) + c.listDepth-- + if c.listDepth == 0 { + c.inList = wasInList + } +} + +func (c *MarkdownCleaner) processListItem(li *ast.ListItem) { + if c.preserveListStructure { + indent := strings.Repeat(" ", c.listDepth-1) + c.result.WriteString(indent + "- ") + } + c.processChildren(li) +} + +func (c *MarkdownCleaner) processCodeBlock(cb *ast.CodeBlock) { + if !c.preserveCodeContent { + return + } + code := string(cb.Literal) + if string(cb.Info) != "" { + c.result.WriteString(string(cb.Info) + ":") + } + cleanCode := c.cleanCodeContent(code) + c.result.WriteString(cleanCode) +} + +func (c *MarkdownCleaner) processTable(t *ast.Table) { + if !c.preserveTableContent { + return + } + c.inTable = true + c.processChildren(t) + c.inTable = false +} + +func (c *MarkdownCleaner) processTableRow(tr *ast.TableRow) { + if !c.preserveTableContent { + return + } + var cellContents []string + children := tr.GetChildren() + for _, cell := range children { + if tableCell, ok := cell.(*ast.TableCell); ok { + var cellBuilder strings.Builder + tempResult := c.result + c.result = cellBuilder + c.processChildren(tableCell) + c.result = tempResult + content := strings.TrimSpace(cellBuilder.String()) + if content != "" { + cellContents = append(cellContents, content) + } + } + } + if len(cellContents) > 0 { + c.result.WriteString(strings.Join(cellContents, " | ")) + } +} + +func (c *MarkdownCleaner) processTableCell(tc *ast.TableCell) { + c.processChildren(tc) +} + +func (c *MarkdownCleaner) processText(t *ast.Text) { + text := string(t.Literal) + cleanText := c.cleanText(text) + c.result.WriteString(cleanText) +} + +func (c *MarkdownCleaner) processLink(l *ast.Link) { + c.result.WriteString(" [") + c.processChildren(l) + c.result.WriteString("] ") +} + +func (c *MarkdownCleaner) processImage(img *ast.Image) { + c.processChildren(img) +} + +func (c *MarkdownCleaner) processInlineCode(code *ast.Code) { + if c.preserveCodeContent { + cleanCode := c.cleanText(string(code.Literal)) + c.result.WriteString(cleanCode) + } +} + +func (c *MarkdownCleaner) processBlockQuote(bq *ast.BlockQuote) { + c.processChildren(bq) +} + +func (c *MarkdownCleaner) cleanText(text string) string { + htmlRe := regexp.MustCompile(`<[^>]*>`) + text = htmlRe.ReplaceAllString(text, "") + spaceRe := regexp.MustCompile(`\s+`) + text = spaceRe.ReplaceAllString(text, " ") + text = strings.Map(func(r rune) rune { + if unicode.IsControl(r) && r != '\n' && r != '\r' && r != '\t' { + return -1 + } + return r + }, text) + return strings.TrimSpace(text) +} + +func (c *MarkdownCleaner) cleanCodeContent(code string) string { + lines := strings.Split(code, "\n") + var cleanLines []string + for _, line := range lines { + line = strings.TrimSpace(line) + if len(line) == 0 { + continue + } + cleanLines = append(cleanLines, line) + } + return strings.Join(cleanLines, " ") +} + +func (c *MarkdownCleaner) removeFrontmatter(markdown string) string { + if !strings.HasPrefix(markdown, "---") { + return markdown + } + lines := strings.Split(markdown, "\n") + if len(lines) < 3 { + return markdown + } + endIndex := -1 + for i := 1; i < len(lines); i++ { + if strings.TrimSpace(lines[i]) == "---" { + endIndex = i + break + } + } + if endIndex > 0 { + remainingLines := lines[endIndex+1:] + for len(remainingLines) > 0 && strings.TrimSpace(remainingLines[0]) == "" { + remainingLines = remainingLines[1:] + } + return strings.Join(remainingLines, "\n") + } + return markdown +} + +func (c *MarkdownCleaner) postProcess(text string) string { + multiNewlineRe := regexp.MustCompile(`\n{3,}`) + text = multiNewlineRe.ReplaceAllString(text, "\n\n") + text = strings.TrimSpace(text) + if !utf8.ValidString(text) { + text = strings.ToValidUTF8(text, "") + } + return text +} diff --git a/ai/component/rag/rag.go b/ai/component/rag/rag.go new file mode 100644 index 00000000..5fd08cfe --- /dev/null +++ b/ai/component/rag/rag.go @@ -0,0 +1,101 @@ +package rag + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" +) + +// RAG provides runtime-facing document split, index and retrieve operations. +type RAG struct { + Loader document.Loader + Splitter document.Transformer + Indexer indexer.Indexer + Retriever retriever.Retriever + Reranker Reranker +} + +func (s *RAG) Split(ctx context.Context, docs []*schema.Document) ([]*schema.Document, error) { + if s.Splitter == nil { + return docs, nil + } + return s.Splitter.Transform(ctx, docs) +} + +func (s *RAG) Index(ctx context.Context, namespace string, docs []*schema.Document, opts ...indexer.Option) ([]string, error) { + if s.Indexer == nil { + return nil, fmt.Errorf("indexer is nil") + } + if namespace == "" { + return s.Indexer.Store(ctx, docs, opts...) + } + all := append([]indexer.Option{WithIndexerNamespace(namespace)}, opts...) + return s.Indexer.Store(ctx, docs, all...) +} + +func (s *RAG) Retrieve(ctx context.Context, namespace string, queries []string, opts ...RetrieveOption) (map[string][]*RetrieveResult, error) { + if s.Retriever == nil { + return nil, fmt.Errorf("retriever is nil") + } + if len(queries) == 0 { + return map[string][]*RetrieveResult{}, nil + } + + var co CallOptions + for _, opt := range opts { + if opt != nil { + opt(&co) + } + } + + retrieveOpts := make([]retriever.Option, 0, 2) + if co.TopK != nil { + retrieveOpts = append(retrieveOpts, retriever.WithTopK(*co.TopK)) + } + if co.TargetIndex != nil && *co.TargetIndex != "" { + retrieveOpts = append(retrieveOpts, WithRetrieverImplTargetIndex(*co.TargetIndex)) + } + effectiveNamespace := namespace + if co.Namespace != nil { + effectiveNamespace = *co.Namespace + } + if effectiveNamespace != "" { + retrieveOpts = append(retrieveOpts, WithRetrieverImplNamespace(effectiveNamespace)) + } + + resp := make(map[string][]*RetrieveResult, len(queries)) + for _, query := range queries { + docs, err := s.Retriever.Retrieve(ctx, query, retrieveOpts...) + if err != nil { + return nil, fmt.Errorf("failed to retrieve for query %q: %w", query, err) + } + results := make([]*RetrieveResult, 0, len(docs)) + for _, doc := range docs { + results = append(results, &RetrieveResult{Content: doc.Content, RelevanceScore: 0}) + } + resp[query] = results + } + + if s.Reranker == nil { + return resp, nil + } + + final := make(map[string][]*RetrieveResult, len(resp)) + for query, raw := range resp { + docs := make([]*schema.Document, 0, len(raw)) + for _, r := range raw { + docs = append(docs, &schema.Document{Content: r.Content}) + } + reranked, err := s.Reranker.Rerank(ctx, query, docs, opts...) + if err != nil { + return nil, err + } + final[query] = reranked + } + + return final, nil +} diff --git a/ai/component/rag/rag.yaml b/ai/component/rag/rag.yaml new file mode 100644 index 00000000..40a1c3a5 --- /dev/null +++ b/ai/component/rag/rag.yaml @@ -0,0 +1,37 @@ +type: rag +spec: + embedder: + type: genkit + spec: + model: dashscope/qwen3-embedding + + loader: + type: local + spec: {} + + splitter: + type: recursive + spec: + chunk_size: 1000 + overlap_size: 100 + + indexer: + type: dev + spec: + storage_path: "../../data/ai/index" + index_format: sqlite + dimension: 1536 + + retriever: + type: dev + spec: + storage_path: "../../data/ai/index" + index_format: sqlite + dimension: 1536 + + reranker: + type: cohere + spec: + enabled: true + model: rerank-english-v3.0 + api_key: "${COHERE_API_KEY}" diff --git a/ai/component/rag/rerank.go b/ai/component/rag/rerank.go new file mode 100644 index 00000000..dd96d156 --- /dev/null +++ b/ai/component/rag/rerank.go @@ -0,0 +1,121 @@ +package rag + +import ( + "context" + + "fmt" + "os" + + "github.com/cloudwego/eino/schema" + cohere "github.com/cohere-ai/cohere-go/v2" + cohereClient "github.com/cohere-ai/cohere-go/v2/client" +) + +type cohereReranker struct { + cfg *cohereRerankerConfig +} + +type cohereRerankerConfig struct { + APIKey string + Model string + TopN int +} + +func (r *cohereReranker) Rerank(ctx context.Context, query string, docs any, opts ...RerankOption) ([]*RetrieveResult, error) { + if r == nil || r.cfg == nil { + return nil, fmt.Errorf("rerank config is nil") + } + if query == "" { + return nil, fmt.Errorf("query is empty") + } + + // Convert docs to []*schema.Document + var schemaDocs []*schema.Document + switch v := docs.(type) { + case []*schema.Document: + schemaDocs = v + case []any: + schemaDocs = make([]*schema.Document, 0, len(v)) + for _, item := range v { + if doc, ok := item.(*schema.Document); ok { + schemaDocs = append(schemaDocs, doc) + } + } + default: + return nil, fmt.Errorf("unsupported docs type: %T", docs) + } + + if len(schemaDocs) == 0 { + return []*RetrieveResult{}, nil + } + + apiKey := r.cfg.APIKey + if apiKey == "" { + apiKey = os.Getenv("COHERE_API_KEY") + } + if apiKey == "" { + return nil, fmt.Errorf("COHERE_API_KEY is not set") + } + + var co CallOptions + for _, opt := range opts { + if opt != nil { + opt(&co) + } + } + + topN := r.cfg.TopN + if co.TopN != nil { + topN = *co.TopN + } + if topN <= 0 { + topN = 3 + } + + texts := make([]*string, 0, len(schemaDocs)) + for _, d := range schemaDocs { + c := d.Content + texts = append(texts, &c) + } + + res, err := rerank(apiKey, r.cfg.Model, query, texts, topN) + if err != nil { + return nil, err + } + + out := make([]*RetrieveResult, 0, len(res)) + for _, item := range res { + if item.Index < 0 || item.Index >= len(schemaDocs) { + continue + } + out = append(out, &RetrieveResult{Content: schemaDocs[item.Index].Content, RelevanceScore: item.RelevanceScore}) + } + + return out, nil +} + +func rerank(apiKey, model, query string, documents []*string, topN int) ([]*cohere.RerankResponseResultsItem, error) { + client := cohereClient.NewClient(cohereClient.WithToken(apiKey)) + + var rerankDocs []*cohere.RerankRequestDocumentsItem + for _, doc := range documents { + rerankDoc := &cohere.RerankRequestDocumentsItem{} + rerankDoc.String = *doc + rerankDocs = append(rerankDocs, rerankDoc) + } + + rerankResponse, err := client.Rerank( + context.Background(), + &cohere.RerankRequest{ + Query: query, + Documents: rerankDocs, + TopN: &topN, + Model: &model, + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to call rerank API: %w", err) + } + + return rerankResponse.Results, nil +} diff --git a/ai/component/rag/retriever.go b/ai/component/rag/retriever.go new file mode 100644 index 00000000..d9320a3b --- /dev/null +++ b/ai/component/rag/retriever.go @@ -0,0 +1,237 @@ +package rag + +import ( + "context" + + "dubbo-admin-ai/utils" + "fmt" + "sync" + + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/localvec" + "github.com/firebase/genkit/go/plugins/pinecone" +) + +// --- Retriever --- +type PineconeRetriever struct { + g *genkit.Genkit + embedder string + target string + defaultK int + retriever map[string]ai.Retriever // keyed by target index +} + +func newPineconeRetriever(g *genkit.Genkit, embedderModel string, targetIndex string, topK int) *PineconeRetriever { + return &PineconeRetriever{ + g: g, + embedder: embedderModel, + target: targetIndex, + defaultK: topK, + } +} + +func (r *PineconeRetriever) getRetriever(ctx context.Context, targetIndex string) (ai.Retriever, error) { + if targetIndex == "" { + targetIndex = "default" + } + + if r.retriever == nil { + r.retriever = make(map[string]ai.Retriever) + } + ret := r.retriever[targetIndex] + if ret != nil { + return ret, nil + } + + embedder := genkit.LookupEmbedder(r.g, r.embedder) + if embedder == nil { + return nil, fmt.Errorf("failed to find embedder %s", r.embedder) + } + + var err error + if !pinecone.IsDefinedRetriever(r.g, targetIndex) { + _, ret, err = pinecone.DefineRetriever(ctx, r.g, + pinecone.Config{ + IndexID: targetIndex, + Embedder: embedder, + }, + &ai.RetrieverOptions{ + Label: targetIndex, + ConfigSchema: core.InferSchemaMap(pinecone.PineconeRetrieverOptions{}), + }) + } else { + ret = pinecone.Retriever(r.g, targetIndex) + } + if err != nil { + return nil, fmt.Errorf("failed to define retriever: %w", err) + } + + if r.retriever == nil { + r.retriever = make(map[string]ai.Retriever) + } + if existing := r.retriever[targetIndex]; existing != nil { + ret = existing + } else { + r.retriever[targetIndex] = ret + } + + return ret, nil +} + +func (r *PineconeRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + impl := retriever.GetImplSpecificOptions(&CommonRetrieverOptions{}, opts...) + effectiveTarget := r.target + if impl.TargetIndex != nil && *impl.TargetIndex != "" { + effectiveTarget = *impl.TargetIndex + } + ret, err := r.getRetriever(ctx, effectiveTarget) + if err != nil { + return nil, err + } + + // Options handling + // Default options + defaultK := r.defaultK + pineconeOpts := &pinecone.PineconeRetrieverOptions{ + K: defaultK, // Default TopK + } + + // Apply Eino common options + commonOpts := retriever.GetCommonOptions(&retriever.Options{ + TopK: &defaultK, + }, opts...) + + if commonOpts.TopK != nil { + pineconeOpts.K = *commonOpts.TopK + } + + // Apply implementation specific options (for Namespace) + if impl.Namespace != "" { + pineconeOpts.Namespace = impl.Namespace + } + + // Retrieve + resp, err := ret.Retrieve(ctx, &ai.RetrieverRequest{ + Query: ai.DocumentFromText(query, nil), + Options: pineconeOpts, + }) + if err != nil { + return nil, fmt.Errorf("failed to retrieve: %w", err) + } + + docs := utils.ToEinoDocuments(resp.Documents) + + return docs, nil +} + +type DevRetriever struct { + g *genkit.Genkit + embedder string + target string + defaultK int + mu sync.Mutex + retriever map[string]ai.Retriever // keyed by target index +} + +func newDevRetriever(g *genkit.Genkit, embedderModel string, targetIndex string, topK int) *DevRetriever { + return &DevRetriever{ + g: g, + embedder: embedderModel, + target: targetIndex, + defaultK: topK, + } +} + +func (r *DevRetriever) getRetriever(ctx context.Context, targetIndex string) (ai.Retriever, error) { + if targetIndex == "" { + targetIndex = "default" + } + + r.mu.Lock() + if r.retriever == nil { + r.retriever = make(map[string]ai.Retriever) + } + ret := r.retriever[targetIndex] + r.mu.Unlock() + if ret != nil { + return ret, nil + } + + embedder := genkit.LookupEmbedder(r.g, r.embedder) + if embedder == nil { + return nil, fmt.Errorf("failed to find embedder %s", r.embedder) + } + + if err := localvec.Init(); err != nil { + return nil, fmt.Errorf("failed to init localvec: %w", err) + } + + localvecConfig := localvec.Config{Embedder: embedder} + + var err error + if localvec.IsDefinedRetriever(r.g, targetIndex) { + ret = localvec.Retriever(r.g, targetIndex) + } else { + _, ret, err = localvec.DefineRetriever(r.g, targetIndex, localvecConfig, nil) + } + if err != nil { + return nil, fmt.Errorf("failed to define localvec retriever: %w", err) + } + + r.mu.Lock() + if r.retriever == nil { + r.retriever = make(map[string]ai.Retriever) + } + if existing := r.retriever[targetIndex]; existing != nil { + ret = existing + } else { + r.retriever[targetIndex] = ret + } + r.mu.Unlock() + + return ret, nil +} + +func (r *DevRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + impl := retriever.GetImplSpecificOptions(&CommonRetrieverOptions{}, opts...) + effectiveTarget := r.target + if impl.TargetIndex != nil && *impl.TargetIndex != "" { + effectiveTarget = *impl.TargetIndex + } + ret, err := r.getRetriever(ctx, effectiveTarget) + if err != nil { + return nil, err + } + + // Options handling + defaultK := r.defaultK + // Apply Eino common options + commonOpts := retriever.GetCommonOptions(&retriever.Options{ + TopK: &defaultK, + }, opts...) + + k := defaultK + if commonOpts.TopK != nil { + k = *commonOpts.TopK + } + + // Retrieve + retrieverReq := &ai.RetrieverRequest{ + Query: ai.DocumentFromText(query, nil), + Options: &localvec.RetrieverOptions{ + K: k, + }, + } + resp, err := ret.Retrieve(ctx, retrieverReq) + if err != nil { + return nil, fmt.Errorf("failed to retrieve: %w", err) + } + + docs := utils.ToEinoDocuments(resp.Documents) + + return docs, nil +} diff --git a/ai/component/rag/test/rag_config_test.go b/ai/component/rag/test/rag_config_test.go new file mode 100644 index 00000000..4b40c7d4 --- /dev/null +++ b/ai/component/rag/test/rag_config_test.go @@ -0,0 +1,31 @@ +package ragtest + +import ( + compRag "dubbo-admin-ai/component/rag" + "dubbo-admin-ai/config" + "strings" + "testing" + + "gopkg.in/yaml.v3" +) + +func encodeToYAMLNode(v any) yaml.Node { + var node yaml.Node + node.Encode(v) + return node +} + +func TestRAGComponent_Validate(t *testing.T) { + cfg := &compRag.RAGSpec{ + Embedder: &config.Config{Type: "genkit", Spec: encodeToYAMLNode(&compRag.EmbedderSpec{Model: "dashscope/qwen3-embedding"})}, + Loader: &config.Config{Type: "local", Spec: encodeToYAMLNode(&compRag.LoaderSpec{})}, + Splitter: &config.Config{Type: "recursive", Spec: encodeToYAMLNode(&compRag.SplitterSpec{ChunkSize: 100, OverlapSize: 100})}, + Indexer: &config.Config{Type: "dev", Spec: encodeToYAMLNode(compRag.DefaultIndexerSpec())}, + Retriever: &config.Config{Type: "dev", Spec: encodeToYAMLNode(compRag.DefaultRetrieverSpec())}, + } + + err := cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "overlap_size") { + t.Fatalf("expected splitter semantic validation error, got %v", err) + } +} diff --git a/ai/component/rag/test/workflow_test.go b/ai/component/rag/test/workflow_test.go new file mode 100644 index 00000000..ccd65dda --- /dev/null +++ b/ai/component/rag/test/workflow_test.go @@ -0,0 +1,149 @@ +package ragtest + +import ( + "context" + "strings" + "sync" + "testing" + + compRag "dubbo-admin-ai/component/rag" + + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/indexer" + einoRetriever "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" +) + +type workflowStore struct { + mu sync.RWMutex + docs map[string][]*schema.Document +} + +func newWorkflowStore() *workflowStore { + return &workflowStore{docs: make(map[string][]*schema.Document)} +} + +type workflowIndexer struct { + store *workflowStore + lastCount int +} + +func (w *workflowIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) { + impl := indexer.GetImplSpecificOptions(&compRag.CommonIndexerOptions{}, opts...) + ns := impl.Namespace + w.store.mu.Lock() + w.store.docs[ns] = append(w.store.docs[ns], docs...) + w.store.mu.Unlock() + w.lastCount = len(docs) + ids := make([]string, len(docs)) + for i := range docs { + ids[i] = docs[i].ID + } + return ids, nil +} + +type workflowRetriever struct{ store *workflowStore } + +func (w *workflowRetriever) Retrieve(ctx context.Context, query string, opts ...einoRetriever.Option) ([]*schema.Document, error) { + impl := einoRetriever.GetImplSpecificOptions(&compRag.CommonRetrieverOptions{}, opts...) + ns := impl.Namespace + w.store.mu.RLock() + defer w.store.mu.RUnlock() + all := w.store.docs[ns] + out := make([]*schema.Document, 0) + for _, d := range all { + if strings.Contains(strings.ToLower(d.Content), strings.ToLower(query)) { + out = append(out, d) + } + } + return out, nil +} + +type workflowSplitter struct{} + +func (w *workflowSplitter) Transform(ctx context.Context, src []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) { + if len(src) == 0 { + return src, nil + } + return []*schema.Document{{ID: "c1", Content: src[0].Content[:len(src[0].Content)/2]}, {ID: "c2", Content: src[0].Content[len(src[0].Content)/2:]}}, nil +} + +func TestRAGWorkflow(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + { + name: "index_retrieve", + run: func(t *testing.T) { + ctx := context.Background() + store := newWorkflowStore() + r := &compRag.RAG{Indexer: &workflowIndexer{store: store}, Retriever: &workflowRetriever{store: store}} + _, err := r.Index(ctx, "ns", []*schema.Document{{ID: "1", Content: "Dubbo is RPC"}}) + if err != nil { + t.Fatalf("Index() error: %v", err) + } + got, err := r.Retrieve(ctx, "ns", []string{"RPC"}) + if err != nil { + t.Fatalf("Retrieve() error: %v", err) + } + if len(got["RPC"]) == 0 || !strings.Contains(got["RPC"][0].Content, "Dubbo") { + t.Fatalf("expected retrieval to include Dubbo, got %+v", got) + } + }, + }, + { + name: "split_index", + run: func(t *testing.T) { + ctx := context.Background() + idx := &workflowIndexer{store: newWorkflowStore()} + r := &compRag.RAG{Splitter: &workflowSplitter{}, Indexer: idx} + split, err := r.Split(ctx, []*schema.Document{{ID: "doc", Content: "abcdefghijklmnopqrstuvwxyz"}}) + if err != nil { + t.Fatalf("Split() error: %v", err) + } + if len(split) <= 1 { + t.Fatalf("expected split chunks > 1, got %d", len(split)) + } + if _, err := r.Index(ctx, "ns", split); err != nil { + t.Fatalf("Index() error: %v", err) + } + if idx.lastCount != len(split) { + t.Fatalf("indexed count = %d, want %d", idx.lastCount, len(split)) + } + }, + }, + { + name: "namespace", + run: func(t *testing.T) { + ctx := context.Background() + store := newWorkflowStore() + r := &compRag.RAG{Indexer: &workflowIndexer{store: store}, Retriever: &workflowRetriever{store: store}} + _, _ = r.Index(ctx, "ns1", []*schema.Document{{ID: "1", Content: "alpha only"}}) + _, _ = r.Index(ctx, "ns2", []*schema.Document{{ID: "2", Content: "beta only"}}) + got, err := r.Retrieve(ctx, "ns1", []string{"beta"}) + if err != nil { + t.Fatalf("Retrieve() error: %v", err) + } + if len(got["beta"]) != 0 { + t.Fatalf("expected namespace isolation, got %+v", got) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { tt.run(t) }) + } +} + +func TestRAG_Retrieve(t *testing.T) { + r := &compRag.RAG{Retriever: &workflowRetriever{store: newWorkflowStore()}} + got, err := r.Retrieve(context.Background(), "ns", nil) + if err != nil { + t.Fatalf("Retrieve() error: %v", err) + } + if got == nil || len(got) != 0 { + t.Fatalf("expected non-nil empty map, got %+v", got) + } +} diff --git a/ai/component/server/component.go b/ai/component/server/component.go new file mode 100644 index 00000000..dbfad150 --- /dev/null +++ b/ai/component/server/component.go @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "context" + "dubbo-admin-ai/component/agent/react" + "dubbo-admin-ai/component/server/engine" + "dubbo-admin-ai/runtime" + "fmt" + "net/http" + "time" + + "github.com/gin-gonic/gin" +) + +// ServerComponent Server component implementation +type ServerComponent struct { + rt *runtime.Runtime + srv *http.Server + + port int + host string + debug bool + corsOrigins []string + readTimeout int + writeTimeout int +} + +func NewServerComponent( + port int, + host string, + debug bool, + corsOrigins []string, + timeouts ...int, +) (runtime.Component, error) { + readTimeout := 30 + writeTimeout := 30 + if len(timeouts) > 0 { + readTimeout = timeouts[0] + } + if len(timeouts) > 1 { + writeTimeout = timeouts[1] + } + return &ServerComponent{ + port: port, + host: host, + debug: debug, + corsOrigins: corsOrigins, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + }, nil +} + +// Name returns the component name +func (s *ServerComponent) Name() string { + return "server" +} + +func (s *ServerComponent) Validate() error { + if s.host == "" { + return fmt.Errorf("host is required") + } + if s.port <= 0 || s.port > 65535 { + return fmt.Errorf("port must be between 1 and 65535") + } + if s.readTimeout <= 0 { + return fmt.Errorf("read_timeout must be greater than 0") + } + if s.writeTimeout <= 0 { + return fmt.Errorf("write_timeout must be greater than 0") + } + return nil +} + +func (s *ServerComponent) Init(rt *runtime.Runtime) error { + s.rt = rt + rt.GetLogger().Info("Server component initialized", + "port", s.port, + "host", s.host) + + return nil +} + +func (s *ServerComponent) Start() error { + // Retrieve Agent component from Runtime + agentComp, err := s.rt.GetComponent("agent") + if err != nil { + s.rt.GetLogger().Error("Failed to get agent component", "error", err) + return fmt.Errorf("failed to get agent component: %w", err) + } + + agentComponent, ok := agentComp.(*react.AgentComponent) + if !ok { + s.rt.GetLogger().Error("Agent component is not an AgentComponent") + return fmt.Errorf("agent component is not an AgentComponent") + } + + // Configure Gin mode + if !s.debug { + gin.SetMode(gin.ReleaseMode) + } else { + gin.SetMode(gin.DebugMode) + } + + // Create router with AI interface + router := engine.NewRouter(agentComponent.Agent) + + // Add health check endpoint + router.GetEngine().GET("/health", func(c *gin.Context) { + c.JSON(200, gin.H{ + "status": "ok", + }) + }) + + // Create HTTP server + s.srv = &http.Server{ + Addr: fmt.Sprintf("%s:%d", s.host, s.port), + Handler: router.GetEngine(), + } + + s.rt.GetLogger().Info("Server starting", + "addr", s.srv.Addr, + "debug", s.debug) + + go func() { + if err := s.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + s.rt.GetLogger().Error("Server failed", "error", err) + } + }() + + return nil +} + +func (s *ServerComponent) Stop() error { + if s.srv != nil { + s.rt.GetLogger().Info("Server shutting down") + + // Gracefully shutdown the server using Shutdown instead of Close + // Shutdown will: + // 1. Stop accepting new connections + // 2. Wait for existing requests to complete (up to 10 seconds) + // 3. Close all connections + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Force close if timeout occurs + if err := s.srv.Shutdown(ctx); err != nil { + s.rt.GetLogger().Error("Server shutdown timeout, forcing close", "error", err) + return s.srv.Close() + } + + s.rt.GetLogger().Info("Server shutdown gracefully completed") + return nil + } + return nil +} diff --git a/ai/component/server/config.go b/ai/component/server/config.go new file mode 100644 index 00000000..2d8aa276 --- /dev/null +++ b/ai/component/server/config.go @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +// ServerSpec defines server configuration +type ServerSpec struct { + Port int `yaml:"port"` + Host string `yaml:"host"` + Debug bool `yaml:"debug"` + CORSOrigins []string `yaml:"cors_origins"` + ReadTimeout int `yaml:"read_timeout"` + WriteTimeout int `yaml:"write_timeout"` +} + +// DefaultServerSpec returns default server configuration +func DefaultServerSpec() *ServerSpec { + return &ServerSpec{ + Port: 8888, + Host: "0.0.0.0", + Debug: false, + CORSOrigins: []string{"*"}, + ReadTimeout: 30, + WriteTimeout: 30, + } +} diff --git a/ai/server/docs/openapi.yaml b/ai/component/server/engine/docs/openapi.yaml similarity index 100% rename from ai/server/docs/openapi.yaml rename to ai/component/server/engine/docs/openapi.yaml diff --git a/ai/server/handlers.go b/ai/component/server/engine/handlers.go similarity index 79% rename from ai/server/handlers.go rename to ai/component/server/engine/handlers.go index e9738143..1be96c1a 100644 --- a/ai/server/handlers.go +++ b/ai/component/server/engine/handlers.go @@ -1,26 +1,26 @@ -package server +package engine import ( "fmt" "net/http" - "dubbo-admin-ai/manager" - "dubbo-admin-ai/server/sse" + "dubbo-admin-ai/component/server/engine/session" + "dubbo-admin-ai/component/server/engine/sse" + rt "dubbo-admin-ai/runtime" - "dubbo-admin-ai/agent" + "dubbo-admin-ai/component/agent" "dubbo-admin-ai/schema" - "dubbo-admin-ai/server/session" "github.com/gin-gonic/gin" ) -// AgentHandler AI Agent处理器 +// AgentHandler handles AI Agent requests type AgentHandler struct { agent agent.Agent sessionMgr *session.Manager } -// NewAgentHandler 创建AI Agent处理器 +// NewAgentHandler creates an AI Agent handler func NewAgentHandler(agent agent.Agent, sessionMgr *session.Manager) *AgentHandler { sessionMgr.CreateMockSession() return &AgentHandler{ @@ -29,7 +29,7 @@ func NewAgentHandler(agent agent.Agent, sessionMgr *session.Manager) *AgentHandl } } -// StreamChat 流式聊天接口 +// StreamChat handles streaming chat endpoint func (h *AgentHandler) StreamChat(c *gin.Context) { var ( req ChatRequest @@ -41,14 +41,14 @@ func (h *AgentHandler) StreamChat(c *gin.Context) { err error ) - // 解析请求 + // Parse request if err = c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, NewErrorResponse("Invalid request: "+err.Error())) return } sessionID = req.SessionID - // 验证session存在并更新活动时间 + // Validate session exists and update activity time session, err = h.sessionMgr.GetSession(sessionID) if err != nil { c.JSON(http.StatusBadRequest, NewErrorResponse("Invalid session ID: "+err.Error())) @@ -62,7 +62,7 @@ func (h *AgentHandler) StreamChat(c *gin.Context) { } sseHandler = sse.NewStreamHandler(streamWriter, sessionID) - // 设置响应头和错误恢复 + // Set response headers and error recovery defer func() { if r := recover(); r != nil { sseHandler.HandleError("internal_error", fmt.Sprintf("internal error: %v", r)) @@ -83,7 +83,7 @@ func (h *AgentHandler) StreamChat(c *gin.Context) { } if err != nil { sseHandler.HandleError("agent_error", fmt.Sprintf("agent error: %v", err)) - manager.GetLogger().Error("Agent interaction error", "session_id", sessionID, "error", err) + rt.GetLogger().Error("Agent interaction error", "session_id", sessionID, "error", err) channels.Close() return } @@ -96,31 +96,31 @@ func (h *AgentHandler) StreamChat(c *gin.Context) { h.MessageDelta(sseHandler, feedback.Final()) } else if feedback.IsDone() { if err := sseHandler.HandleContentBlockStop(feedback.Index()); err != nil { - manager.GetLogger().Error("Failed to handle content block stop", "error", err) + rt.GetLogger().Error("Failed to handle content block stop", "error", err) } } else { if err := sseHandler.HandleText(feedback.Text(), feedback.Index()); err != nil { - manager.GetLogger().Error("Failed to handle text", "error", err) + rt.GetLogger().Error("Failed to handle text", "error", err) } } case <-c.Request.Context().Done(): - manager.GetLogger().Info("Client disconnected from stream") + rt.GetLogger().Info("Client disconnected from stream") return default: if channels.Closed() { if err := sseHandler.FinishStream(); err != nil { - manager.GetLogger().Error("Failed to finish stream", "error", err) + rt.GetLogger().Error("Failed to finish stream", "error", err) } - manager.GetLogger().Info("Stream processing completed", "session_id", sessionID) + rt.GetLogger().Info("Stream processing completed", "session_id", sessionID) return } } } } -// finishmessageUsage 完成流并处理使用情况 +// MessageDelta finishes stream and handles usage func (h *AgentHandler) MessageDelta(sseHandler *sse.SSEHandler, output schema.Schema) { stopReason := "end_turn" if err := sseHandler.MessageDeltaWithUsage(stopReason, output); err != nil { @@ -162,7 +162,7 @@ func (h *AgentHandler) ListSessions(c *gin.Context) { c.JSON(http.StatusOK, NewSuccessResponse(response)) } -// DeleteSession 删除会话 +// DeleteSession deletes a session func (h *AgentHandler) DeleteSession(c *gin.Context) { sessionID := c.Param("sessionId") if sessionID == "" { @@ -176,10 +176,10 @@ func (h *AgentHandler) DeleteSession(c *gin.Context) { return } - // 删除对应的 history + // Delete corresponding history if agentMemory := h.agent.GetMemory(); agentMemory != nil { agentMemory.Clear(sessionID) - manager.GetLogger().Info("Session history cleared", "session_id", sessionID) + rt.GetLogger().Info("Session history cleared", "session_id", sessionID) } c.JSON(http.StatusOK, NewSuccessResponse(map[string]string{ diff --git a/ai/component/server/engine/models.go b/ai/component/server/engine/models.go new file mode 100644 index 00000000..eecc3cce --- /dev/null +++ b/ai/component/server/engine/models.go @@ -0,0 +1,46 @@ +package engine + +import ( + "time" + + "github.com/google/uuid" +) + +// Response defines unified API response format +type Response struct { + Message string `json:"message"` // Response message + Data any `json:"data,omitempty"` // Response data + RequestID string `json:"request_id"` // Request ID for tracking + Timestamp int64 `json:"timestamp"` // Response timestamp +} + +// NewResponse creates a response +func NewResponse(message string, data any) *Response { + return &Response{ + Message: message, + Data: data, + RequestID: generateRequestID(), + Timestamp: time.Now().Unix(), + } +} + +// NewSuccessResponse creates a success response +func NewSuccessResponse(data any) *Response { + return NewResponse("success", data) +} + +// NewErrorResponse creates an error response +func NewErrorResponse(message string) *Response { + return NewResponse(message, nil) +} + +// ChatRequest defines streaming chat request +type ChatRequest struct { + Message string `json:"message" binding:"required"` // User message + SessionID string `json:"sessionID" binding:"required"` // Session ID +} + +// generateRequestID generates a request ID +func generateRequestID() string { + return "req_" + uuid.New().String() +} diff --git a/ai/server/router.go b/ai/component/server/engine/router.go similarity index 76% rename from ai/server/router.go rename to ai/component/server/engine/router.go index 3de8c584..8f1899df 100644 --- a/ai/server/router.go +++ b/ai/component/server/engine/router.go @@ -1,8 +1,8 @@ -package server +package engine import ( - "dubbo-admin-ai/agent/react" - "dubbo-admin-ai/server/session" + "dubbo-admin-ai/component/agent/react" + "dubbo-admin-ai/component/server/engine/session" "github.com/gin-gonic/gin" ) @@ -28,29 +28,29 @@ func NewRouter(agent *react.ReActAgent) *Router { } func (r *Router) setupRoutes() { - // 添加CORS中间件 + // Add CORS middleware r.engine.Use(corsMiddleware()) - // API v1 组 + // API v1 group v1 := r.engine.Group("/api/v1/ai") { - // 聊天相关 - v1.POST("/chat/stream", r.handler.StreamChat) // 流式聊天 + // Chat endpoints + v1.POST("/chat/stream", r.handler.StreamChat) // Streaming chat - // 会话管理 - v1.POST("/sessions", r.handler.CreateSession) // 创建会话 - v1.GET("/sessions", r.handler.ListSessions) // 列出会话 - v1.GET("/sessions/:sessionId", r.handler.GetSession) // 获取会话信息 - v1.DELETE("/sessions/:sessionId", r.handler.DeleteSession) // 删除会话 + // Session management + v1.POST("/sessions", r.handler.CreateSession) // Create session + v1.GET("/sessions", r.handler.ListSessions) // List sessions + v1.GET("/sessions/:sessionId", r.handler.GetSession) // Get session info + v1.DELETE("/sessions/:sessionId", r.handler.DeleteSession) // Delete session } } -// GetEngine 获取Gin引擎 +// GetEngine returns the Gin engine func (r *Router) GetEngine() *gin.Engine { return r.engine } -// corsMiddleware CORS中间件 +// corsMiddleware provides CORS middleware func corsMiddleware() gin.HandlerFunc { return func(c *gin.Context) { c.Header("Access-Control-Allow-Origin", "*") diff --git a/ai/server/session/session.go b/ai/component/server/engine/session/session.go similarity index 67% rename from ai/server/session/session.go rename to ai/component/server/engine/session/session.go index 4db8fedb..4f58bd95 100644 --- a/ai/server/session/session.go +++ b/ai/component/server/engine/session/session.go @@ -5,7 +5,7 @@ import ( "sync" "time" - "dubbo-admin-ai/manager" + rt "dubbo-admin-ai/runtime" "github.com/google/uuid" ) @@ -15,35 +15,35 @@ var ( ErrSessionExpired = errors.New("session expired") ) -// Session 简化的会话实例,不再管理history +// Session is a simplified session instance, does not manage history type Session struct { - ID string `json:"id"` // 会话ID - CreatedAt time.Time `json:"created_at"` // 创建时间 - UpdatedAt time.Time `json:"updated_at"` // 最后更新时间 - Status string `json:"status"` // 会话状态: "active", "closed" - mu sync.RWMutex `json:"-"` // 读写锁 + ID string `json:"id"` // Session ID + CreatedAt time.Time `json:"created_at"` // Creation time + UpdatedAt time.Time `json:"updated_at"` // Last update time + Status string `json:"status"` // Session status: "active", "closed" + mu sync.RWMutex `json:"-"` // Read-write lock } -// UpdateActivity 更新会话活动时间 +// UpdateActivity updates session activity time func (s *Session) UpdateActivity() { s.mu.Lock() defer s.mu.Unlock() s.UpdatedAt = time.Now() } -// Close 关闭会话 +// Close closes the session func (s *Session) Close() { s.mu.Lock() defer s.mu.Unlock() s.Status = "closed" } -// IsExpired 检查会话是否过期(24小时) +// IsExpired checks if session is expired (24 hours) func (s *Session) IsExpired() bool { return time.Since(s.UpdatedAt) > 24*time.Hour } -// ToSessionInfo 转换为API响应格式 +// ToSessionInfo converts to API response format func (s *Session) ToSessionInfo() map[string]any { s.mu.RLock() defer s.mu.RUnlock() @@ -56,19 +56,19 @@ func (s *Session) ToSessionInfo() map[string]any { } } -// Manager 会话管理器 +// Manager manages sessions type Manager struct { sessions map[string]*Session mu sync.RWMutex } -// NewManager 创建会话管理器 +// NewManager creates a session manager func NewManager() *Manager { m := &Manager{ sessions: make(map[string]*Session), } - // 启动定期清理过期会话的goroutine + // Start goroutine to periodically cleanup expired sessions go m.cleanupExpiredSessions() return m @@ -82,11 +82,11 @@ func (m *Manager) CreateMockSession() *Session { Status: "active", } m.sessions[session.ID] = session - manager.GetLogger().Info("Session created", "session_id", session.ID) + rt.GetLogger().Info("Session created", "session_id", session.ID) return session } -// CreateSession 创建新会话 +// CreateSession creates a new session func (m *Manager) CreateSession() *Session { m.mu.Lock() defer m.mu.Unlock() @@ -101,11 +101,11 @@ func (m *Manager) CreateSession() *Session { m.sessions[sessionID] = session - manager.GetLogger().Info("Session created", "session_id", sessionID) + rt.GetLogger().Info("Session created", "session_id", sessionID) return session } -// GetSession 获取会话 +// GetSession retrieves a session func (m *Manager) GetSession(sessionID string) (*Session, error) { m.mu.RLock() defer m.mu.RUnlock() @@ -122,7 +122,7 @@ func (m *Manager) GetSession(sessionID string) (*Session, error) { return session, nil } -// DeleteSession 删除会话 +// DeleteSession deletes a session func (m *Manager) DeleteSession(sessionID string) error { m.mu.Lock() defer m.mu.Unlock() @@ -135,11 +135,11 @@ func (m *Manager) DeleteSession(sessionID string) error { session.Close() delete(m.sessions, sessionID) - manager.GetLogger().Info("Session deleted", "session_id", sessionID) + rt.GetLogger().Info("Session deleted", "session_id", sessionID) return nil } -// ListSessions 列出所有活跃会话 +// ListSessions lists all active sessions func (m *Manager) ListSessions() []map[string]any { m.mu.RLock() defer m.mu.RUnlock() @@ -154,9 +154,9 @@ func (m *Manager) ListSessions() []map[string]any { return sessions } -// cleanupExpiredSessions 定期清理过期会话 +// cleanupExpiredSessions periodically cleans up expired sessions func (m *Manager) cleanupExpiredSessions() { - ticker := time.NewTicker(1 * time.Hour) // 每小时清理一次 + ticker := time.NewTicker(1 * time.Hour) // Cleanup every hour defer ticker.Stop() for range ticker.C { @@ -175,14 +175,14 @@ func (m *Manager) cleanupExpiredSessions() { } if len(expiredSessions) > 0 { - manager.GetLogger().Info("Cleaned up expired sessions", "count", len(expiredSessions)) + rt.GetLogger().Info("Cleaned up expired sessions", "count", len(expiredSessions)) } m.mu.Unlock() } } -// generateSessionID 生成会话ID +// generateSessionID generates a session ID func generateSessionID() string { return "session_" + uuid.New().String() } diff --git a/ai/server/sse/sse.go b/ai/component/server/engine/sse/sse.go similarity index 83% rename from ai/server/sse/sse.go rename to ai/component/server/engine/sse/sse.go index 7b9e13cf..ba723d0a 100644 --- a/ai/server/sse/sse.go +++ b/ai/component/server/engine/sse/sse.go @@ -72,22 +72,22 @@ type Delta struct { StopSequence *string `json:"stop_sequence,omitempty"` } -// ErrorInfo 错误信息 +// ErrorInfo defines error information type ErrorInfo struct { Type string `json:"type"` Message string `json:"message"` } -// StreamWriter SSE流写入器 +// StreamWriter handles SSE stream writing type StreamWriter struct { ctx context.Context writer gin.ResponseWriter flusher http.Flusher } -// NewStreamWriter 创建SSE流写入器 +// NewStreamWriter creates an SSE stream writer func NewStreamWriter(c *gin.Context) (*StreamWriter, error) { - // 设置SSE响应头 + // Set SSE response headers c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") @@ -106,21 +106,21 @@ func NewStreamWriter(c *gin.Context) (*StreamWriter, error) { }, nil } -// WriteEvent 写入SSE事件的通用方法 +// WriteEvent writes SSE event (generic method) func (sw *StreamWriter) WriteEvent(eventType SSEType, data any) error { - // 检查上下文是否已取消 + // Check if context is cancelled select { case <-sw.ctx.Done(): return sw.ctx.Err() default: } - // 构建SSE数据 + // Build SSE data eventData := map[string]any{ "type": eventType, } - // 根据事件类型添加相应数据 + // Add corresponding data based on event type if data != nil { switch eventType { case MessageStart: @@ -166,29 +166,29 @@ func (sw *StreamWriter) WriteEvent(eventType SSEType, data any) error { } } - // 将数据序列化为JSON + // Serialize data to JSON jsonData, err := json.Marshal(eventData) if err != nil { return fmt.Errorf("failed to marshal event data: %w", err) } - // 写入SSE格式数据 + // Write SSE formatted data _, err = fmt.Fprintf(sw.writer, "event: %s\ndata: %s\n\n", eventType, string(jsonData)) if err != nil { return fmt.Errorf("failed to write SSE data: %w", err) } - // 立即刷新缓冲区 + // Flush buffer immediately sw.flusher.Flush() return nil } -// WriteMessageStart 写入消息开始事件 +// WriteMessageStart writes message start event func (sw *StreamWriter) WriteMessageStart(msg *Message) error { return sw.WriteEvent(MessageStart, msg) } -// WriteContentBlockStart 写入内容块开始事件 +// WriteContentBlockStart writes content block start event func (sw *StreamWriter) WriteContentBlockStart(index int, contentBlock *ContentBlock) error { data := map[string]any{ "index": index, @@ -197,7 +197,7 @@ func (sw *StreamWriter) WriteContentBlockStart(index int, contentBlock *ContentB return sw.WriteEvent(ContentBlockStart, data) } -// WriteContentBlockDelta 写入内容块增量事件 +// WriteContentBlockDelta writes content block delta event func (sw *StreamWriter) WriteContentBlockDelta(index int, delta *Delta) error { data := map[string]any{ "index": index, @@ -206,12 +206,12 @@ func (sw *StreamWriter) WriteContentBlockDelta(index int, delta *Delta) error { return sw.WriteEvent(ContentBlockDelta, data) } -// WriteContentBlockStop 写入内容块结束事件 +// WriteContentBlockStop writes content block stop event func (sw *StreamWriter) WriteContentBlockStop(index int) error { return sw.WriteEvent(ContentBlockStop, index) } -// WriteMessageDelta 写入消息增量事件 +// WriteMessageDelta writes message delta event func (sw *StreamWriter) WriteMessageDelta(delta *Delta, usage *ai.GenerationUsage) error { dataMap := map[string]any{ "delta": delta, @@ -220,17 +220,17 @@ func (sw *StreamWriter) WriteMessageDelta(delta *Delta, usage *ai.GenerationUsag return sw.WriteEvent(MessageDelta, dataMap) } -// WriteMessageStop 写入消息结束事件 +// WriteMessageStop writes message stop event func (sw *StreamWriter) WriteMessageStop() error { return sw.WriteEvent(MessageStop, nil) } -// WriteError 写入错误事件 +// WriteError writes error event func (sw *StreamWriter) WriteError(errorInfo *ErrorInfo) error { return sw.WriteEvent(ErrorEvent, errorInfo) } -// WritePing 写入ping事件 +// WritePing writes ping event // func (sw *StreamWriter) WritePing() error { // _, err := fmt.Fprintf(sw.writer, "event: ping\ndata: {\"type\": \"ping\"}\n\n") // if err != nil { @@ -247,9 +247,9 @@ type SSEHandler struct { ContentStarted bool } -// NewStreamHandler 创建流式处理器 +// NewStreamHandler creates a stream handler func NewStreamHandler(writer *StreamWriter, sessionID string) *SSEHandler { - // 生成消息ID + // Generate message ID messageID := fmt.Sprintf("msg_%s", uuid.New().String()) return &SSEHandler{ @@ -260,11 +260,11 @@ func NewStreamHandler(writer *StreamWriter, sessionID string) *SSEHandler { } } -// HandleText 处理纯文本消息(如最终答案) +// HandleText handles plain text messages (e.g., final answers) func (sh *SSEHandler) HandleText(text string, index int) error { - // 如果还没有开始内容块,先发送消息开始和内容块开始事件 + // Send message start and content block start events if content hasn't started if !sh.ContentStarted { - // 发送消息开始事件 + // Send message start event msg := &Message{ ID: sh.messageID, Type: "message", @@ -275,7 +275,7 @@ func (sh *SSEHandler) HandleText(text string, index int) error { return err } - // 发送内容块开始事件 + // Send content block start event contentBlock := &ContentBlock{ Type: "text", Text: "", @@ -287,7 +287,7 @@ func (sh *SSEHandler) HandleText(text string, index int) error { sh.ContentStarted = true } - // 发送文本内容作为增量 + // Send text content as delta if text != "" { delta := &Delta{ Type: TextDelta, @@ -301,11 +301,11 @@ func (sh *SSEHandler) HandleText(text string, index int) error { return nil } -// HandleStreamChunk 处理流式数据块 +// HandleStreamChunk handles streaming data chunks func (sh *SSEHandler) HandleStreamChunk(chunk schema.StreamChunk) error { - // 处理第一个chunk时发送消息开始和内容块开始事件 + // Send message start and content block start events when processing first chunk if !sh.ContentStarted { - // 发送消息开始事件 + // Send message start event msg := &Message{ ID: sh.messageID, Type: "message", @@ -316,7 +316,7 @@ func (sh *SSEHandler) HandleStreamChunk(chunk schema.StreamChunk) error { return err } - // 发送内容块开始事件 + // Send content block start event contentBlock := &ContentBlock{ Type: "text", Text: "", @@ -328,9 +328,9 @@ func (sh *SSEHandler) HandleStreamChunk(chunk schema.StreamChunk) error { sh.ContentStarted = true } - // 根据Agent的流式输出格式处理 + // Process based on Agent's streaming output format if chunk.Chunk != nil { - // 直接获取增量文本 + // Get delta text directly deltaText := chunk.Chunk.Text() if deltaText != "" { delta := &Delta{ @@ -346,9 +346,9 @@ func (sh *SSEHandler) HandleStreamChunk(chunk schema.StreamChunk) error { return nil } -// HandleContentBlockStop 处理内容块结束事件 +// HandleContentBlockStop handles content block stop event func (sh *SSEHandler) HandleContentBlockStop(index int) error { - // 只有在已经开始内容块的情况下才发送 content_block_stop + // Only send content_block_stop if content has started if !sh.ContentStarted { return nil } @@ -365,7 +365,7 @@ func (sh *SSEHandler) MessageDeltaWithUsage(stopReason string, output schema.Sch return nil } -// FinishStream 完成流式响应,发送结束事件 +// FinishStream completes streaming response and sends stop event func (sh *SSEHandler) FinishStream() error { if err := sh.writer.WriteMessageStop(); err != nil { return err @@ -374,7 +374,7 @@ func (sh *SSEHandler) FinishStream() error { return nil } -// HandleError 处理错误情况,发送错误事件 +// HandleError handles error conditions and sends error event func (sh *SSEHandler) HandleError(errorType, errorMessage string) { errorInfo := &ErrorInfo{ Type: errorType, diff --git a/ai/component/server/factory.go b/ai/component/server/factory.go new file mode 100644 index 00000000..456a9ba5 --- /dev/null +++ b/ai/component/server/factory.go @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "dubbo-admin-ai/runtime" + "fmt" + + "gopkg.in/yaml.v3" +) + +// ServerFactory component factory function (explicit registration, does not use init) +func ServerFactory(spec *yaml.Node) (runtime.Component, error) { + if spec == nil { + return nil, fmt.Errorf("spec is nil") + } + + var cfg ServerSpec + if err := spec.Decode(&cfg); err != nil { + return nil, fmt.Errorf("failed to decode server spec: %w", err) + } + + return NewServerComponent( + cfg.Port, + cfg.Host, + cfg.Debug, + cfg.CORSOrigins, + cfg.ReadTimeout, + cfg.WriteTimeout, + ) +} diff --git a/ai/component/server/server.yaml b/ai/component/server/server.yaml new file mode 100644 index 00000000..3cb531d6 --- /dev/null +++ b/ai/component/server/server.yaml @@ -0,0 +1,8 @@ +type: server +spec: + port: 8880 # Server port + host: "localhost" # Server host + debug: false # Debug mode + cors_origins: ["*"] # CORS origins + read_timeout: 30 # Read timeout in seconds + write_timeout: 30 # Write timeout in seconds diff --git a/ai/component/server/test/server_test.go b/ai/component/server/test/server_test.go new file mode 100644 index 00000000..30f611e2 --- /dev/null +++ b/ai/component/server/test/server_test.go @@ -0,0 +1,34 @@ +package servertest + +import ( + "strings" + "testing" + + compServer "dubbo-admin-ai/component/server" +) + +func TestServerComponent_Validate(t *testing.T) { + tests := []struct { + name string + port int + readTimeout int + writeTimeout int + errContain string + }{ + {name: "port_range", port: 70000, readTimeout: 30, writeTimeout: 30, errContain: "port"}, + {name: "read_timeout_positive", port: 8080, readTimeout: 0, writeTimeout: 30, errContain: "timeout"}, + {name: "write_timeout_positive", port: 8080, readTimeout: 30, writeTimeout: 0, errContain: "timeout"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + comp, err := compServer.NewServerComponent(tt.port, "0.0.0.0", false, []string{"*"}, tt.readTimeout, tt.writeTimeout) + if err != nil { + t.Fatalf("NewServerComponent() error: %v", err) + } + if err := comp.Validate(); err == nil || !strings.Contains(err.Error(), tt.errContain) { + t.Fatalf("expected %q validation error, got %v", tt.errContain, err) + } + }) + } +} diff --git a/ai/component/tools/component.go b/ai/component/tools/component.go new file mode 100644 index 00000000..4d56f10d --- /dev/null +++ b/ai/component/tools/component.go @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tools + +import ( + "dubbo-admin-ai/component/memory" + "dubbo-admin-ai/component/tools/engine" + "dubbo-admin-ai/runtime" + "fmt" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" +) + +type ToolsComponent struct { + registry *genkit.Genkit + history *memory.HistoryMemory + toolManagers []engine.ToolManager + toolRefs []ai.ToolRef + config ToolConfig +} + +func NewToolsComponent(config ToolConfig) (runtime.Component, error) { + return &ToolsComponent{ + config: config, + }, nil +} + +func (t *ToolsComponent) Name() string { + return "tools" +} + +func (t *ToolsComponent) Validate() error { + if t.config.EnableMCPTools && t.config.MCPHostName == "" { + return fmt.Errorf("mcp_host_name is required when mcp tools are enabled") + } + if t.config.MCPTimeout <= 0 { + return fmt.Errorf("mcp_timeout must be greater than 0") + } + if t.config.MCPMaxRetries < 0 { + return fmt.Errorf("mcp_max_retries must be >= 0") + } + return nil +} + +func (t *ToolsComponent) Init(rt *runtime.Runtime) error { + t.registry = rt.GetGenkitRegistry() + + memoryComp, err := rt.GetComponent("memory") + if err != nil { + return fmt.Errorf("memory component not found: %w", err) + } + + memComp, ok := memoryComp.(*memory.MemoryComponent) + if !ok { + return fmt.Errorf("invalid memory component type") + } + + t.history, err = memComp.GetMemory() + if err != nil { + return fmt.Errorf("failed to get history from memory component: %w", err) + } + + // Initialize Tool Managers + if t.config.EnableMockTools { + t.toolManagers = append(t.toolManagers, + engine.NewMockToolManager(t.registry)) + } + + if t.config.EnableInternalTools { + t.toolManagers = append(t.toolManagers, + engine.NewInternalToolManager(t.registry, t.history)) + } + + if t.config.EnableMCPTools { + mcpToolManager, err := engine.NewMCPToolManager(t.registry, t.config.MCPHostName) + if err != nil { + return fmt.Errorf("failed to create MCP tool manager: %w", err) + } + t.toolManagers = append(t.toolManagers, mcpToolManager) + } + + t.toolRefs = engine.NewToolRegistry(t.toolManagers...).AllToolRefs() + + rt.GetLogger().Info("Tools component initialized", + "total_tools", len(t.toolRefs), + "total_managers", len(t.toolManagers), + "mcp_host", t.config.MCPHostName) + + return nil +} + +func (t *ToolsComponent) Start() error { + return nil +} + +func (t *ToolsComponent) Stop() error { + return nil +} + +func (t *ToolsComponent) GetToolRefs() []ai.ToolRef { + return t.toolRefs +} + +func (t *ToolsComponent) SetConfig(config ToolConfig) { + t.config = config +} diff --git a/ai/component/tools/config.go b/ai/component/tools/config.go new file mode 100644 index 00000000..7fe974a0 --- /dev/null +++ b/ai/component/tools/config.go @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tools + +// ToolsConfig defines the tools configuration +type ToolsConfig struct { + Memory MemoryConfig `yaml:"memory"` + Mock MockConfig `yaml:"mock"` + Internal InternalConfig `yaml:"internal"` + MCP MCPConfig `yaml:"mcp"` +} + +// ToolConfig defines the tool configuration +type ToolConfig struct { + EnableMockTools bool `yaml:"enable_mock_tools"` + EnableInternalTools bool `yaml:"enable_internal_tools"` + EnableMCPTools bool `yaml:"enable_mcp_tools"` + MCPHostName string `yaml:"mcp_host_name"` + MCPTimeout int `yaml:"mcp_timeout"` + MCPMaxRetries int `yaml:"mcp_max_retries"` +} + +// MemoryConfig defines the memory configuration +type MemoryConfig struct { + DefaultIndex string `yaml:"default_index"` + TopK int `yaml:"top_k"` + RerankEnabled bool `yaml:"rerank_enabled"` + RerankModel string `yaml:"rerank_model"` + RerankTopN int `yaml:"rerank_top_n"` +} + +// MockConfig defines the mock configuration +type MockConfig struct { + Enabled bool `yaml:"enabled"` +} + +// InternalConfig defines the internal configuration +type InternalConfig struct { + Enabled bool `yaml:"enabled"` +} + +// MCPConfig defines the MCP configuration +type MCPConfig struct { + Enabled bool `yaml:"enabled"` + Host string `yaml:"host"` + Port int `yaml:"port"` +} + +// DefaultToolsConfig returns default tools configuration +func DefaultToolsConfig() ToolsConfig { + return ToolsConfig{ + Memory: DefaultMemoryConfig(), + Mock: MockConfig{Enabled: true}, + Internal: InternalConfig{Enabled: true}, + MCP: MCPConfig{Enabled: true, Host: "mcp_host", Port: 8080}, + } +} + +// DefaultToolConfig returns default tool configuration +func DefaultToolConfig() *ToolConfig { + return &ToolConfig{ + EnableMockTools: true, + EnableInternalTools: true, + EnableMCPTools: true, + MCPHostName: "mcp_host", + MCPTimeout: 30, + MCPMaxRetries: 3, + } +} + +// DefaultMemoryConfig returns default memory configuration +func DefaultMemoryConfig() MemoryConfig { + return MemoryConfig{ + DefaultIndex: "kube-docs", + TopK: 10, + RerankEnabled: true, + RerankModel: "rerank-v3.5", + RerankTopN: 2, + } +} diff --git a/ai/tools/mcp.go b/ai/component/tools/engine/mcp.go similarity index 99% rename from ai/tools/mcp.go rename to ai/component/tools/engine/mcp.go index 7b4cbb06..bb93e135 100644 --- a/ai/tools/mcp.go +++ b/ai/component/tools/engine/mcp.go @@ -1,4 +1,4 @@ -package tools +package engine import ( "context" diff --git a/ai/component/tools/engine/memory.go b/ai/component/tools/engine/memory.go new file mode 100644 index 00000000..f8b2ba78 --- /dev/null +++ b/ai/component/tools/engine/memory.go @@ -0,0 +1,136 @@ +package engine + +import ( + "context" + "dubbo-admin-ai/component/memory" + compRag "dubbo-admin-ai/component/rag" + "dubbo-admin-ai/config" + "fmt" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "gopkg.in/yaml.v3" +) + +const ( + GetAllMemoryTool string = "memory_all_by_session_id" + RetrieveBasicConceptFromK8SDocTool string = "retrieve_basic_concept_from_k8s_doc" +) + +type MemoryToolInput struct { + SessionID string `json:"session_id"` +} + +func defineMemoryTools(g *genkit.Genkit, history *memory.HistoryMemory) []ai.Tool { + tools := []ai.Tool{ + getAllMemoryBySession(g, history), + RetrieveBasicConceptFromK8SDoc(g), + } + return tools +} + +func getAllMemoryBySession(g *genkit.Genkit, history *memory.HistoryMemory) ai.Tool { + return genkit.DefineTool( + g, GetAllMemoryTool, "Get all history memory messages of a session by input `session_id`", + func(ctx *ai.ToolContext, input MemoryToolInput) (ToolOutput, error) { + if input.SessionID == "" { + return ToolOutput{}, fmt.Errorf("sessionID is required") + } + + if history.IsEmpty(input.SessionID) { + return ToolOutput{ + ToolName: GetAllMemoryTool, + Summary: "No memory available", + }, nil + } + + return ToolOutput{ + ToolName: GetAllMemoryTool, + Result: history.AllMemory(input.SessionID), + Summary: "", + }, nil + }, + ) +} + +type K8SRAGQueryInput struct { + Querys []string `json:"query"` +} + +const ( + K8S_CONCEPTS_NAMESPACE string = "concepts" +) + +func RetrieveBasicConceptFromK8SDoc(g *genkit.Genkit) ai.Tool { + return genkit.DefineTool( + g, RetrieveBasicConceptFromK8SDocTool, "Retrieve the basic kubernetes concepts from RAG", + func(ctx *ai.ToolContext, input K8SRAGQueryInput) (ToolOutput, error) { + if input.Querys == nil { + return ToolOutput{}, fmt.Errorf("query is required") + } + + // TODO(memory-tool, 2026-02-24): Get configuration from Runtime instead of hardcoded values + // Current: backend="dev", indexName="k8s", topK=10 + // Should be: Read from runtime config + backend := "dev" + indexName := "k8s" + topK := 10 + rerankEnabled := false + rerankTopN := 2 + rerankModel := "rerank-v3.5" + embeddingModel := "qwen3-embedding" + + // Build configuration using standard Config pattern + cfg := &compRag.RAGSpec{ + Embedder: &config.Config{ + Type: "genkit", + Spec: encodeToYAMLNode(&compRag.EmbedderSpec{Model: embeddingModel}), + }, + Indexer: &config.Config{ + Type: backend, + Spec: encodeToYAMLNode(&compRag.IndexerSpec{}), + }, + Retriever: &config.Config{ + Type: backend, + Spec: encodeToYAMLNode(&compRag.RetrieverSpec{}), + }, + } + if rerankEnabled { + cfg.Reranker = &config.Config{ + Type: "cohere", + Spec: encodeToYAMLNode(&compRag.RerankerSpec{ + Enabled: true, + Model: rerankModel, + }), + } + } + + sys, err := compRag.BuildRAGFromSpec(context.Background(), g, cfg) + if err != nil { + return ToolOutput{}, fmt.Errorf("failed to build RAG system: %w", err) + } + + retrieveOpts := []compRag.RetrieveOption{compRag.WithTopK(topK), compRag.WithTargetIndex(indexName)} + if rerankEnabled { + retrieveOpts = append(retrieveOpts, compRag.WithTopN(rerankTopN)) + } + results, err := sys.Retrieve(context.Background(), K8S_CONCEPTS_NAMESPACE, input.Querys, retrieveOpts...) + if err != nil { + return ToolOutput{}, fmt.Errorf("failed to retrieve from RAG: %w", err) + } + + return ToolOutput{ + ToolName: RetrieveBasicConceptFromK8SDocTool, + Result: results, + Summary: fmt.Sprintf("Retrieved %d results", len(results)), + }, nil + }, + ) +} + +// encodeToYAMLNode converts a struct to yaml.Node +func encodeToYAMLNode(v any) yaml.Node { + var node yaml.Node + node.Encode(v) + return node +} diff --git a/ai/tools/mock_tools.go b/ai/component/tools/engine/mock_tools.go similarity index 85% rename from ai/tools/mock_tools.go rename to ai/component/tools/engine/mock_tools.go index f6cc47df..38b6e8f1 100644 --- a/ai/tools/mock_tools.go +++ b/ai/component/tools/engine/mock_tools.go @@ -1,4 +1,4 @@ -package tools +package engine import ( "fmt" @@ -44,7 +44,7 @@ func prometheusQueryServiceLatency(ctx *ai.ToolContext, input PrometheusServiceL return ToolOutput{ ToolName: "prometheus_query_service_latency", - Summary: fmt.Sprintf("服务 %s 在过去%d分钟内的 P%.0f 延迟为 %dms", input.ServiceName, input.TimeRangeMinutes, input.Quantile*100, output.ValueMillis), + Summary: fmt.Sprintf("Service %s had a P%.0f latency of %dms over the past %d minutes", input.ServiceName, input.Quantile*100, output.ValueMillis, input.TimeRangeMinutes), Result: output, }, nil } @@ -80,7 +80,7 @@ func prometheusQueryServiceTraffic(ctx *ai.ToolContext, input PrometheusServiceT return ToolOutput{ ToolName: "prometheus_query_service_traffic", - Summary: fmt.Sprintf("服务 %s 的 QPS 为 %.1f, 错误率为 %.1f%%", input.ServiceName, output.RequestRateQPS, output.ErrorRatePercentage), + Summary: fmt.Sprintf("Service %s has a QPS of %.1f and error rate of %.1f%%", input.ServiceName, output.RequestRateQPS, output.ErrorRatePercentage), Result: output, }, nil } @@ -138,7 +138,7 @@ func queryTimeseriesDatabase(ctx *ai.ToolContext, input QueryTimeseriesDatabaseI return ToolOutput{ ToolName: "query_timeseries_database", - Summary: "查询返回了 2 个时间序列", + Summary: "Query returned 2 time series", Result: output, }, nil } @@ -191,7 +191,7 @@ func applicationPerformanceProfiling(ctx *ai.ToolContext, input ApplicationPerfo return ToolOutput{ ToolName: "application_performance_profiling", - Summary: "性能分析显示,45.5%的CPU时间消耗在数据库查询调用链上", + Summary: "Performance analysis shows that 45.5% of CPU time is consumed by database query call chain", Result: output, }, nil } @@ -264,7 +264,7 @@ func traceDependencyView(ctx *ai.ToolContext, input TraceDependencyViewInput) (T return ToolOutput{ ToolName: "trace_dependency_view", - Summary: fmt.Sprintf("服务 %s 的上下游依赖关系查询完成", input.ServiceName), + Summary: fmt.Sprintf("Upstream and downstream dependency query for service %s completed", input.ServiceName), Result: output, }, nil } @@ -317,7 +317,7 @@ func traceLatencyAnalysis(ctx *ai.ToolContext, input TraceLatencyAnalysisInput) return ToolOutput{ ToolName: "trace_latency_analysis", - Summary: "平均总延迟为 3200ms。瓶颈已定位,95.3% 的延迟来自对下游 'mysql-orders-db' 的调用", + Summary: "Average total latency is 3200ms. Bottleneck identified: 95.3% of latency comes from calls to downstream 'mysql-orders-db'", Result: output, }, nil } @@ -356,7 +356,7 @@ func databaseConnectionPoolAnalysis(ctx *ai.ToolContext, input DatabaseConnectio return ToolOutput{ ToolName: "database_connection_pool_analysis", - Summary: "数据库连接池已完全耗尽 (100/100),当前有 58 个请求正在排队等待连接", + Summary: "Database connection pool is fully exhausted (100/100), currently 58 requests are queued waiting for connections", Result: output, }, nil } @@ -472,7 +472,7 @@ func dubboServiceStatus(ctx *ai.ToolContext, input DubboServiceStatusInput) (Too return ToolOutput{ ToolName: "dubbo_service_status", - Summary: fmt.Sprintf("服务 %s 的提供者和消费者状态查询完成", input.ServiceName), + Summary: fmt.Sprintf("Provider and consumer status query for service %s completed", input.ServiceName), Result: output, }, nil } @@ -526,7 +526,7 @@ func queryLogDatabase(ctx *ai.ToolContext, input QueryLogDatabaseInput) (ToolOut return ToolOutput{ ToolName: "query_log_database", - Summary: fmt.Sprintf("在过去%d分钟内,发现 152 条关于 '%s' 的日志条目", input.TimeRangeMinutes, input.Keyword), + Summary: fmt.Sprintf("Found 152 log entries about '%s' in the past %d minutes", input.Keyword, input.TimeRangeMinutes), Result: output, }, nil } @@ -579,7 +579,7 @@ func searchArchivedLogs(ctx *ai.ToolContext, input SearchArchivedLogsInput) (Too return ToolOutput{ ToolName: "search_archived_logs", - Summary: "在归档的日志文件中,发现了多条查询,搜索了 5 个文件,找到了 2 条匹配行", + Summary: "Found multiple queries in archived log files, searched 5 files and found 2 matching lines", Result: output, }, nil } @@ -624,7 +624,7 @@ func queryKnowledgeBase(ctx *ai.ToolContext, input QueryKnowledgeBaseInput) (Too return ToolOutput{ ToolName: "query_knowledge_base", - Summary: fmt.Sprintf("知识库查询 '%s' 完成", input.QueryText), + Summary: fmt.Sprintf("Knowledge base query '%s' completed", input.QueryText), Result: output, }, nil } @@ -640,19 +640,19 @@ type MockToolManager struct { func NewMockToolManager(g *genkit.Genkit) *MockToolManager { tools := []ai.Tool{ - genkit.DefineTool(g, "prometheus_query_service_latency", "查询指定服务在特定时间范围内的 P95/P99 延迟指标", prometheusQueryServiceLatency), - genkit.DefineTool(g, "prometheus_query_service_traffic", "查询指定服务在特定时间范围内的请求率 (QPS) 和错误率 (Error Rate)", prometheusQueryServiceTraffic), - genkit.DefineTool(g, "query_timeseries_database", "执行一条完整的 PromQL 查询语句,用于进行普罗米修斯历史数据的深度或自定义分析", queryTimeseriesDatabase), - genkit.DefineTool(g, "application_performance_profiling", "对指定服务的指定实例(Pod)进行性能剖析,以结构化文本格式返回消耗CPU最多的函数调用栈", applicationPerformanceProfiling), - genkit.DefineTool(g, "jvm_performance_analysis", "检查指定Java服务的JVM状态,特别是GC(垃圾回收)活动", jvmPerformanceAnalysis), - genkit.DefineTool(g, "trace_dependency_view", "基于链路追踪数据,查询指定服务的上下游依赖关系", traceDependencyView), - genkit.DefineTool(g, "trace_latency_analysis", "分析指定服务在某时间范围内的链路追踪数据,定位延迟最高的下游调用", traceLatencyAnalysis), - genkit.DefineTool(g, "database_connection_pool_analysis", "查询指定服务连接数据库的连接池状态", databaseConnectionPoolAnalysis), - genkit.DefineTool(g, "kubernetes_get_pod_resources", "使用类似 kubectl 的功能,获取指定服务所有Pod的CPU和内存的静态配置(Limits/Requests)和动态使用情况", kubernetesGetPodResources), - genkit.DefineTool(g, "dubbo_service_status", "使用类似 dubbo-admin 的命令,查询指定Dubbo服务的提供者和消费者列表及其状态", dubboServiceStatus), - genkit.DefineTool(g, "query_log_database", "查询已索引的日志数据库(如Elasticsearch, Loki),用于实时或近实时的日志分析", queryLogDatabase), - genkit.DefineTool(g, "search_archived_logs", "在归档的日志文件(如存储在S3或服务器文件系统的.log.gz文件)中进行文本搜索(类似grep)", searchArchivedLogs), - genkit.DefineTool(g, "query_knowledge_base", "在向量数据库中查询与问题相关的历史故障报告或解决方案文档", queryKnowledgeBase), + genkit.DefineTool(g, "prometheus_query_service_latency", "Query P95/P99 latency metrics for a specific service within a time range", prometheusQueryServiceLatency), + genkit.DefineTool(g, "prometheus_query_service_traffic", "Query request rate (QPS) and error rate for a specific service within a time range", prometheusQueryServiceTraffic), + genkit.DefineTool(g, "query_timeseries_database", "Execute a complete PromQL query for deep or custom analysis of Prometheus historical data", queryTimeseriesDatabase), + genkit.DefineTool(g, "application_performance_profiling", "Perform performance profiling on a specific service instance (Pod) and return the function call stack consuming the most CPU in structured text format", applicationPerformanceProfiling), + genkit.DefineTool(g, "jvm_performance_analysis", "Check the JVM status of a specific Java service, especially GC (garbage collection) activity", jvmPerformanceAnalysis), + genkit.DefineTool(g, "trace_dependency_view", "Query upstream and downstream dependencies of a specific service based on tracing data", traceDependencyView), + genkit.DefineTool(g, "trace_latency_analysis", "Analyze tracing data for a specific service within a time range to identify downstream calls with highest latency", traceLatencyAnalysis), + genkit.DefineTool(g, "database_connection_pool_analysis", "Query the connection pool status of a specific service's database connection", databaseConnectionPoolAnalysis), + genkit.DefineTool(g, "kubernetes_get_pod_resources", "Use kubectl-like functionality to get CPU and memory static configuration (limits/requests) and dynamic usage for all pods of a specific service", kubernetesGetPodResources), + genkit.DefineTool(g, "dubbo_service_status", "Use dubbo-admin-like commands to query the provider and consumer lists and their status for a specific Dubbo service", dubboServiceStatus), + genkit.DefineTool(g, "query_log_database", "Query indexed log databases (such as Elasticsearch, Loki) for real-time or near real-time log analysis", queryLogDatabase), + genkit.DefineTool(g, "search_archived_logs", "Perform text search (similar to grep) in archived log files (such as .log.gz files stored in S3 or server file system)", searchArchivedLogs), + genkit.DefineTool(g, "query_knowledge_base", "Query vector databases for historical failure reports or solution documents related to the question", queryKnowledgeBase), } return &MockToolManager{ diff --git a/ai/tools/tools.go b/ai/component/tools/engine/tools.go similarity index 88% rename from ai/tools/tools.go rename to ai/component/tools/engine/tools.go index 61ba8ab0..c12fe206 100644 --- a/ai/tools/tools.go +++ b/ai/component/tools/engine/tools.go @@ -1,9 +1,9 @@ -package tools +package engine import ( "context" - "dubbo-admin-ai/manager" - "dubbo-admin-ai/memory" + "dubbo-admin-ai/component/memory" + rt "dubbo-admin-ai/runtime" "fmt" "github.com/firebase/genkit/go/ai" @@ -41,7 +41,7 @@ func Call(g *genkit.Genkit, mcp *MCPToolManager, toolName string, input any) (to if rawToolOutput == nil { return toolOutput, fmt.Errorf("tool %s is unavailable", toolName) } - manager.GetLogger().Info("Tool output:", "output", rawToolOutput) + rt.GetLogger().Info("Tool output:", "output", rawToolOutput) if isMCPTool { toolOutput = ToolOutput{ @@ -89,9 +89,9 @@ type InternalToolManager struct { tools []ai.Tool } -func NewInternalToolManager(g *genkit.Genkit, history *memory.History) *InternalToolManager { +func NewInternalToolManager(g *genkit.Genkit, historyMem *memory.HistoryMemory) *InternalToolManager { var tools []ai.Tool - tools = append(tools, defineMemoryTools(g, history)...) + tools = append(tools, defineMemoryTools(g, historyMem)...) return &InternalToolManager{ registry: g, tools: tools, diff --git a/ai/component/tools/factory.go b/ai/component/tools/factory.go new file mode 100644 index 00000000..a48aa9ed --- /dev/null +++ b/ai/component/tools/factory.go @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tools + +import ( + "dubbo-admin-ai/runtime" + "fmt" + + "gopkg.in/yaml.v3" +) + +// ToolsFactory creates a tools component (explicit registration, no init) +func ToolsFactory(spec *yaml.Node) (runtime.Component, error) { + var cfg ToolConfig + if err := spec.Decode(&cfg); err != nil { + return nil, fmt.Errorf("failed to decode tools spec: %w", err) + } + + return NewToolsComponent(cfg) +} diff --git a/ai/component/tools/test/tools_test.go b/ai/component/tools/test/tools_test.go new file mode 100644 index 00000000..265006b9 --- /dev/null +++ b/ai/component/tools/test/tools_test.go @@ -0,0 +1,23 @@ +package toolstest + +import ( + "strings" + "testing" + + compTools "dubbo-admin-ai/component/tools" +) + +func TestToolsComponent_Validate(t *testing.T) { + comp, err := compTools.NewToolsComponent(compTools.ToolConfig{ + EnableMCPTools: true, + MCPHostName: "", + MCPTimeout: 30, + MCPMaxRetries: 1, + }) + if err != nil { + t.Fatalf("NewToolsComponent() error: %v", err) + } + if err := comp.Validate(); err == nil || !strings.Contains(err.Error(), "mcp_host_name") { + t.Fatalf("expected mcp_host_name validation error, got %v", err) + } +} diff --git a/ai/component/tools/tools.yaml b/ai/component/tools/tools.yaml new file mode 100644 index 00000000..7758a553 --- /dev/null +++ b/ai/component/tools/tools.yaml @@ -0,0 +1,8 @@ +type: tools +spec: + enable_mock_tools: true + enable_internal_tools: true + enable_mcp_tools: false # 测试环境禁用MCP工具,避免依赖外部服务 + mcp_host_name: "mcp_host" + mcp_timeout: 30 + mcp_max_retries: 3 diff --git a/ai/config.yaml b/ai/config.yaml new file mode 100644 index 00000000..9f2cf244 --- /dev/null +++ b/ai/config.yaml @@ -0,0 +1,10 @@ +project: dubbo-admin-ai +version: 1.0.0 +components: + logger: component/logger/logger.yaml + models: component/models/models.yaml + server: component/server/server.yaml + memory: component/memory/memory.yaml + tools: component/tools/tools.yaml + rag: component/rag/rag.yaml + agent: component/agent/agent.yaml diff --git a/ai/config/config.go b/ai/config/config.go index ecce0788..06bb836e 100644 --- a/ai/config/config.go +++ b/ai/config/config.go @@ -1,40 +1,19 @@ package config import ( - "path/filepath" - "runtime" - - "dubbo-admin-ai/plugins/dashscope" - "dubbo-admin-ai/plugins/model" + "gopkg.in/yaml.v3" ) -var ( - // API keys - GEMINI_API_KEY string - SILICONFLOW_API_KEY string - DASHSCOPE_API_KEY string - PINECONE_API_KEY string - COHERE_API_KEY string - PROMETHEUS_URL string - - // Configuration - // Automatically get project root directory - _, b, _, _ = runtime.Caller(0) - PROJECT_ROOT = filepath.Join(filepath.Dir(b), "..") +// Config component configuration structure +type Config struct { + Type string `yaml:"type"` + Spec yaml.Node `yaml:"spec"` +} - PROMPT_DIR_PATH string = filepath.Join(PROJECT_ROOT, "prompts") - DEFAULT_MODEL *model.Model = dashscope.Qwen_max - EMBEDDING_MODEL *model.Embedder = dashscope.Qwen3_embedding -) +func (c *Config) GetType() string { + return c.Type +} -const ( - MAX_REACT_ITERATIONS int = 10 - STAGE_CHANNEL_BUFFER_SIZE int = 5 - PINECONE_INDEX_NAME string = "dubbot" - MCP_HOST_NAME string = "mcp_host" - K8S_RAG_INDEX string = "kube-docs" - RAG_TOP_K int = 10 - RERANK_TOP_N int = 2 - RERANK_MODEL string = "rerank-v3.5" - RERANK_ENABLE bool = true -) +func (c *Config) GetSpec() *yaml.Node { + return &c.Spec +} diff --git a/ai/config/jsonschema.go b/ai/config/jsonschema.go new file mode 100644 index 00000000..f440965d --- /dev/null +++ b/ai/config/jsonschema.go @@ -0,0 +1,175 @@ +package config + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/xeipuuv/gojsonschema" +) + +// schemaEngine manages the loading, caching, and compilation of JSON schemas. +type schemaEngine struct { + baseDir string + mu sync.RWMutex + cache map[string]*gojsonschema.Schema + rootObjs map[string]map[string]any +} + +// NewSchemaEngine creates a new schemaEngine with the specified base directory. +func NewSchemaEngine(baseDir string) *schemaEngine { + return &schemaEngine{ + baseDir: baseDir, + cache: make(map[string]*gojsonschema.Schema), + rootObjs: make(map[string]map[string]any), + } +} + +// ApplyDefaultsAndValidate applies default values and validates. +// This method modifies doc in-place by applying defaults, then validates it. +// +// Parameters: +// - doc: The configuration document (WILL BE MODIFIED) +// - schemaFile: The schema filename +// +// Returns: +// - The modified doc (same map as input) +// - Error if validation fails +func (e *schemaEngine) ApplyDefaultsAndValidate(doc map[string]any, schemaFile string) (map[string]any, error) { + compiled, rootObj, err := e.loadSchema(schemaFile) + if err != nil { + return nil, err + } + + // Apply defaults in-place (modifies doc) + applyDefaults(rootObj, rootObj, doc) + + // Validate the result + if err := validateJSONSchema(compiled, doc); err != nil { + return nil, err + } + + return doc, nil +} + +func (e *schemaEngine) loadSchema(fileName string) (*gojsonschema.Schema, map[string]any, error) { + e.mu.RLock() + compiled, hasCached := e.cache[fileName] + rootObj, hasRoot := e.rootObjs[fileName] + e.mu.RUnlock() + + if hasCached && hasRoot { + return compiled, rootObj, nil + } + + fullPath := filepath.Join(e.baseDir, fileName) + raw, err := os.ReadFile(fullPath) + if err != nil { + return nil, nil, fmt.Errorf("structural error: failed to read schema file %s: %w", fileName, err) + } + + var schemaObj map[string]any + if err := json.Unmarshal(raw, &schemaObj); err != nil { + return nil, nil, fmt.Errorf("structural error: failed to parse schema file %s: %w", fileName, err) + } + + compiled, err = gojsonschema.NewSchema(gojsonschema.NewBytesLoader(raw)) + if err != nil { + return nil, nil, fmt.Errorf("structural error: failed to compile schema file %s: %w", fileName, err) + } + + e.mu.Lock() + e.cache[fileName] = compiled + e.rootObjs[fileName] = schemaObj + e.mu.Unlock() + return compiled, schemaObj, nil +} + +func validateJSONSchema(compiled *gojsonschema.Schema, doc any) error { + docRaw, err := json.Marshal(doc) + if err != nil { + return fmt.Errorf("failed to marshal config for schema validation: %w", err) + } + + result, err := compiled.Validate(gojsonschema.NewBytesLoader(docRaw)) + if err != nil { + return fmt.Errorf("failed to validate schema: %w", err) + } + if result.Valid() { + return nil + } + + errMsgs := make([]string, 0, len(result.Errors())) + for _, e := range result.Errors() { + field := e.Field() + if field == "(root)" || field == "" { + field = "root" + } + errMsgs = append(errMsgs, fmt.Sprintf("%s: %s", field, e.Description())) + } + return fmt.Errorf("structural error: %s", strings.Join(errMsgs, "; ")) +} + +// applyDefaults recursively applies default values from schema to value +// Modifies value in-place +func applyDefaults(root map[string]any, schema map[string]any, value any) { + resolved := resolveSchemaRef(root, schema) + + switch v := value.(type) { + case map[string]any: + props, _ := resolved["properties"].(map[string]any) + for key, propVal := range props { + propSchema, ok := propVal.(map[string]any) + if !ok { + continue + } + propSchema = resolveSchemaRef(root, propSchema) + + // Apply default value if property is missing + if _, exists := v[key]; !exists { + if defVal, hasDefault := propSchema["default"]; hasDefault { + v[key] = defVal + } + } + + // Recursively apply defaults to nested properties + if child, exists := v[key]; exists { + applyDefaults(root, propSchema, child) + } + } + + case []any: + if items, ok := resolved["items"].(map[string]any); ok { + items = resolveSchemaRef(root, items) + for i := range v { + applyDefaults(root, items, v[i]) + } + } + } +} + +// resolveSchemaRef resolves JSON Pointer references ($ref) within a schema +func resolveSchemaRef(root map[string]any, schema map[string]any) map[string]any { + if ref, ok := schema["$ref"].(string); ok && strings.HasPrefix(ref, "#/") { + parts := strings.Split(strings.TrimPrefix(ref, "#/"), "/") + var cur any = root + for _, p := range parts { + obj, ok := cur.(map[string]any) + if !ok { + return schema + } + next, ok := obj[p] + if !ok { + return schema + } + cur = next + } + if resolved, ok := cur.(map[string]any); ok { + return resolveSchemaRef(root, resolved) + } + } + return schema +} diff --git a/ai/config/loader.go b/ai/config/loader.go new file mode 100644 index 00000000..fbc2228b --- /dev/null +++ b/ai/config/loader.go @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package config + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + + "github.com/joho/godotenv" + "gopkg.in/yaml.v3" +) + +// mainConfig defines the structure of the main configuration file (internal use only) +type mainConfig struct { + Project string `yaml:"project"` // Project name + Version string `yaml:"version"` // Project version + Components map[string]any `yaml:"components"` // Component configuration path mapping +} + +// LoadedConfig contains all loaded configurations +type LoadedConfig struct { + Project string + Version string + SchemaDir string // Schema directory path (from environment) + Components map[string]*Config // Component configurations (key: component name) +} + +// Loader handles configuration file loading +type Loader struct { + configFile string + configDir string // Directory containing configuration files + schemaDir string + schemaEngine *schemaEngine // Schema validation engine +} + +// NewLoader creates a new configuration loader +func NewLoader(configFile string) *Loader { + return &Loader{ + configFile: configFile, + configDir: filepath.Dir(configFile), + } +} + +// Load loads and parses all configurations (main entry point) +// Features: +// 1. Load .env file if exists +// 2. Initialize schema engine from SCHEMA_DIR (or default schema/json) +// 3. Read and validate main configuration file +// 4. Load all component configurations (ordering handled by caller) +func (l *Loader) Load() (*LoadedConfig, error) { + // 1. Initialize schema engine + if err := l.ensureSchemaEngine(); err != nil { + return nil, err + } + + // 2. Read and validate main configuration file + mainCfg, err := l.loadMainConfig() + if err != nil { + return nil, fmt.Errorf("failed to load main config: %w", err) + } + + // 3. Load all component configurations + components, err := l.loadAllComponents(mainCfg) + if err != nil { + return nil, fmt.Errorf("failed to load components: %w", err) + } + + return &LoadedConfig{ + Project: mainCfg.Project, + Version: mainCfg.Version, + SchemaDir: l.schemaDir, + Components: components, + }, nil +} + +// LoadComponent loads and validates a single component config file. +// It reuses the same structural pipeline as Load(): +// yaml.Unmarshal -> schema defaults+validation -> decodeYAMLStrict(KnownFields=true). +func (l *Loader) LoadComponent(configPath string) (*Config, error) { + if err := l.ensureSchemaEngine(); err != nil { + return nil, err + } + return l.loadComponent(configPath) +} + +// loadEnvFile loads .env file if exists +func (l *Loader) loadEnvFile() error { + // Load .env file if exists (missing .env is not an error) + if err := godotenv.Load(); err != nil { + // File not found is acceptable, only log for other errors + // os.IsNotExist(err) always returns false for godotenv + // So we check the error message instead + if err.Error() != "open .env: no such file or directory" && + err.Error() != "open .env: file does not exist" { + return err + } + // .env file not found is normal, continue execution + } + return nil +} + +func (l *Loader) ensureSchemaEngine() error { + if l.schemaEngine != nil { + return nil + } + + if err := l.loadEnvFile(); err != nil { + return fmt.Errorf("failed to load .env file: %w", err) + } + + schemaDir, err := l.resolveSchemaDir() + if err != nil { + return err + } + + l.schemaDir = schemaDir + l.schemaEngine = NewSchemaEngine(schemaDir) + return nil +} + +func (l *Loader) resolveSchemaDir() (string, error) { + schemaDir := os.Getenv("SCHEMA_DIR") + if schemaDir == "" { + schemaDir = filepath.Join(l.configDir, "schema", "json") + } + + if !filepath.IsAbs(schemaDir) { + schemaDir = filepath.Join(l.configDir, schemaDir) + } + + absDir, err := filepath.Abs(schemaDir) + if err != nil { + return "", fmt.Errorf("structural error: failed to resolve schema directory: %w", err) + } + + info, err := os.Stat(absDir) + if err != nil { + return "", fmt.Errorf("structural error: schema directory not found: %s", absDir) + } + if !info.IsDir() { + return "", fmt.Errorf("structural error: schema directory is not a directory: %s", absDir) + } + + return absDir, nil +} + +// loadMainConfig loads and validates the main configuration file. +func (l *Loader) loadMainConfig() (*mainConfig, error) { + data, err := os.ReadFile(l.configFile) + if err != nil { + return nil, fmt.Errorf("failed to read config file: %w", err) + } + + var raw map[string]any + if err := yaml.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("parse error: %w", err) + } + + normalized, err := l.schemaEngine.ApplyDefaultsAndValidate(raw, "main.schema.json") + if err != nil { + return nil, err + } + + normalizedData, err := yaml.Marshal(normalized) + if err != nil { + return nil, fmt.Errorf("failed to marshal normalized main config: %w", err) + } + + var cfg mainConfig + if err := decodeYAMLStrict(normalizedData, &cfg); err != nil { + return nil, err + } + + return &cfg, nil +} + +// loadAllComponents loads all component configurations +func (l *Loader) loadAllComponents(mainCfg *mainConfig) (map[string]*Config, error) { + components := make(map[string]*Config) + + // Iterate through component declarations in main config + for name, path := range mainCfg.Components { + switch v := path.(type) { + case string: + // Single configuration file + cfg, err := l.loadComponent(v) + if err != nil { + return nil, fmt.Errorf("failed to load component %s: %w", name, err) + } + components[name] = cfg + + case []any: + // Multiple configuration files (e.g., agents) + for i, item := range v { + pathStr, ok := item.(string) + if !ok { + return nil, fmt.Errorf("structural error: components.%s[%d] must be string, got %T", name, i, item) + } + // Use index as name suffix for multiple configs + componentName := fmt.Sprintf("%s-%d", name, i) + cfg, err := l.loadComponent(pathStr) + if err != nil { + return nil, fmt.Errorf("failed to load component %s: %w", componentName, err) + } + components[componentName] = cfg + } + default: + return nil, fmt.Errorf("structural error: components.%s must be string or []string, got %T", name, path) + } + } + + return components, nil +} + +// loadComponent loads a single component configuration +func (l *Loader) loadComponent(configPath string) (*Config, error) { + // Resolve relative path + fullPath := configPath + if !filepath.IsAbs(configPath) { + fullPath = filepath.Join(l.configDir, configPath) + } + + // Read configuration file + data, err := os.ReadFile(fullPath) + if err != nil { + return nil, fmt.Errorf("failed to read config file: %w", err) + } + + // Expand environment variables + expandedData := os.ExpandEnv(string(data)) + + var raw map[string]any + if err := yaml.Unmarshal([]byte(expandedData), &raw); err != nil { + return nil, fmt.Errorf("parse error: %w", err) + } + componentType, _ := raw["type"].(string) + if componentType == "" { + return nil, fmt.Errorf("structural error: type is required") + } + componentSchema, err := schemaFileForComponent(componentType) + if err != nil { + return nil, err + } + normalized, err := l.schemaEngine.ApplyDefaultsAndValidate(raw, componentSchema) + if err != nil { + return nil, err + } + normalizedData, err := yaml.Marshal(normalized) + if err != nil { + return nil, fmt.Errorf("failed to marshal normalized component config: %w", err) + } + + // Parse YAML + var cfg Config + if err := decodeYAMLStrict(normalizedData, &cfg); err != nil { + return nil, err + } + + return &cfg, nil +} + +func schemaFileForComponent(componentType string) (string, error) { + switch componentType { + case "logger": + return "logger.schema.json", nil + case "memory": + return "memory.schema.json", nil + case "models": + return "models.schema.json", nil + case "tools": + return "tools.schema.json", nil + case "server": + return "server.schema.json", nil + case "rag": + return "rag.schema.json", nil + case "agent": + return "agent.schema.json", nil + default: + return "", fmt.Errorf("structural error: unsupported component type: %s", componentType) + } +} + +func decodeYAMLStrict(data []byte, target any) error { + var parserCheck yaml.Node + if err := yaml.Unmarshal(data, &parserCheck); err != nil { + return fmt.Errorf("parse error: %w", err) + } + + dec := yaml.NewDecoder(bytes.NewReader(data)) + dec.KnownFields(true) + if err := dec.Decode(target); err != nil { + return fmt.Errorf("structural error: %w", err) + } + + return nil +} diff --git a/ai/config/test/loader_test.go b/ai/config/test/loader_test.go new file mode 100644 index 00000000..45349e8a --- /dev/null +++ b/ai/config/test/loader_test.go @@ -0,0 +1,297 @@ +package config_test + +import ( + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "dubbo-admin-ai/component/agent/react" + "dubbo-admin-ai/component/server" + "dubbo-admin-ai/config" +) + +func repoSchemaDir(t *testing.T) string { + t.Helper() + _, file, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("failed to get caller") + } + root := filepath.Dir(filepath.Dir(filepath.Dir(file))) + return filepath.Join(root, "schema", "json") +} + +func copySchemaDir(t *testing.T, dst string) { + t.Helper() + src := repoSchemaDir(t) + entries, err := os.ReadDir(src) + if err != nil { + t.Fatalf("read schema dir: %v", err) + } + if err := os.MkdirAll(dst, 0o755); err != nil { + t.Fatalf("mkdir schema dir: %v", err) + } + for _, e := range entries { + if e.IsDir() { + continue + } + b, err := os.ReadFile(filepath.Join(src, e.Name())) + if err != nil { + t.Fatalf("read schema file %s: %v", e.Name(), err) + } + if err := os.WriteFile(filepath.Join(dst, e.Name()), b, 0o644); err != nil { + t.Fatalf("write schema file %s: %v", e.Name(), err) + } + } +} + +func writeFile(t *testing.T, path, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", path, err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("write %s: %v", path, err) + } +} + +func mustContain(t *testing.T, err error, want string) { + t.Helper() + if err == nil || !strings.Contains(err.Error(), want) { + t.Fatalf("expected error containing %q, got: %v", want, err) + } +} + +func TestLoader_MainConfig_Parse(t *testing.T) { + tests := []struct { + name string + mainYAML string + expectLike string + }{ + {name: "parse_error", mainYAML: "project: p\nversion: v\ncomponents: [", expectLike: "parse error"}, + {name: "parse_error_priority", mainYAML: "project: p\nversion v\ncomponents:\n logger: logger.yaml\nunknown: true\n", expectLike: "parse error"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + t.Setenv("SCHEMA_DIR", repoSchemaDir(t)) + writeFile(t, filepath.Join(dir, "config.yaml"), tt.mainYAML) + loader := config.NewLoader(filepath.Join(dir, "config.yaml")) + _, err := loader.Load() + mustContain(t, err, tt.expectLike) + }) + } +} + +func TestLoader_Component_Parse(t *testing.T) { + tests := []struct { + name string + componentYML string + }{ + {name: "parse_error", componentYML: "type: logger\nspec: ["}, + {name: "parse_error_priority", componentYML: "type: logger\nspec: [\nextra: true\n"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + t.Setenv("SCHEMA_DIR", repoSchemaDir(t)) + writeFile(t, filepath.Join(dir, "config.yaml"), "project: p\nversion: v\ncomponents:\n logger: logger.yaml\n") + writeFile(t, filepath.Join(dir, "logger.yaml"), tt.componentYML) + loader := config.NewLoader(filepath.Join(dir, "config.yaml")) + _, err := loader.Load() + mustContain(t, err, "parse error") + }) + } +} + +func TestLoader_MainConfig_Structural(t *testing.T) { + tests := []struct { + name string + mainYAML string + expectLike string + }{ + {name: "unknown_field", mainYAML: "project: p\nversion: v\nunknown: true\ncomponents:\n logger: logger.yaml\n", expectLike: "structural error"}, + {name: "components_type_invalid", mainYAML: "project: p\nversion: v\ncomponents:\n logger:\n path: logger.yaml\n", expectLike: "structural error"}, + {name: "components_array_item_invalid", mainYAML: "project: p\nversion: v\ncomponents:\n agent:\n - a.yaml\n - 1\n", expectLike: "structural error"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + t.Setenv("SCHEMA_DIR", repoSchemaDir(t)) + writeFile(t, filepath.Join(dir, "config.yaml"), tt.mainYAML) + loader := config.NewLoader(filepath.Join(dir, "config.yaml")) + _, err := loader.Load() + mustContain(t, err, tt.expectLike) + }) + } +} + +func TestLoader_MainConfig_SchemaDir(t *testing.T) { + dir := t.TempDir() + schemaDst := filepath.Join(dir, "schema", "json") + copySchemaDir(t, schemaDst) + t.Setenv("SCHEMA_DIR", "") + + writeFile(t, filepath.Join(dir, "config.yaml"), "project: p\nversion: v\ncomponents:\n logger: logger.yaml\n") + writeFile(t, filepath.Join(dir, "logger.yaml"), "type: logger\nspec: {}\n") + + loader := config.NewLoader(filepath.Join(dir, "config.yaml")) + loaded, err := loader.Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + want, _ := filepath.Abs(schemaDst) + if loaded.SchemaDir != want { + t.Fatalf("SchemaDir = %q, want %q", loaded.SchemaDir, want) + } +} + +func TestLoader_Component_Structural(t *testing.T) { + tests := []struct { + name string + componentYML string + }{ + {name: "missing_type", componentYML: "spec: {}\n"}, + {name: "missing_spec", componentYML: "type: server\n"}, + {name: "unknown_top_field", componentYML: "type: server\nspec: {}\nextra: true\n"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + t.Setenv("SCHEMA_DIR", repoSchemaDir(t)) + writeFile(t, filepath.Join(dir, "comp.yaml"), tt.componentYML) + loader := config.NewLoader(filepath.Join(dir, "config.yaml")) + _, err := loader.LoadComponent("comp.yaml") + mustContain(t, err, "structural error") + }) + } +} + +func TestLoader_Component_DefaultInjection(t *testing.T) { + tests := []struct { + name string + fileName string + componentYML string + assertFn func(t *testing.T, cfg *config.Config) + }{ + { + name: "server", + fileName: "server.yaml", + componentYML: "type: server\nspec: {}\n", + assertFn: func(t *testing.T, cfg *config.Config) { + var spec server.ServerSpec + if err := cfg.Spec.Decode(&spec); err != nil { + t.Fatalf("decode server spec: %v", err) + } + if spec.Port != 8888 || spec.Host != "0.0.0.0" || spec.ReadTimeout != 30 || spec.WriteTimeout != 30 { + t.Fatalf("server defaults not injected: %+v", spec) + } + }, + }, + { + name: "agent", + fileName: "agent.yaml", + componentYML: `type: agent +spec: + default_model: qwen-max + prompt_base_path: ./prompts + stages: + - name: think-stage + flow_type: think + prompt_file: think.txt +`, + assertFn: func(t *testing.T, cfg *config.Config) { + var spec react.AgentSpec + if err := cfg.Spec.Decode(&spec); err != nil { + t.Fatalf("decode agent spec: %v", err) + } + if len(spec.Stages) != 1 { + t.Fatalf("stages len = %d, want 1", len(spec.Stages)) + } + stage := spec.Stages[0] + if stage.Temperature == 0 || stage.TopP == 0 || stage.MaxTokens == 0 || stage.Timeout == 0 { + t.Fatalf("agent stage defaults not injected: %+v", stage) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + t.Setenv("SCHEMA_DIR", repoSchemaDir(t)) + writeFile(t, filepath.Join(dir, tt.fileName), tt.componentYML) + loader := config.NewLoader(filepath.Join(dir, "config.yaml")) + cfg, err := loader.LoadComponent(tt.fileName) + if err != nil { + t.Fatalf("LoadComponent() error: %v", err) + } + tt.assertFn(t, cfg) + }) + } +} + +func TestLoader_Component_ConditionalRequired(t *testing.T) { + tests := []struct { + name string + fileName string + componentYML string + }{ + {name: "tools_mcp_enabled_require_host", fileName: "tools.yaml", componentYML: `type: tools +spec: + enable_mcp_tools: true + mcp_host_name: "" +`}, + {name: "rag_reranker_enabled_require_api_key", fileName: "rag.yaml", componentYML: `type: rag +spec: + embedder: + spec: + model: dashscope/qwen3-embedding + loader: + spec: {} + splitter: + spec: {} + indexer: + spec: {} + retriever: + spec: {} + reranker: + spec: + enabled: true +`}, + {name: "rag_splitter_oneof_branch_validation", fileName: "rag.yaml", componentYML: `type: rag +spec: + embedder: + spec: + model: dashscope/qwen3-embedding + loader: + spec: {} + splitter: + type: recursive + spec: + chunk_size: 10 + headers: + "#": h1 + indexer: + spec: {} + retriever: + spec: {} +`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + t.Setenv("SCHEMA_DIR", repoSchemaDir(t)) + writeFile(t, filepath.Join(dir, tt.fileName), tt.componentYML) + loader := config.NewLoader(filepath.Join(dir, "config.yaml")) + _, err := loader.LoadComponent(tt.fileName) + mustContain(t, err, "structural error") + }) + } +} diff --git a/ai/go.mod b/ai/go.mod index bb50c2b4..c8678ecc 100644 --- a/ai/go.mod +++ b/ai/go.mod @@ -4,18 +4,19 @@ go 1.24.1 toolchain go1.24.5 -replace github.com/firebase/genkit/go => github.com/stringl1l1l1l/genkit/go v0.0.0-20250926153048-97c88b0acd38 - require ( + github.com/cloudwego/eino v0.7.34 + github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260214075714-8f11ae8e65a2 + github.com/cloudwego/eino-ext/components/document/parser/pdf v0.0.0-20260214075714-8f11ae8e65a2 + github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260214075714-8f11ae8e65a2 + github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260214075714-8f11ae8e65a2 github.com/cohere-ai/cohere-go/v2 v2.15.3 - github.com/firebase/genkit/go v1.0.4 + github.com/firebase/genkit/go v1.2.0 github.com/gin-gonic/gin v1.11.0 github.com/gomarkdown/markdown v0.0.0-20250810172220-2e2c11897d1a github.com/joho/godotenv v1.5.1 - github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 github.com/mitchellh/mapstructure v1.5.0 github.com/openai/openai-go v1.12.0 - github.com/tmc/langchaingo v0.1.13 ) require ( @@ -25,24 +26,31 @@ require ( github.com/bytedance/sonic v1.14.1 // indirect github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect - github.com/dlclark/regexp2 v1.11.5 // indirect + github.com/dslipak/pdf v0.0.2 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/eino-contrib/jsonschema v1.0.3 // indirect github.com/gabriel-vasile/mimetype v1.4.10 // indirect github.com/gin-contrib/sse v1.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.27.0 // indirect github.com/goccy/go-json v0.10.5 // indirect + github.com/goph/emperror v0.17.2 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mark3labs/mcp-go v0.40.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/nikolalohinski/gonja v1.5.3 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect - github.com/pkoukk/tiktoken-go v0.1.8 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/quic-go v0.54.0 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect github.com/spf13/cast v1.10.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.2.0 // indirect @@ -50,13 +58,10 @@ require ( github.com/tidwall/sjson v1.2.5 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.0 // indirect - gitlab.com/golang-commonmark/html v0.0.0-20191124015941-a22733972181 // indirect - gitlab.com/golang-commonmark/linkify v0.0.0-20200225224916-64bca66f6ad3 // indirect - gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a // indirect - gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84 // indirect - gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f // indirect + github.com/yargevad/filepathx v1.0.0 // indirect go.uber.org/mock v0.5.0 // indirect golang.org/x/arch v0.21.0 // indirect + golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect golang.org/x/mod v0.27.0 // indirect golang.org/x/sync v0.17.0 // indirect golang.org/x/tools v0.36.0 // indirect @@ -73,7 +78,7 @@ require ( github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/goccy/go-yaml v1.18.0 // indirect - github.com/google/dotprompt/go v0.0.0-20250923103342-a8a91d1dff59 // indirect + github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/google/uuid v1.6.0 @@ -98,9 +103,9 @@ require ( golang.org/x/net v0.44.0 // indirect golang.org/x/sys v0.36.0 // indirect golang.org/x/text v0.29.0 // indirect - google.golang.org/genai v1.26.0 // indirect + google.golang.org/genai v1.30.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250922171735-9219d122eba9 // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.9 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + gopkg.in/yaml.v3 v3.0.1 ) diff --git a/ai/go.sum b/ai/go.sum index 33d7f6ed..8dbaed3c 100644 --- a/ai/go.sum +++ b/ai/go.sum @@ -4,42 +4,68 @@ cloud.google.com/go/auth v0.16.5 h1:mFWNQ2FEVWAliEQWpAdH80omXFokmrnbDhUS9cBywsI= cloud.google.com/go/auth v0.16.5/go.mod h1:utzRfHMP+Vv0mpOkTRQoWD2q3BatTOoWbA7gCc2dUhQ= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= github.com/aws/aws-sdk-go-v2 v1.39.1 h1:fWZhGAwVRK/fAN2tmt7ilH4PPAE11rDj7HytrmbZ2FE= github.com/aws/aws-sdk-go-v2 v1.39.1/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY= github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE= github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= +github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/cloudwego/eino v0.7.34 h1:mL1l703kPRxG0tpBAnqpo8so5/4reXd9jt+VUwwFqes= +github.com/cloudwego/eino v0.7.34/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ= +github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260214075714-8f11ae8e65a2 h1:dDyZP4dwf4DXCF2FQcwwU4FlnP7wliT57bK9yr5j1aI= +github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260214075714-8f11ae8e65a2/go.mod h1:HnxTQxmhuev6zaBl92EHUy/vEDWCuoE/OE4cTiF5JCg= +github.com/cloudwego/eino-ext/components/document/parser/pdf v0.0.0-20260214075714-8f11ae8e65a2 h1:W8+/PvKJmYHJSCUIGEax65XT0rRwI3unhsJfWyrV1GI= +github.com/cloudwego/eino-ext/components/document/parser/pdf v0.0.0-20260214075714-8f11ae8e65a2/go.mod h1:kHC3xkGM/gv3IHpOk33p75BfBaEIYATOs2XmYFKffcs= +github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260214075714-8f11ae8e65a2 h1:DrIq57NcdAClsBGUdUXbQo4r4UR0wLOzv5SMM1K1T7o= +github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260214075714-8f11ae8e65a2/go.mod h1:KVOVct4e2BQ7epDONW2QE1qU5+ccoh91FzJTs9vIJj0= +github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260214075714-8f11ae8e65a2 h1:Fc8bR5LbV+AN0ajzPMtU/8nICBnpB1re9pP03vvvUiM= +github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260214075714-8f11ae8e65a2/go.mod h1:9R0RQrQSpg1JaNnRtw7+RfRAAv0HgdE348YnrlZ6coo= github.com/cohere-ai/cohere-go/v2 v2.15.3 h1:d6m4mspLmviA5OcJzY4wRmugQhcWP1iOPjSkgyZImhs= github.com/cohere-ai/cohere-go/v2 v2.15.3/go.mod h1:MuiJkCxlR18BDV2qQPbz2Yb/OCVphT1y6nD2zYaKeR0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= -github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dslipak/pdf v0.0.2 h1:djAvcM5neg9Ush+zR6QXB+VMJzR6TdnX766HPIg1JmI= +github.com/dslipak/pdf v0.0.2/go.mod h1:2L3SnkI9cQwnAS9gfPz2iUoLC0rUZwbucpbKi5R1mUo= github.com/dusted-go/logging v1.3.0 h1:SL/EH1Rp27oJQIte+LjWvWACSnYDTqNx5gZULin0XRY= github.com/dusted-go/logging v1.3.0/go.mod h1:s58+s64zE5fxSWWZfp+b8ZV0CHyKHjamITGyuY1wzGg= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/eino-contrib/jsonschema v1.0.3 h1:2Kfsm1xlMV0ssY2nuxshS4AwbLFuqmPmzIjLVJ1Fsp0= +github.com/eino-contrib/jsonschema v1.0.3/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/firebase/genkit/go v1.2.0 h1:C31p32vdMZhhSSQQvXouH/kkcleTH4jlgFmpqlJtBS4= +github.com/firebase/genkit/go v1.2.0/go.mod h1:ru1cIuxG1s3HeUjhnadVveDJ1yhinj+j+uUh0f0pyxE= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/gabriel-vasile/mimetype v1.4.10 h1:zyueNbySn/z8mJZHLt6IPw0KoZsiQNszIpU+bX4+ZK0= github.com/gabriel-vasile/mimetype v1.4.10/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk= github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -57,12 +83,14 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/gomarkdown/markdown v0.0.0-20250810172220-2e2c11897d1a h1:l7A0loSszR5zHd/qK53ZIHMO8b3bBSmENnQ6eKnUT0A= github.com/gomarkdown/markdown v0.0.0-20250810172220-2e2c11897d1a/go.mod h1:JDGcbDT52eL4fju3sZ4TeHGsQwhG9nbDV21aMyhwPoA= -github.com/google/dotprompt/go v0.0.0-20250923103342-a8a91d1dff59 h1:EywQhHXdzYlMKD7Gxl9Ho34c8dQ0meph6FuRN9iENEY= -github.com/google/dotprompt/go v0.0.0-20250923103342-a8a91d1dff59/go.mod h1:k8cjJAQWc//ac/bMnzItyOFbfT01tgRTZGgxELCuxEQ= +github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254 h1:okN800+zMJOGHLJCgry+OGzhhtH6YrjQh1rluHmOacE= +github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254/go.mod h1:k8cjJAQWc//ac/bMnzItyOFbfT01tgRTZGgxELCuxEQ= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -74,32 +102,47 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= +github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= +github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+ExRDqGQltzXqN/xypdKP86niVn8= -github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mark3labs/mcp-go v0.40.0 h1:M0oqK412OHBKut9JwXSsj4KanSmEKpzoW8TcxoPOkAU= github.com/mark3labs/mcp-go v0.40.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a h1:v2cBA3xWKv2cIOVhnzX/gNgkNXqiHfUgJtA3r61Hf7A= github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a/go.mod h1:Y6ghKH+ZijXn5d9E7qGGZBmjitx7iitZdQiIW97EpTU= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -107,12 +150,18 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= +github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= -github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo= -github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -122,21 +171,30 @@ github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQB github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= -github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= -github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/stringl1l1l1l/genkit/go v0.0.0-20250926153048-97c88b0acd38 h1:TWcF/Q/fUWdlV5uYSCvav7hcq3yT0JvqzvFIkWPc/zc= -github.com/stringl1l1l1l/genkit/go v0.0.0-20250926153048-97c88b0acd38/go.mod h1:t7g2u7wrkC83kBeYHXhgutFmEe1mMaBDsHZM5WJWYQw= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -148,14 +206,14 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA= -github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA= github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= +github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= @@ -163,21 +221,10 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= +github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= -gitlab.com/golang-commonmark/html v0.0.0-20191124015941-a22733972181 h1:K+bMSIx9A7mLES1rtG+qKduLIXq40DAzYHtb0XuCukA= -gitlab.com/golang-commonmark/html v0.0.0-20191124015941-a22733972181/go.mod h1:dzYhVIwWCtzPAa4QP98wfB9+mzt33MSmM8wsKiMi2ow= -gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82/go.mod h1:Gn+LZmCrhPECMD3SOKlE+BOHwhOYD9j7WT9NUtkCrC8= -gitlab.com/golang-commonmark/linkify v0.0.0-20200225224916-64bca66f6ad3 h1:1Coh5BsUBlXoEJmIEaNzVAWrtg9k7/eJzailMQr1grw= -gitlab.com/golang-commonmark/linkify v0.0.0-20200225224916-64bca66f6ad3/go.mod h1:Gn+LZmCrhPECMD3SOKlE+BOHwhOYD9j7WT9NUtkCrC8= -gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a h1:O85GKETcmnCNAfv4Aym9tepU8OE0NmcZNqPlXcsBKBs= -gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a/go.mod h1:LaSIs30YPGs1H5jwGgPhLzc8vkNc/k0rDX/fEZqiU/M= -gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84 h1:qqjvoVXdWIcZCLPMlzgA7P9FZWdPGPvP/l3ef8GzV6o= -gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84/go.mod h1:IJZ+fdMvbW2qW6htJx7sLJ04FEs4Ldl/MDsJtMKywfw= -gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f h1:Wku8eEdeJqIOFHtrfkYUByc4bCaTeA6fL0UJgfEiFMI= -gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f/go.mod h1:Tiuhl+njh/JIg0uS/sOJVYi0x2HEa5rc1OAaVsb5tAs= -gitlab.com/opennota/wd v0.0.0-20180912061657-c5d65f63c638 h1:uPZaMiz6Sz0PZs3IZJWpU5qHKGNy///1pacZC9txiUI= -gitlab.com/opennota/wd v0.0.0-20180912061657-c5d65f63c638/go.mod h1:EGRJaqe2eO9XGmFtQCvV3Lm9NLico3UhFwUpCG/+mVU= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= @@ -198,27 +245,37 @@ go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= golang.org/x/arch v0.21.0 h1:iTC9o7+wP6cPWpDWkivCvQFGAHDQ59SrSxsLPcnkArw= golang.org/x/arch v0.21.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= +golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= +golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ= +golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/genai v1.26.0 h1:r4HGL54kFv/WCRMTAbZg05Ct+vXfhAbTRlXhFyBkEQo= -google.golang.org/genai v1.26.0/go.mod h1:OClfdf+r5aaD+sCd4aUSkPzJItmg2wD/WON9lQnRPaY= +google.golang.org/genai v1.30.0 h1:7021aneIvl24nEBLbtQFEWleHsMbjzpcQvkT4WcJ1dc= +google.golang.org/genai v1.30.0/go.mod h1:7pAilaICJlQBonjKKJNhftDFv3SREhZcTe9F6nRcjbg= google.golang.org/genproto/googleapis/rpc v0.0.0-20250922171735-9219d122eba9 h1:V1jCN2HBa8sySkR5vLcCSqJSTMv093Rw9EJefhQGP7M= google.golang.org/genproto/googleapis/rpc v0.0.0-20250922171735-9219d122eba9/go.mod h1:HSkG/KdJWusxU1F6CNrwNDjBMgisKxGnc5dAZfT0mjQ= google.golang.org/grpc v1.75.1 h1:/ODCNEuf9VghjgO3rqLcfg8fiOP0nSluljWFlDxELLI= @@ -228,10 +285,9 @@ google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXn gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= -sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8= diff --git a/ai/main.go b/ai/main.go index 62a3fa26..621bb7b5 100644 --- a/ai/main.go +++ b/ai/main.go @@ -5,75 +5,94 @@ import ( "flag" "fmt" "log" - "log/slog" - "net/http" "os" "os/signal" "syscall" "time" - "dubbo-admin-ai/agent/react" - "dubbo-admin-ai/manager" - "dubbo-admin-ai/plugins/dashscope" - "dubbo-admin-ai/server" - - "github.com/gin-gonic/gin" + "dubbo-admin-ai/component/agent/react" + "dubbo-admin-ai/component/logger" + "dubbo-admin-ai/component/memory" + "dubbo-admin-ai/component/models" + compRag "dubbo-admin-ai/component/rag" + "dubbo-admin-ai/component/server" + "dubbo-admin-ai/component/tools" + "dubbo-admin-ai/runtime" ) -func main() { - port := flag.Int("port", 8880, "Port for the AI agent server") - mode := flag.String("mode", "release", "Server mode: dev or prod") - envPath := flag.String("env", "./.env", "Path to the .env file") - flag.Parse() +// registerFactorys explicitly registers all component factories +// Registration order determines component initialization order +func registerFactorys(rt *runtime.Runtime) { + // Core components (no dependencies) + rt.RegisterFactory("logger", logger.LoggerFactory) + rt.RegisterFactory("memory", memory.MemoryFactory) - var logger *slog.Logger - switch *mode { - case "release": - logger = manager.ProductionLogger() - gin.SetMode(gin.ReleaseMode) - case "dev": - logger = manager.DevLogger() - gin.SetMode(gin.DebugMode) - } + // Model components (depend on logger) + rt.RegisterFactory("models", models.ModelsFactory) - reActAgent, err := react.Create(manager.Registry(dashscope.Qwen3_coder.Key(), *envPath, logger)) - if err != nil { - logger.Error("Failed to create ReAct agent", "error", err) - return - } + // Tools components (depend on models, memory) + rt.RegisterFactory("tools", tools.ToolsFactory) - apiRouter := server.NewRouter(reActAgent) + // RAG components (depend on models) + rt.RegisterFactory("rag", compRag.RAGFactory) - server := &http.Server{ - Addr: fmt.Sprintf(":%d", *port), - Handler: apiRouter.GetEngine(), - } + // Server components (depend on all other components) + rt.RegisterFactory("server", server.ServerFactory) - // 启动服务器 - go func() { - fmt.Printf("🤖 Dubbo Admin AI Agent Server starting on port %d...\n", *port) - fmt.Printf("📖 API Documentation: http://localhost:%d/docs\n", *port) - fmt.Printf("🔍 Health Check: http://localhost:%d/api/v1/ai/health\n", *port) + // Agent components (depend on tools, rag) + rt.RegisterFactory("agent", react.AgentFactory) +} - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Failed to start server: %v", err) - } - }() +func main() { + configPath := flag.String("config", "config.yaml", "Path to the AI configuration file") + flag.Parse() + + rt, err := runtime.Bootstrap(*configPath, registerFactorys) + if err != nil { + log.Fatalf("Failed to initialize runtime: %v", err) + } + rt.GetLogger().Info("🤖 Dubbo Admin AI Agent Server initialized successfully") - // 等待中断信号以优雅关闭服务器 + // Wait for interrupt signal to gracefully shutdown server quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit fmt.Println("🛑 Shutting down server...") - // 5秒超时的优雅关闭 - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - if err := server.Shutdown(ctx); err != nil { - log.Fatalf("Server forced to shutdown: %v", err) + // Stop all components + if err := stopComponents(rt); err != nil { + log.Printf("Warning: Error stopping components: %v", err) } fmt.Println("✅ Server exited") } + +func stopComponents(rt *runtime.Runtime) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + done := make(chan error) + + go func() { + rt.Components.Range(func(key, value interface{}) bool { + comp := value.(runtime.Component) + + if err := comp.Stop(); err != nil { + rt.GetLogger().Error("Failed to stop component", + "name", comp.Name(), + "error", err) + } + + return true + }) + done <- nil + }() + + select { + case <-ctx.Done(): + return fmt.Errorf("shutdown timeout") + case <-done: + return nil + } +} diff --git a/ai/manager/manager.go b/ai/manager/manager.go deleted file mode 100644 index 6c1efc17..00000000 --- a/ai/manager/manager.go +++ /dev/null @@ -1,142 +0,0 @@ -package manager - -import ( - "context" - "log/slog" - "os" - "sync" - - "dubbo-admin-ai/config" - "dubbo-admin-ai/plugins/dashscope" - "dubbo-admin-ai/plugins/siliconflow" - - "github.com/dusted-go/logging/prettylog" - "github.com/firebase/genkit/go/core/api" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/googlegenai" - "github.com/firebase/genkit/go/plugins/pinecone" - "github.com/joho/godotenv" -) - -var ( - gloRegistry *genkit.Genkit - gloLogger *slog.Logger - once sync.Once -) - -func Registry(modelName string, envPath string, logger *slog.Logger) (registry *genkit.Genkit) { - once.Do(func() { - gloLogger = logger - if logger == nil { - gloLogger = ProductionLogger() - } - LoadEnvVars2Config(envPath) - gloRegistry = defaultRegistry(modelName) - }) - if gloRegistry == nil { - panic("Failed to get global registry") - } - return gloRegistry -} - -// Load environment variables from .env file -func LoadEnvVars2Config(envPath string) { - // Check if the .env file exists, if not, try to find in the current directory - if _, err := os.Stat(envPath); os.IsNotExist(err) { - if _, err := os.Stat("./.env"); err == nil { - envPath = "./.env" - } - } - - // Load environment variables - if err := godotenv.Load(envPath); err != nil { - GetLogger().Warn("No .env file found at " + envPath + ", proceeding with existing environment variables") - } - - // config.GEMINI_API_KEY = os.Getenv("GEMINI_API_KEY") - // config.SILICONFLOW_API_KEY = os.Getenv("SILICONFLOW_API_KEY") - config.DASHSCOPE_API_KEY = os.Getenv("DASHSCOPE_API_KEY") - config.PINECONE_API_KEY = os.Getenv("PINECONE_API_KEY") - config.COHERE_API_KEY = os.Getenv("COHERE_API_KEY") - - if config.COHERE_API_KEY == "" { - GetLogger().Warn("COHERE_API_KEY missing in the environment variables. Please check out.") - } - // if config.GEMINI_API_KEY == "" { - // GetLogger().Warn("GEMINI_API_KEY missing in the environment variables. Please check out.") - // } - // if config.SILICONFLOW_API_KEY == "" { - // GetLogger().Warn("SILICONFLOW_API_KEY missing in the environment variables. Please check out.") - // } - if config.DASHSCOPE_API_KEY == "" { - GetLogger().Warn("DASHSCOPE_API_KEY missing in the environment variables. Please check out.") - } - if config.PINECONE_API_KEY == "" { - GetLogger().Warn("PINECONE_API_KEY missing in the environment variables. Please check out.") - } -} - -func defaultRegistry(modelName string) *genkit.Genkit { - ctx := context.Background() - plugins := []api.Plugin{} - if config.SILICONFLOW_API_KEY != "" { - plugins = append(plugins, &siliconflow.SiliconFlow{ - APIKey: config.SILICONFLOW_API_KEY, - }) - } - if config.GEMINI_API_KEY != "" { - plugins = append(plugins, &googlegenai.GoogleAI{ - APIKey: config.GEMINI_API_KEY, - }) - } - if config.DASHSCOPE_API_KEY != "" { - plugins = append(plugins, &dashscope.DashScope{ - APIKey: config.DASHSCOPE_API_KEY, - }) - } - if config.PINECONE_API_KEY != "" { - plugins = append(plugins, &pinecone.Pinecone{ - APIKey: config.PINECONE_API_KEY, - }) - } - - registry := genkit.Init(ctx, - genkit.WithPlugins(plugins...), - genkit.WithDefaultModel(modelName), - genkit.WithPromptDir(config.PROMPT_DIR_PATH), - ) - return registry -} - -func DevLogger() *slog.Logger { - slog.SetDefault( - slog.New( - prettylog.NewHandler(&slog.HandlerOptions{ - Level: slog.LevelDebug, - AddSource: true, - ReplaceAttr: nil, - }), - ), - ) - return slog.Default() -} - -func ProductionLogger() *slog.Logger { - slog.SetDefault( - slog.New( - prettylog.NewHandler(&slog.HandlerOptions{ - Level: slog.LevelInfo, - AddSource: false, - ReplaceAttr: nil, - }), - ), - ) - return slog.Default() -} - -func GetLogger() *slog.Logger { - if gloLogger == nil { - gloLogger = ProductionLogger() - } - return gloLogger -} diff --git a/ai/plugins/dashscope/dashscope.go b/ai/plugins/dashscope/dashscope.go deleted file mode 100644 index 3c3b693a..00000000 --- a/ai/plugins/dashscope/dashscope.go +++ /dev/null @@ -1,145 +0,0 @@ -package dashscope - -import ( - "context" - "os" - - "dubbo-admin-ai/plugins/model" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/core/api" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/compat_oai" - openaiGo "github.com/openai/openai-go" - "github.com/openai/openai-go/option" -) - -const provider = "dashscope" -const baseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1" - -const ( - qwen3_235b_a22b = "qwen3-235b-a22b-instruct-2507" - qwen_max = "qwen-max" - qwen_plus = "qwen-plus" - qwen_flash = "qwen-flash" - qwen3_coder = "qwen3-coder-plus" - - text_embedding_v4 = "text-embedding-v4" -) - -type TextEmbeddingConfig struct { - Dimensions int `json:"dimensions,omitempty"` - EncodingFormat openaiGo.EmbeddingNewParamsEncodingFormat `json:"encodingFormat,omitempty"` -} - -var ( - Qwen3 = model.NewModel(provider, qwen3_235b_a22b, &compat_oai.BasicText) - Qwen_plus = model.NewModel(provider, qwen_plus, &compat_oai.BasicText) - Qwen_max = model.NewModel(provider, qwen_max, &compat_oai.BasicText) - Qwen3_coder = model.NewModel(provider, qwen3_coder, &compat_oai.BasicText) - Qwen_flash = model.NewModel(provider, qwen_flash, &compat_oai.BasicText) - - Qwen3_embedding = model.NewEmbedder( - provider, - text_embedding_v4, - 1024, - &ai.EmbedderSupports{ - Input: []string{"text"}, - }, - TextEmbeddingConfig{}, - ) - - supportedModels = []*model.Model{ - Qwen3, - Qwen_plus, - Qwen_max, - Qwen3_coder, - Qwen_flash, - } - - supportedEmbeddingModels = []*model.Embedder{ - Qwen3_embedding, - } -) - -type DashScope struct { - APIKey string - - Opts []option.RequestOption - - openAICompatible *compat_oai.OpenAICompatible -} - -// Name implements genkit.Plugin. -func (o *DashScope) Name() string { - return provider -} - -// Init implements genkit.Plugin. -func (o *DashScope) Init(ctx context.Context) []api.Action { - apiKey := o.APIKey - - // if api key is not set, get it from environment variable - if apiKey == "" { - apiKey = os.Getenv("DASHSCOPE_API_KEY") - } - - if apiKey == "" { - panic("DashScope plugin initialization failed: apiKey is required") - } - - if o.openAICompatible == nil { - o.openAICompatible = &compat_oai.OpenAICompatible{} - } - - // set the options - o.openAICompatible.Opts = []option.RequestOption{ - option.WithAPIKey(apiKey), - option.WithBaseURL(baseURL), - } - - if len(o.Opts) > 0 { - o.openAICompatible.Opts = append(o.openAICompatible.Opts, o.Opts...) - } - - o.openAICompatible.Provider = provider - compatActions := o.openAICompatible.Init(ctx) - - var actions []api.Action - actions = append(actions, compatActions...) - - // define default models - for _, model := range supportedModels { - actions = append(actions, o.DefineModel(model.InternalKey(), model.Options()).(api.Action)) - } - - // define default embedders - for _, embedder := range supportedEmbeddingModels { - actions = append(actions, o.DefineEmbedder(embedder.InternalKey(), embedder.Options()).(api.Action)) - } - return actions -} - -func (o *DashScope) Model(g *genkit.Genkit, name string) ai.Model { - return o.openAICompatible.Model(g, api.NewName(provider, name)) -} - -func (o *DashScope) DefineModel(id string, opts ai.ModelOptions) ai.Model { - return o.openAICompatible.DefineModel(provider, id, opts) -} - -func (o *DashScope) DefineEmbedder(id string, opts *ai.EmbedderOptions) ai.Embedder { - return o.openAICompatible.DefineEmbedder(provider, id, opts) -} - -func (o *DashScope) Embedder(g *genkit.Genkit, name string) ai.Embedder { - return o.openAICompatible.Embedder(g, api.NewName(provider, name)) -} - -func (o *DashScope) ListActions(ctx context.Context) []api.ActionDesc { - return o.openAICompatible.ListActions(ctx) -} - -func (o *DashScope) ResolveAction(atype api.ActionType, name string) api.Action { - return o.openAICompatible.ResolveAction(atype, name) -} diff --git a/ai/plugins/model/model.go b/ai/plugins/model/model.go deleted file mode 100644 index 6cd6d765..00000000 --- a/ai/plugins/model/model.go +++ /dev/null @@ -1,73 +0,0 @@ -package model - -import ( - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/core" - "github.com/firebase/genkit/go/core/api" -) - -type Model struct { - provider string - internalKey string // internalKey is the internal model representation of different providers. - supports *ai.ModelSupports -} - -func NewModel(provider string, internalKey string, supports *ai.ModelSupports) *Model { - return &Model{ - provider: provider, - internalKey: internalKey, - supports: supports, - } -} - -// Key is the model query string of genkit registry. -func (m *Model) Key() string { - return api.NewName(m.provider, m.internalKey) -} - -func (m *Model) InternalKey() string { - return m.internalKey -} - -func (m *Model) Options() ai.ModelOptions { - return ai.ModelOptions{ - Label: m.internalKey, - Supports: m.supports, - Versions: []string{m.internalKey}, - } -} - -type Embedder struct { - config map[string]any - provider string - internalKey string - dimensions int - supports *ai.EmbedderSupports -} - -func NewEmbedder(provider string, internalKey string, dimensions int, supports *ai.EmbedderSupports, config any) *Embedder { - return &Embedder{ - config: core.InferSchemaMap(config), - provider: provider, - internalKey: internalKey, - dimensions: dimensions, - supports: supports, - } -} - -func (m *Embedder) Key() string { - return api.NewName(m.provider, m.internalKey) -} - -func (m *Embedder) InternalKey() string { - return m.internalKey -} - -func (m *Embedder) Options() *ai.EmbedderOptions { - return &ai.EmbedderOptions{ - ConfigSchema: m.config, - Label: m.internalKey, - Supports: m.supports, - Dimensions: m.dimensions, - } -} diff --git a/ai/plugins/siliconflow/siliconflow.go b/ai/plugins/siliconflow/siliconflow.go deleted file mode 100644 index 67579e70..00000000 --- a/ai/plugins/siliconflow/siliconflow.go +++ /dev/null @@ -1,118 +0,0 @@ -package siliconflow - -import ( - "context" - "os" - - "dubbo-admin-ai/plugins/model" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/core/api" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/compat_oai" - "github.com/openai/openai-go/option" -) - -const provider = "siliconflow" -const baseURL = "https://api.siliconflow.cn/v1" - -const ( - deepseekV3 = "deepseek-ai/DeepSeek-V3" - deepseekR1 = "deepseek-ai/DeepSeek-R1" - qwenQwQ32B = "Qwen/QwQ-32B" - qwen3Coder = "Qwen/Qwen3-Coder-480B-A35B-Instruct" -) - -var ( - DeepSeekV3 = model.NewModel(provider, deepseekV3, &compat_oai.BasicText) - QwenQwQ32B = model.NewModel(provider, qwenQwQ32B, &compat_oai.BasicText) - Qwen3Coder = model.NewModel(provider, qwen3Coder, &compat_oai.Multimodal) - DeepSeekR1 = model.NewModel(provider, deepseekR1, &compat_oai.BasicText) - - supportedModels = []*model.Model{ - DeepSeekV3, - QwenQwQ32B, - Qwen3Coder, - DeepSeekR1, - } - - // supportedEmbeddingModels = []string{} -) - -type SiliconFlow struct { - APIKey string - - Opts []option.RequestOption - - openAICompatible *compat_oai.OpenAICompatible -} - -// Name implements genkit.Plugin. -func (o *SiliconFlow) Name() string { - return provider -} - -func (o *SiliconFlow) Init(ctx context.Context) []api.Action { - apiKey := o.APIKey - - // if api key is not set, get it from environment variable - if apiKey == "" { - apiKey = os.Getenv("SILICONFLOW_API_KEY") - } - - if apiKey == "" { - panic("SiliconFlow plugin initialization failed: apiKey is required") - } - - if o.openAICompatible == nil { - o.openAICompatible = &compat_oai.OpenAICompatible{} - } - - // set the options - o.openAICompatible.Opts = []option.RequestOption{ - option.WithAPIKey(apiKey), - option.WithBaseURL(baseURL), - } - - if len(o.Opts) > 0 { - o.openAICompatible.Opts = append(o.openAICompatible.Opts, o.Opts...) - } - - o.openAICompatible.Provider = provider - compatActions := o.openAICompatible.Init(ctx) - - var actions []api.Action - actions = append(actions, compatActions...) - - // define default models - for _, model := range supportedModels { - actions = append(actions, o.DefineModel(model.InternalKey(), model.Options()).(api.Action)) - } - //TODO: define default embedders - - return actions -} - -func (o *SiliconFlow) Model(g *genkit.Genkit, name string) ai.Model { - return o.openAICompatible.Model(g, api.NewName(provider, name)) -} - -func (o *SiliconFlow) DefineModel(id string, opts ai.ModelOptions) ai.Model { - return o.openAICompatible.DefineModel(provider, id, opts) -} - -func (o *SiliconFlow) DefineEmbedder(id string, opts *ai.EmbedderOptions) ai.Embedder { - return o.openAICompatible.DefineEmbedder(provider, id, opts) -} - -func (o *SiliconFlow) Embedder(g *genkit.Genkit, name string) ai.Embedder { - return o.openAICompatible.Embedder(g, api.NewName(provider, name)) -} - -func (o *SiliconFlow) ListActions(ctx context.Context) []api.ActionDesc { - return o.openAICompatible.ListActions(ctx) -} - -func (o *SiliconFlow) ResolveAction(atype api.ActionType, name string) api.Action { - return o.openAICompatible.ResolveAction(atype, name) -} diff --git a/ai/prompts/agentAct.txt b/ai/prompts/agentAct.txt new file mode 100644 index 00000000..64e0dbc6 --- /dev/null +++ b/ai/prompts/agentAct.txt @@ -0,0 +1 @@ +You are a helpful assistant. diff --git a/ai/runtime/runtime.go b/ai/runtime/runtime.go new file mode 100644 index 00000000..53137166 --- /dev/null +++ b/ai/runtime/runtime.go @@ -0,0 +1,237 @@ +package runtime + +import ( + "context" + "dubbo-admin-ai/config" + "fmt" + "log/slog" + "sync" + + "github.com/firebase/genkit/go/genkit" + "gopkg.in/yaml.v3" +) + +// Component defines the interface for all components +type Component interface { + Name() string + Validate() error + Init(*Runtime) error + Start() error + Stop() error +} + +// ComponentFactory is the function type for creating components +type ComponentFactory func(config *yaml.Node) (Component, error) + +var ( + gloRuntime *Runtime = nil +) + +func NewRuntime() *Runtime { + return &Runtime{ + factories: make(map[string]ComponentFactory), + factoryOrder: make([]string, 0), + genkitOptions: make([]genkit.GenkitOption, 0), + } +} + +func Bootstrap(configFile string, registerFn func(rt *Runtime)) (*Runtime, error) { + gloRuntime = NewRuntime() + + // Register component factories + if registerFn != nil { + registerFn(gloRuntime) + } + + // Create config loader and load all configurations + loader := config.NewLoader(configFile) + loadedCfg, err := loader.Load() + if err != nil { + return nil, fmt.Errorf("failed to load config: %w", err) + } + + // Create component instances + instances, err := gloRuntime.createComponents(loadedCfg) + if err != nil { + return nil, fmt.Errorf("failed to create components: %w", err) + } + + // Initialize components in dependency order, which is the order of factory registration. + for _, comp := range instances { + if err := comp.Validate(); err != nil { + return nil, fmt.Errorf("failed to validate %s: %w", comp.Name(), err) + } + + if err := comp.Init(gloRuntime); err != nil { + return nil, fmt.Errorf("failed to init %s: %w", comp.Name(), err) + } + gloRuntime.Components.Store(comp.Name(), comp) + } + + // Start all loaded components + gloRuntime.Components.Range(func(key, value any) bool { + comp := value.(Component) + if err := comp.Start(); err != nil { + return false + } + return true + }) + + return gloRuntime, nil +} + +type Runtime struct { + mu sync.RWMutex + configFile string + genkitRegistry *genkit.Genkit + genkitOptions []genkit.GenkitOption + factories map[string]ComponentFactory + factoryOrder []string + + Components sync.Map +} + +// RegisterFactory registers a component factory function +func (r *Runtime) RegisterFactory(componentType string, factory ComponentFactory) { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.factories[componentType]; exists { + fmt.Printf("Warning: component type '%s' is already registered, overwriting\n", componentType) + } else { + // Only record order on first registration + r.factoryOrder = append(r.factoryOrder, componentType) + } + + r.factories[componentType] = factory +} + +// RegisterGenkitOption registers Genkit initialization options +func (r *Runtime) RegisterGenkitOption(opts ...genkit.GenkitOption) { + r.mu.Lock() + defer r.mu.Unlock() + r.genkitOptions = append(r.genkitOptions, opts...) +} + +func (r *Runtime) SetGenkitRegistry(registry *genkit.Genkit) { + r.mu.Lock() + defer r.mu.Unlock() + r.genkitRegistry = registry +} + +func (r *Runtime) GetFactoryFn(componentType string) (ComponentFactory, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + factory, exists := r.factories[componentType] + if !exists { + return nil, fmt.Errorf("component type '%s' not registered", componentType) + } + + return factory, nil +} + +// Creates component instances based on loaded configuration and factory registration order +func (r *Runtime) createComponents(loadedCfg *config.LoadedConfig) ([]Component, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + var instances []Component + processed := make(map[string]bool) + + // Create components following factory registration order + for _, componentType := range r.factoryOrder { + // Find and create the component with this type + for name, cfg := range loadedCfg.Components { + if processed[name] { + continue + } + + if cfg.Type == componentType { + comp, err := r.createComponent(cfg) + if err != nil { + return nil, fmt.Errorf("failed to create %s: %w", name, err) + } + + instances = append(instances, comp) + processed[name] = true + } + } + } + + // Fail fast when configuration contains component types without registered factories. + for name, cfg := range loadedCfg.Components { + if processed[name] { + continue + } + if _, exists := r.factories[cfg.Type]; !exists { + return nil, fmt.Errorf("no factory for %s", cfg.Type) + } + } + + return instances, nil +} + +// createComponent get factory by component type and create component instance +func (r *Runtime) createComponent(cfg *config.Config) (Component, error) { + factoryFn, err := r.GetFactoryFn(cfg.Type) + if err != nil { + return nil, fmt.Errorf("no factory for %s: %w", cfg.Type, err) + } + + comp, err := factoryFn(&cfg.Spec) + if err != nil { + return nil, fmt.Errorf("factory failed for %s: %w", cfg.Type, err) + } + + return comp, nil +} + +func GetLogger() *slog.Logger { + if gloRuntime == nil { + return slog.Default() + } + return gloRuntime.GetLogger() +} + +func (rt *Runtime) GetLogger() *slog.Logger { + return slog.Default() +} + +func (rt *Runtime) GetContext() context.Context { + return context.Background() +} + +func (rt *Runtime) GetRegistry() *genkit.Genkit { + if rt.genkitRegistry == nil { + panic("Genkit registry not initialized") + } + return rt.genkitRegistry +} + +func (rt *Runtime) GetGenkitRegistry() *genkit.Genkit { + // Returns nil before genkit.Init() is called, does not panic + // Components must check if return value is nil + return rt.genkitRegistry +} + +// GetComponent retrieves a component instance by name +func (rt *Runtime) GetComponent(name string) (Component, error) { + v, ok := rt.Components.Load(name) + if !ok { + return nil, fmt.Errorf("component not found: %s", name) + } + + return v.(Component), nil +} + +func (rt *Runtime) RegisterComponent(comp Component) { + rt.Components.Store(comp.Name(), comp) +} + +func GetRuntime() *Runtime { + if gloRuntime == nil { + panic("Runtime not initialized, call Bootstrap() first") + } + return gloRuntime +} diff --git a/ai/runtime/test/runtime_test.go b/ai/runtime/test/runtime_test.go new file mode 100644 index 00000000..4c634a05 --- /dev/null +++ b/ai/runtime/test/runtime_test.go @@ -0,0 +1,252 @@ +package runtimetest + +import ( + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + "testing" + + appruntime "dubbo-admin-ai/runtime" + + "gopkg.in/yaml.v3" +) + +type stubComponent struct { + name string + calls *[]string + callsMu *sync.Mutex + validateErr error + initCalled *bool +} + +func (s *stubComponent) Name() string { return s.name } +func (s *stubComponent) Validate() error { + if s.calls != nil && s.callsMu != nil { + s.callsMu.Lock() + *s.calls = append(*s.calls, "validate:"+s.name) + s.callsMu.Unlock() + } + return s.validateErr +} +func (s *stubComponent) Init(*appruntime.Runtime) error { + if s.calls != nil && s.callsMu != nil { + s.callsMu.Lock() + *s.calls = append(*s.calls, "init:"+s.name) + s.callsMu.Unlock() + } + if s.initCalled != nil { + *s.initCalled = true + } + return nil +} +func (s *stubComponent) Start() error { return nil } +func (s *stubComponent) Stop() error { return nil } + +func schemaDirFromRepo(t *testing.T) string { + t.Helper() + _, file, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("failed to get caller") + } + root := filepath.Dir(filepath.Dir(filepath.Dir(file))) + return filepath.Join(root, "schema", "json") +} + +func repoRoot(t *testing.T) string { + t.Helper() + _, file, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("failed to get caller") + } + return filepath.Dir(filepath.Dir(filepath.Dir(file))) +} + +func writeRuntimeFixture(t *testing.T, dir, content string) { + t.Helper() + if err := os.WriteFile(filepath.Join(dir, "config.yaml"), []byte(content), 0o644); err != nil { + t.Fatalf("write config.yaml: %v", err) + } +} + +func writeComponentFile(t *testing.T, dir, name, content string) { + t.Helper() + if err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0o644); err != nil { + t.Fatalf("write %s: %v", name, err) + } +} + +func TestRuntime_RegisterFactory(t *testing.T) { + t.Run("duplicate", func(t *testing.T) { + rt := appruntime.NewRuntime() + + origStdout := os.Stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("os.Pipe(): %v", err) + } + os.Stdout = w + + f1 := func(*yaml.Node) (appruntime.Component, error) { return &stubComponent{name: "c1"}, nil } + f2 := func(*yaml.Node) (appruntime.Component, error) { return &stubComponent{name: "c2"}, nil } + rt.RegisterFactory("dup", f1) + rt.RegisterFactory("dup", f2) + + _ = w.Close() + os.Stdout = origStdout + out, _ := io.ReadAll(r) + _ = r.Close() + + gotFactory, err := rt.GetFactoryFn("dup") + if err != nil { + t.Fatalf("GetFactoryFn() error: %v", err) + } + comp, err := gotFactory(&yaml.Node{}) + if err != nil { + t.Fatalf("factory error: %v", err) + } + if comp.Name() != "c2" { + t.Fatalf("expected second registration to overwrite, got %s", comp.Name()) + } + if !strings.Contains(string(out), "already registered") { + t.Fatalf("expected duplicate warning, got %q", string(out)) + } + }) + + t.Run("concurrent", func(t *testing.T) { + rt := appruntime.NewRuntime() + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + typeName := fmt.Sprintf("t-%d", i) + rt.RegisterFactory(typeName, func(*yaml.Node) (appruntime.Component, error) { + return &stubComponent{name: fmt.Sprintf("c-%d", i)}, nil + }) + }() + } + wg.Wait() + + for i := 0; i < 100; i++ { + if _, err := rt.GetFactoryFn(fmt.Sprintf("t-%d", i)); err != nil { + t.Fatalf("missing registered factory t-%d: %v", i, err) + } + } + }) +} + +func TestRuntime_Get(t *testing.T) { + tests := []struct { + name string + runFn func(*appruntime.Runtime) error + errLike string + }{ + {name: "factory_not_found", runFn: func(rt *appruntime.Runtime) error { _, err := rt.GetFactoryFn("test"); return err }, errLike: "not registered"}, + {name: "component_not_found", runFn: func(rt *appruntime.Runtime) error { _, err := rt.GetComponent("agent"); return err }, errLike: "component not found"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rt := appruntime.NewRuntime() + err := tt.runFn(rt) + if err == nil || !strings.Contains(err.Error(), tt.errLike) { + t.Fatalf("expected %q error, got %v", tt.errLike, err) + } + }) + } +} + +func TestRuntime_ComponentInitOrder(t *testing.T) { + dir := t.TempDir() + t.Setenv("SCHEMA_DIR", schemaDirFromRepo(t)) + writeRuntimeFixture(t, dir, "project: p\nversion: v\ncomponents:\n logger: logger.yaml\n memory: memory.yaml\n") + writeComponentFile(t, dir, "logger.yaml", "type: logger\nspec: {}\n") + writeComponentFile(t, dir, "memory.yaml", "type: memory\nspec: {}\n") + + calls := make([]string, 0) + var mu sync.Mutex + _, err := appruntime.Bootstrap(filepath.Join(dir, "config.yaml"), func(rt *appruntime.Runtime) { + rt.RegisterFactory("logger", func(*yaml.Node) (appruntime.Component, error) { + return &stubComponent{name: "logger", calls: &calls, callsMu: &mu}, nil + }) + rt.RegisterFactory("memory", func(*yaml.Node) (appruntime.Component, error) { + return &stubComponent{name: "memory", calls: &calls, callsMu: &mu}, nil + }) + }) + if err != nil { + t.Fatalf("Bootstrap() error: %v", err) + } + + got := strings.Join(calls, ",") + want := "validate:logger,init:logger,validate:memory,init:memory" + if got != want { + t.Fatalf("call order = %q, want %q", got, want) + } +} + +func TestBootstrap(t *testing.T) { + t.Run("validate_fail_stops_init", func(t *testing.T) { + dir := t.TempDir() + t.Setenv("SCHEMA_DIR", schemaDirFromRepo(t)) + writeRuntimeFixture(t, dir, "project: p\nversion: v\ncomponents:\n logger: logger.yaml\n server: server.yaml\n") + writeComponentFile(t, dir, "logger.yaml", "type: logger\nspec: {}\n") + writeComponentFile(t, dir, "server.yaml", "type: server\nspec: {}\n") + + serverInitCalled := false + _, err := appruntime.Bootstrap(filepath.Join(dir, "config.yaml"), func(rt *appruntime.Runtime) { + rt.RegisterFactory("logger", func(*yaml.Node) (appruntime.Component, error) { + return &stubComponent{name: "logger", validateErr: fmt.Errorf("boom")}, nil + }) + rt.RegisterFactory("server", func(*yaml.Node) (appruntime.Component, error) { + return &stubComponent{name: "server", initCalled: &serverInitCalled}, nil + }) + }) + + if err == nil || !strings.Contains(err.Error(), "failed to validate logger") { + t.Fatalf("expected validate fail error, got %v", err) + } + if serverInitCalled { + t.Fatalf("server should not be initialized after validate failure") + } + }) + + t.Run("missing_factory_for_configured_type", func(t *testing.T) { + dir := t.TempDir() + t.Setenv("SCHEMA_DIR", schemaDirFromRepo(t)) + writeRuntimeFixture(t, dir, "project: p\nversion: v\ncomponents:\n logger: logger.yaml\n") + writeComponentFile(t, dir, "logger.yaml", "type: logger\nspec: {}\n") + + _, err := appruntime.Bootstrap(filepath.Join(dir, "config.yaml"), func(rt *appruntime.Runtime) {}) + if err == nil || !strings.Contains(err.Error(), "no factory for") { + t.Fatalf("expected no factory error, got %v", err) + } + }) +} + +func TestRuntime_GetRuntime(t *testing.T) { + dir := t.TempDir() + src := `package main +import "dubbo-admin-ai/runtime" +func main() { _ = runtime.GetRuntime() }` + mainFile := filepath.Join(dir, "main.go") + if err := os.WriteFile(mainFile, []byte(src), 0o644); err != nil { + t.Fatalf("write main.go: %v", err) + } + + cmd := exec.Command("go", "run", mainFile) + cmd.Dir = repoRoot(t) + cmd.Env = append(os.Environ(), "GOCACHE="+filepath.Join(dir, "gocache")) + out, err := cmd.CombinedOutput() + if err == nil { + t.Fatalf("expected go run to fail with panic, got success") + } + if !strings.Contains(string(out), "Runtime not initialized") { + t.Fatalf("expected panic output, got: %s", string(out)) + } +} diff --git a/ai/schema/json/README.md b/ai/schema/json/README.md new file mode 100644 index 00000000..373ff7fc --- /dev/null +++ b/ai/schema/json/README.md @@ -0,0 +1,19 @@ +# JSON Schema Index + +This directory contains JSON Schema definitions for YAML configs. + +- `main.schema.json`: root [`config.yaml`](/Users/liwener/.codex/worktrees/acbb/dubbo-admin/ai/config.yaml) +- `logger.schema.json`: [`component/logger/logger.yaml`](/Users/liwener/.codex/worktrees/acbb/dubbo-admin/ai/component/logger/logger.yaml) +- `memory.schema.json`: [`component/memory/memory.yaml`](/Users/liwener/.codex/worktrees/acbb/dubbo-admin/ai/component/memory/memory.yaml) +- `models.schema.json`: [`component/models/models.yaml`](/Users/liwener/.codex/worktrees/acbb/dubbo-admin/ai/component/models/models.yaml) +- `tools.schema.json`: [`component/tools/tools.yaml`](/Users/liwener/.codex/worktrees/acbb/dubbo-admin/ai/component/tools/tools.yaml) +- `server.schema.json`: [`component/server/server.yaml`](/Users/liwener/.codex/worktrees/acbb/dubbo-admin/ai/component/server/server.yaml) +- `rag.schema.json`: [`component/rag/rag.yaml`](/Users/liwener/.codex/worktrees/acbb/dubbo-admin/ai/component/rag/rag.yaml) +- `agent.schema.json`: [`component/agent/agent.yaml`](/Users/liwener/.codex/worktrees/acbb/dubbo-admin/ai/component/agent/agent.yaml) + +Notes: +- Schema draft: `2020-12` +- `additionalProperties: false` is used to enforce unknown-field errors at the structural layer. +- Loader is the only structural layer (`yaml.Unmarshal` -> schema defaults+validation -> strict decode with KnownFields). +- Defaults are declared in schema and injected only by Loader/schema engine. +- Required-field policy is documented in [`REQUIRED_FIELDS.md`](/Users/liwener/.codex/worktrees/acbb/dubbo-admin/ai/schema/json/REQUIRED_FIELDS.md). diff --git a/ai/schema/json/REQUIRED_FIELDS.md b/ai/schema/json/REQUIRED_FIELDS.md new file mode 100644 index 00000000..c0d862a9 --- /dev/null +++ b/ai/schema/json/REQUIRED_FIELDS.md @@ -0,0 +1,119 @@ +# Config Required Fields Matrix + +This document defines required-field policy for configuration schemas. + +## Main Config (`config.yaml`) + +| Field | Required | Notes | +|---|---|---| +| `project` | yes | Non-empty string. | +| `version` | yes | Non-empty string. | +| `components` | yes | Object with at least one entry. | +| `components.` | yes | Must be `string` or `array[string]`. | + +## Logger Component (`component/logger/logger.yaml`) + +| Field | Required | Notes | +|---|---|---| +| `type` | yes | Must be `logger`. | +| `spec` | yes | Object. | +| `spec.level` | no | Default: `info`. | + +## Memory Component (`component/memory/memory.yaml`) + +| Field | Required | Notes | +|---|---|---| +| `type` | yes | Must be `memory`. | +| `spec` | yes | Object. | +| `spec.history_key` | no | Default: `chat_history`. | +| `spec.max_turns` | no | Default: `100`, must be `>= 1` when set. | + +## Models Component (`component/models/models.yaml`) + +| Field | Required | Notes | +|---|---|---| +| `type` | yes | Must be `models`. | +| `spec` | yes | Object. | +| `spec.default_model` | yes | Non-empty string. | +| `spec.default_embedding` | yes | Non-empty string. | +| `spec.providers` | yes | Object with at least one provider. | +| `spec.providers..base_url` | yes | Non-empty string. | +| `spec.providers..api_key` | no | May be empty in some environments. | +| `spec.providers..models[]` | no | Defaults to empty array. | +| `spec.providers..embedders[]` | no | Defaults to empty array. | +| `spec.providers..models[].name` | yes | Non-empty string. | +| `spec.providers..models[].key` | yes | Non-empty string. | +| `spec.providers..models[].type` | no | Default: `chat`. | +| `spec.providers..embedders[].name` | yes | Non-empty string. | +| `spec.providers..embedders[].key` | yes | Non-empty string. | +| `spec.providers..embedders[].dimensions` | yes | Integer `>= 1`. | +| `spec.providers..embedders[].type` | no | Default: `text`. | + +## Tools Component (`component/tools/tools.yaml`) + +| Field | Required | Notes | +|---|---|---| +| `type` | yes | Must be `tools`. | +| `spec` | yes | Object. | +| `spec.enable_mock_tools` | no | Default: `true`. | +| `spec.enable_internal_tools` | no | Default: `true`. | +| `spec.enable_mcp_tools` | no | Default: `true`. | +| `spec.mcp_host_name` | conditional | Required when `spec.enable_mcp_tools=true`. | +| `spec.mcp_timeout` | no | Default: `30`, integer `>= 1`. | +| `spec.mcp_max_retries` | no | Default: `3`, integer `>= 0`. | + +## Server Component (`component/server/server.yaml`) + +| Field | Required | Notes | +|---|---|---| +| `type` | yes | Must be `server`. | +| `spec` | yes | Object. | +| `spec.port` | no | Default: `8888`, range `1..65535`. | +| `spec.host` | no | Default: `0.0.0.0`. | +| `spec.debug` | no | Default: `false`. | +| `spec.cors_origins` | no | Default: `[*]`. | +| `spec.read_timeout` | no | Default: `30`, integer `>= 1`. | +| `spec.write_timeout` | no | Default: `30`, integer `>= 1`. | + +## RAG Component (`component/rag/rag.yaml`) + +| Field | Required | Notes | +|---|---|---| +| `type` | yes | Must be `rag`. | +| `spec` | yes | Object. | +| `spec.embedder` | yes | Object. | +| `spec.loader` | yes | Object. | +| `spec.splitter` | yes | Object. | +| `spec.indexer` | yes | Object. | +| `spec.retriever` | yes | Object. | +| `spec.reranker` | no | Optional object. | +| `spec.embedder.spec` | yes | Object. | +| `spec.embedder.spec.model` | yes | Non-empty string. | +| `spec.loader.spec` | yes | Object (may be empty). | +| `spec.splitter.spec` | yes | Object. | +| `spec.indexer.spec` | yes | Object. | +| `spec.retriever.spec` | yes | Object. | +| `spec.reranker.spec` | conditional | Required when `spec.reranker` exists. | +| `spec.reranker.spec.api_key` | conditional | Required when `spec.reranker.spec.enabled=true`. | + +## Agent Component (`component/agent/agent.yaml`) + +| Field | Required | Notes | +|---|---|---| +| `type` | yes | Must be `agent`. | +| `spec` | yes | Object. | +| `spec.default_model` | yes | Non-empty string. | +| `spec.prompt_base_path` | yes | Non-empty string. | +| `spec.stages` | yes | Array with at least one stage. | +| `spec.agent_type` | no | Default: `react` (current supported value). | +| `spec.max_iterations` | no | Default: `10`, integer `>= 1`. | +| `spec.stage_channel_buffer_size` | no | Default: `5`, integer `>= 1`. | +| `spec.mcp_host_name` | no | Default: `mcp_host`. | +| `spec.stages[].name` | yes | Non-empty string. | +| `spec.stages[].flow_type` | yes | Enum: `think|act|observe|feedback`. | +| `spec.stages[].prompt_file` | yes | Non-empty string. | +| `spec.stages[].temperature` | no | Default: `0.7`, `(0,2]`. | +| `spec.stages[].top_p` | no | Default: `0.9`, `(0,1]`. | +| `spec.stages[].max_tokens` | no | Default: `4096`, integer `>= 1`. | +| `spec.stages[].timeout` | no | Default: `30`, integer `>= 1`. | +| `spec.stages[].enable_tools` | no | Default: `false`. | diff --git a/ai/schema/json/agent.schema.json b/ai/schema/json/agent.schema.json new file mode 100644 index 00000000..9d55492f --- /dev/null +++ b/ai/schema/json/agent.schema.json @@ -0,0 +1,111 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://dubbo-admin.ai/schema/agent.schema.json", + "title": "Agent Component Config", + "type": "object", + "additionalProperties": false, + "required": ["type", "spec"], + "properties": { + "type": { + "const": "agent" + }, + "spec": { + "type": "object", + "additionalProperties": false, + "required": ["default_model", "prompt_base_path", "stages"], + "properties": { + "agent_type": { + "type": "string", + "default": "react", + "enum": ["react"] + }, + "default_model": { + "type": "string", + "minLength": 1, + "default": "qwen-max" + }, + "prompt_base_path": { + "type": "string", + "minLength": 1, + "default": "./prompts" + }, + "max_iterations": { + "type": "integer", + "minimum": 1, + "default": 10 + }, + "stage_channel_buffer_size": { + "type": "integer", + "minimum": 1, + "default": 5 + }, + "mcp_host_name": { + "type": "string", + "minLength": 1, + "default": "mcp_host" + }, + "stages": { + "type": "array", + "minItems": 1, + "items": { + "$ref": "#/$defs/stage" + } + } + } + } + }, + "$defs": { + "stage": { + "type": "object", + "additionalProperties": false, + "required": ["name", "flow_type", "prompt_file"], + "properties": { + "name": { + "type": "string", + "minLength": 1 + }, + "flow_type": { + "type": "string", + "enum": ["think", "act", "observe", "feedback"], + "default": "think" + }, + "model": { + "type": "string" + }, + "prompt_file": { + "type": "string", + "minLength": 1 + }, + "temperature": { + "type": "number", + "exclusiveMinimum": 0, + "maximum": 2, + "default": 0.7 + }, + "top_p": { + "type": "number", + "exclusiveMinimum": 0, + "maximum": 1, + "default": 0.9 + }, + "max_tokens": { + "type": "integer", + "minimum": 1, + "default": 4096 + }, + "timeout": { + "type": "integer", + "minimum": 1, + "default": 30 + }, + "enable_tools": { + "type": "boolean", + "default": false + }, + "extra_prompt": { + "type": "string" + } + } + } + } +} diff --git a/ai/schema/json/logger.schema.json b/ai/schema/json/logger.schema.json new file mode 100644 index 00000000..b2d7bf36 --- /dev/null +++ b/ai/schema/json/logger.schema.json @@ -0,0 +1,24 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://dubbo-admin.ai/schema/logger.schema.json", + "title": "Logger Component Config", + "type": "object", + "additionalProperties": false, + "required": ["type", "spec"], + "properties": { + "type": { + "const": "logger" + }, + "spec": { + "type": "object", + "additionalProperties": false, + "properties": { + "level": { + "type": "string", + "enum": ["debug", "info", "warn", "error"], + "default": "info" + } + } + } + } +} diff --git a/ai/schema/json/main.schema.json b/ai/schema/json/main.schema.json new file mode 100644 index 00000000..8411e1ce --- /dev/null +++ b/ai/schema/json/main.schema.json @@ -0,0 +1,37 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://dubbo-admin.ai/schema/main.schema.json", + "title": "Dubbo Admin AI Main Config", + "type": "object", + "additionalProperties": false, + "required": ["project", "version", "components"], + "properties": { + "project": { + "type": "string", + "minLength": 1 + }, + "version": { + "type": "string", + "minLength": 1 + }, + "components": { + "type": "object", + "minProperties": 1, + "additionalProperties": { + "oneOf": [ + { + "type": "string", + "minLength": 1 + }, + { + "type": "array", + "items": { + "type": "string", + "minLength": 1 + } + } + ] + } + } + } +} diff --git a/ai/schema/json/memory.schema.json b/ai/schema/json/memory.schema.json new file mode 100644 index 00000000..aef88f7e --- /dev/null +++ b/ai/schema/json/memory.schema.json @@ -0,0 +1,29 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://dubbo-admin.ai/schema/memory.schema.json", + "title": "Memory Component Config", + "type": "object", + "additionalProperties": false, + "required": ["type", "spec"], + "properties": { + "type": { + "const": "memory" + }, + "spec": { + "type": "object", + "additionalProperties": false, + "properties": { + "history_key": { + "type": "string", + "enum": ["chat_history", "system_memory", "core_memory"], + "default": "chat_history" + }, + "max_turns": { + "type": "integer", + "minimum": 1, + "default": 100 + } + } + } + } +} diff --git a/ai/schema/json/models.schema.json b/ai/schema/json/models.schema.json new file mode 100644 index 00000000..8f507a0e --- /dev/null +++ b/ai/schema/json/models.schema.json @@ -0,0 +1,127 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://dubbo-admin.ai/schema/models.schema.json", + "title": "Models Component Config", + "type": "object", + "additionalProperties": false, + "required": ["type", "spec"], + "properties": { + "type": { + "const": "models" + }, + "spec": { + "type": "object", + "additionalProperties": false, + "required": ["default_model", "default_embedding", "providers"], + "properties": { + "default_model": { + "type": "string", + "minLength": 1, + "default": "dashscope/qwen-max" + }, + "default_embedding": { + "type": "string", + "minLength": 1, + "default": "dashscope/qwen3-embedding" + }, + "providers": { + "type": "object", + "minProperties": 1, + "additionalProperties": { + "$ref": "#/$defs/provider" + } + } + } + } + }, + "$defs": { + "provider": { + "type": "object", + "additionalProperties": false, + "required": ["base_url"], + "properties": { + "api_key": { + "type": "string", + "default": "" + }, + "base_url": { + "type": "string", + "minLength": 1 + }, + "models": { + "type": "array", + "default": [], + "items": { + "$ref": "#/$defs/model_info" + } + }, + "embedders": { + "type": "array", + "default": [], + "items": { + "$ref": "#/$defs/embedder_info" + } + }, + "config": { + "type": "object", + "default": {}, + "additionalProperties": true + } + } + }, + "model_info": { + "type": "object", + "additionalProperties": false, + "required": ["name", "key"], + "properties": { + "name": { + "type": "string", + "minLength": 1 + }, + "key": { + "type": "string", + "minLength": 1 + }, + "type": { + "type": "string", + "enum": ["chat", "multimodal", "code"], + "default": "chat" + }, + "config": { + "type": "object", + "default": {}, + "additionalProperties": true + } + } + }, + "embedder_info": { + "type": "object", + "additionalProperties": false, + "required": ["name", "key", "dimensions"], + "properties": { + "name": { + "type": "string", + "minLength": 1 + }, + "key": { + "type": "string", + "minLength": 1 + }, + "type": { + "type": "string", + "enum": ["text", "image", "audio", "multimodal"], + "default": "text" + }, + "dimensions": { + "type": "integer", + "minimum": 1 + }, + "config": { + "type": "object", + "default": {}, + "additionalProperties": true + } + } + } + } +} diff --git a/ai/schema/json/rag.schema.json b/ai/schema/json/rag.schema.json new file mode 100644 index 00000000..402c4b37 --- /dev/null +++ b/ai/schema/json/rag.schema.json @@ -0,0 +1,259 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://dubbo-admin.ai/schema/rag.schema.json", + "title": "RAG Component Config", + "type": "object", + "additionalProperties": false, + "required": ["type", "spec"], + "properties": { + "type": { + "const": "rag" + }, + "spec": { + "type": "object", + "additionalProperties": false, + "required": ["embedder", "loader", "splitter", "indexer", "retriever"], + "properties": { + "embedder": { + "$ref": "#/$defs/embedder" + }, + "loader": { + "$ref": "#/$defs/loader" + }, + "splitter": { + "$ref": "#/$defs/splitter" + }, + "indexer": { + "$ref": "#/$defs/indexer" + }, + "retriever": { + "$ref": "#/$defs/retriever" + }, + "reranker": { + "$ref": "#/$defs/reranker" + } + } + } + }, + "$defs": { + "embedder": { + "type": "object", + "additionalProperties": false, + "required": ["spec"], + "properties": { + "type": { + "type": "string", + "default": "genkit" + }, + "spec": { + "type": "object", + "additionalProperties": false, + "required": ["model"], + "properties": { + "model": { + "type": "string", + "minLength": 1, + "default": "dashscope/qwen3-embedding" + } + } + } + } + }, + "loader": { + "type": "object", + "additionalProperties": false, + "required": ["spec"], + "properties": { + "type": { + "type": "string", + "default": "local", + "enum": ["local"] + }, + "spec": { + "type": "object", + "additionalProperties": false + } + } + }, + "splitter": { + "type": "object", + "additionalProperties": false, + "required": ["spec"], + "properties": { + "type": { + "type": "string", + "default": "recursive", + "enum": ["recursive", "markdown_header"] + }, + "spec": { + "oneOf": [ + { + "$ref": "#/$defs/splitter_recursive_spec" + }, + { + "$ref": "#/$defs/splitter_markdown_spec" + } + ] + } + } + }, + "splitter_recursive_spec": { + "type": "object", + "additionalProperties": false, + "properties": { + "chunk_size": { + "type": "integer", + "minimum": 1, + "default": 1000 + }, + "overlap_size": { + "type": "integer", + "minimum": 0, + "default": 100 + } + } + }, + "splitter_markdown_spec": { + "type": "object", + "additionalProperties": false, + "properties": { + "headers": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "default": { + "#": "h1", + "##": "h2", + "###": "h3", + "####": "h4" + } + }, + "trim_headers": { + "type": "boolean", + "default": false + } + } + }, + "indexer": { + "type": "object", + "additionalProperties": false, + "required": ["spec"], + "properties": { + "type": { + "type": "string", + "default": "dev", + "enum": ["dev", "pinecone"] + }, + "spec": { + "$ref": "#/$defs/index_storage_spec" + } + } + }, + "retriever": { + "type": "object", + "additionalProperties": false, + "required": ["spec"], + "properties": { + "type": { + "type": "string", + "default": "dev", + "enum": ["dev", "pinecone"] + }, + "spec": { + "$ref": "#/$defs/index_storage_spec" + } + } + }, + "index_storage_spec": { + "type": "object", + "additionalProperties": false, + "properties": { + "storage_path": { + "type": "string", + "default": "../../data/ai/index" + }, + "index_format": { + "type": "string", + "default": "sqlite" + }, + "dimension": { + "type": "integer", + "minimum": 1, + "default": 1536 + } + } + }, + "reranker": { + "type": "object", + "additionalProperties": false, + "required": ["spec"], + "properties": { + "type": { + "type": "string", + "default": "cohere", + "enum": ["cohere"] + }, + "spec": { + "type": "object", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": false + }, + "model": { + "type": "string", + "default": "rerank-english-v3.0" + }, + "api_key": { + "type": "string" + } + } + } + } + } + }, + "allOf": [ + { + "if": { + "properties": { + "spec": { + "properties": { + "reranker": { + "type": "object", + "required": ["spec"], + "properties": { + "spec": { + "type": "object", + "required": ["enabled"], + "properties": { + "enabled": { + "const": true + } + } + } + } + } + } + } + } + }, + "then": { + "properties": { + "spec": { + "properties": { + "reranker": { + "properties": { + "spec": { + "required": ["api_key"] + } + } + } + } + } + } + } + } + ] +} diff --git a/ai/schema/json/server.schema.json b/ai/schema/json/server.schema.json new file mode 100644 index 00000000..de37ee45 --- /dev/null +++ b/ai/schema/json/server.schema.json @@ -0,0 +1,51 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://dubbo-admin.ai/schema/server.schema.json", + "title": "Server Component Config", + "type": "object", + "additionalProperties": false, + "required": ["type", "spec"], + "properties": { + "type": { + "const": "server" + }, + "spec": { + "type": "object", + "additionalProperties": false, + "properties": { + "port": { + "type": "integer", + "minimum": 1, + "maximum": 65535, + "default": 8888 + }, + "host": { + "type": "string", + "minLength": 1, + "default": "0.0.0.0" + }, + "debug": { + "type": "boolean", + "default": false + }, + "cors_origins": { + "type": "array", + "default": ["*"], + "items": { + "type": "string" + } + }, + "read_timeout": { + "type": "integer", + "minimum": 1, + "default": 30 + }, + "write_timeout": { + "type": "integer", + "minimum": 1, + "default": 30 + } + } + } + } +} diff --git a/ai/schema/json/tools.schema.json b/ai/schema/json/tools.schema.json new file mode 100644 index 00000000..1c3b70d2 --- /dev/null +++ b/ai/schema/json/tools.schema.json @@ -0,0 +1,70 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://dubbo-admin.ai/schema/tools.schema.json", + "title": "Tools Component Config", + "type": "object", + "additionalProperties": false, + "required": ["type", "spec"], + "properties": { + "type": { + "const": "tools" + }, + "spec": { + "type": "object", + "additionalProperties": false, + "properties": { + "enable_mock_tools": { + "type": "boolean", + "default": true + }, + "enable_internal_tools": { + "type": "boolean", + "default": true + }, + "enable_mcp_tools": { + "type": "boolean", + "default": true + }, + "mcp_host_name": { + "type": "string", + "minLength": 1, + "default": "mcp_host" + }, + "mcp_timeout": { + "type": "integer", + "minimum": 1, + "default": 30 + }, + "mcp_max_retries": { + "type": "integer", + "minimum": 0, + "default": 3 + } + } + } + }, + "allOf": [ + { + "if": { + "properties": { + "spec": { + "type": "object", + "required": ["enable_mcp_tools"], + "properties": { + "enable_mcp_tools": { + "const": true + } + } + } + } + }, + "then": { + "properties": { + "spec": { + "required": ["mcp_host_name"] + } + } + } + } + ] +} diff --git a/ai/schema/react.go b/ai/schema/react.go index 0e99eac0..b8de2cec 100644 --- a/ai/schema/react.go +++ b/ai/schema/react.go @@ -5,7 +5,7 @@ import ( "fmt" "reflect" - "dubbo-admin-ai/tools" + toolEngine "dubbo-admin-ai/component/tools/engine" "github.com/firebase/genkit/go/ai" ) @@ -22,11 +22,11 @@ type UserInput struct { } type ThinkInput struct { - UserInput *UserInput `json:"user_input,omitempty"` - SessionID string `json:"session_id"` - ToolResponses []tools.ToolOutput `json:"tool_responses,omitempty"` - Observation *Observation `json:"observation,omitempty"` - UsageInfo *ai.GenerationUsage `json:"usage,omitempty" jsonschema_description:"DO NOT USE THIS FIELD, IT IS FOR INTERNAL USAGE ONLY"` + UserInput *UserInput `json:"user_input,omitempty"` + SessionID string `json:"session_id"` + ToolResponses []toolEngine.ToolOutput `json:"tool_responses,omitempty"` + Observation *Observation `json:"observation,omitempty"` + UsageInfo *ai.GenerationUsage `json:"usage,omitempty" jsonschema_description:"DO NOT USE THIS FIELD, IT IS FOR INTERNAL USAGE ONLY"` } func (i ThinkInput) Validate(T reflect.Type) error { @@ -76,9 +76,9 @@ func (ta ThinkOutput) String() string { } type ToolOutputs struct { - Outputs []tools.ToolOutput `json:"tool_responses"` - Thought string `json:"thought,omitempty"` - UsageInfo *ai.GenerationUsage `json:"usage,omitempty" jsonschema_description:"DO NOT USE THIS FIELD, IT IS FOR INTERNAL USAGE ONLY"` + Outputs []toolEngine.ToolOutput `json:"tool_responses"` + Thought string `json:"thought,omitempty"` + UsageInfo *ai.GenerationUsage `json:"usage,omitempty" jsonschema_description:"DO NOT USE THIS FIELD, IT IS FOR INTERNAL USAGE ONLY"` } func (to ToolOutputs) Validate(T reflect.Type) error { @@ -92,7 +92,7 @@ func (to ToolOutputs) Usage() *ai.GenerationUsage { return to.UsageInfo } -func (to *ToolOutputs) Add(output *tools.ToolOutput) { +func (to *ToolOutputs) Add(output *toolEngine.ToolOutput) { to.Outputs = append(to.Outputs, *output) } diff --git a/ai/server/models.go b/ai/server/models.go deleted file mode 100644 index 5c52a1f2..00000000 --- a/ai/server/models.go +++ /dev/null @@ -1,46 +0,0 @@ -package server - -import ( - "time" - - "github.com/google/uuid" -) - -// Response 统一API响应格式 -type Response struct { - Message string `json:"message"` // 响应消息 - Data any `json:"data,omitempty"` // 响应数据 - RequestID string `json:"request_id"` // 请求ID,用于追踪 - Timestamp int64 `json:"timestamp"` // 响应时间戳 -} - -// NewResponse 创建响应 -func NewResponse(message string, data any) *Response { - return &Response{ - Message: message, - Data: data, - RequestID: generateRequestID(), - Timestamp: time.Now().Unix(), - } -} - -// NewSuccessResponse 创建成功响应 -func NewSuccessResponse(data any) *Response { - return NewResponse("success", data) -} - -// NewErrorResponse 创建错误响应 -func NewErrorResponse(message string) *Response { - return NewResponse(message, nil) -} - -// ChatRequest 流式聊天请求 -type ChatRequest struct { - Message string `json:"message" binding:"required"` // 用户消息 - SessionID string `json:"sessionID" binding:"required"` // 会话ID -} - -// generateRequestID 生成请求ID -func generateRequestID() string { - return "req_" + uuid.New().String() -} diff --git a/ai/test/llm_test.go b/ai/test/llm_test.go deleted file mode 100644 index 5ab872cb..00000000 --- a/ai/test/llm_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package test - -import ( - "context" - "fmt" - "log" - "testing" - - "dubbo-admin-ai/agent/react" - "dubbo-admin-ai/config" - "dubbo-admin-ai/manager" - "dubbo-admin-ai/plugins/dashscope" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/core" - "github.com/firebase/genkit/go/genkit" -) - -type WeatherInput struct { - Location string `json:"location" jsonschema_description:"Location to get weather for"` -} - -func defineWeatherFlow(g *genkit.Genkit) *core.Flow[WeatherInput, string, struct{}] { - getWeatherTool := genkit.DefineTool(g, "getWeather", "Gets the current weather in a given location", - func(ctx *ai.ToolContext, input WeatherInput) (string, error) { - // Here, we would typically make an API call or database query. For this - // example, we just return a fixed value. - log.Printf("Tool 'getWeather' called for location: %s", input.Location) - return fmt.Sprintf("The current weather in %s is 63°F and sunny.", input.Location), nil - }) - - return genkit.DefineFlow(g, "getWeatherFlow", - func(ctx context.Context, location WeatherInput) (string, error) { - resp, err := genkit.Generate(ctx, g, - ai.WithTools(getWeatherTool), - ai.WithPrompt("What's the weather in %s?", location.Location), - ) - if err != nil { - return "", err - } - return resp.Text(), nil - }) -} - -func TestTextGeneration(t *testing.T) { - g := manager.Registry(dashscope.Qwen3.Key(), config.PROJECT_ROOT+"/.env", nil) - _, _ = react.Create(g) - ctx := context.Background() - - resp, err := genkit.GenerateText(ctx, g, ai.WithPrompt("Hello, Who are you?")) - if err != nil { - t.Fatalf("failed to generate text: %v", err) - } - t.Logf("Generated text: %s", resp) - - fmt.Printf("%s", resp) -} - -func TestWeatherFlowRun(t *testing.T) { - g := manager.Registry(dashscope.Qwen3.Key(), config.PROJECT_ROOT+"/.env", nil) - _, _ = react.Create(g) - ctx := context.Background() - - flow := defineWeatherFlow(g) - flow.Run(ctx, WeatherInput{Location: "San Francisco"}) -} \ No newline at end of file diff --git a/ai/test/mcp_test.go b/ai/test/mcp_test.go deleted file mode 100644 index 2dfdb1f6..00000000 --- a/ai/test/mcp_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package test - -import ( - "context" - "fmt" - "os" - "testing" - - "dubbo-admin-ai/config" - "dubbo-admin-ai/manager" - "dubbo-admin-ai/plugins/dashscope" - "dubbo-admin-ai/tools" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" -) - -func TestMCP(t *testing.T) { - ctx := context.Background() - g := manager.Registry(dashscope.Qwen_max.Key(), config.PROJECT_ROOT+"/.env", manager.DevLogger()) - - mcpToolManager, err := tools.NewMCPToolManager(g, "mcpHost") - if err != nil { - t.Fatalf("failed to create MCP tool manager: %v", err) - } - - toolRefs := mcpToolManager.ToolRefs() - prompt, err := os.ReadFile(config.PROMPT_DIR_PATH + "/agentTool.txt") - if err != nil { - t.Fatalf("failed to read prompt file: %v", err) - } - - resp, err := genkit.Generate(ctx, g, - ai.WithSystem(string(prompt)), - ai.WithPrompt("What are the existing namespaces?"), - ai.WithTools(toolRefs...), - ai.WithOutputType(tools.ToolOutput{}), - ) - - if err != nil { - t.Fatalf("failed to generate text: %v", err) - } - - manager.GetLogger().Info("Generated response:", "text", resp.Text()) -} - -func TestMCPFlow(t *testing.T) { - g := manager.Registry(dashscope.Qwen3.Key(), config.PROJECT_ROOT+"/.env", manager.DevLogger()) - flow := genkit.DefineFlow(g, "mcpTest", - func(ctx context.Context, userPrompt string) (string, error) { - mcpToolManager, err := tools.NewMCPToolManager(g, "mcpHost") - if err != nil { - return "", fmt.Errorf("failed to create MCP tool manager: %v", err) - } - - toolRefs := mcpToolManager.ToolRefs() - prompt, err := os.ReadFile(config.PROMPT_DIR_PATH + "/agentSystem.txt") - if err != nil { - return "", fmt.Errorf("failed to read prompt file: %v", err) - } - - resp, err := genkit.Generate(ctx, g, - ai.WithSystem(string(prompt)), - ai.WithPrompt(userPrompt), - ai.WithTools(toolRefs...), - ai.WithReturnToolRequests(true), - ) - - if err != nil { - return "", fmt.Errorf("failed to generate text: %v", err) - } - - return resp.Text(), nil - }) - - resp, err := flow.Run(context.Background(), "List all namespaces in the Kubernetes cluster") - if err != nil { - t.Fatalf("failed to run MCP flow: %v", err) - } - manager.GetLogger().Info("MCP Flow response:", "response", resp) -} diff --git a/ai/test/models.md b/ai/test/models.md new file mode 100644 index 00000000..001b3b81 --- /dev/null +++ b/ai/test/models.md @@ -0,0 +1,146 @@ +# 模型可用性测试指南 + +本指南说明如何测试 `config/models.yaml` 中配置的所有模型是否可用。 + +## 测试方法 + +### 方法1:使用现有测试(推荐) + +AI模块已经有一个简单的文本生成测试,可以快速验证默认模型: + +```bash +cd /Users/liwener/programming/ospp/dubbo-admin/ai + +# 设置API密钥 +export DASHSCOPE_API_KEY="your_qwen_api_key" + +# 运行测试 +go test -v ./test/ -run TestTextGeneration +``` + +### 方法2:手动测试单个模型 + +使用简单的Go程序测试特定模型: + +```go +package main + +import ( + "context" + "fmt" + "log" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/compat_oai" + "github.com/openai/openai-go/option" +) + +func main() { + ctx := context.Background() + + // 初始化 genkit + g := genkit.Init(ctx) + + // 注册模型提供商(以Dashscope为例) + _ = compat_oai.Init(g, "dashscope", + compat_oai.WithConfig(openai.BaseURL("https://dashscope.aliyuncs.com/compatible-mode/v1")), + compat_oai.WithAPIKey(os.Getenv("DASHSCOPE_API_KEY")), + ) + + // 测试模型 + resp, err := genkit.GenerateText(ctx, g, + ai.WithPrompt("Hello, who are you?"), + ) + if err != nil { + log.Fatalf("Failed to generate: %v", err) + } + + fmt.Printf("Response: %s\n", resp) +} +``` + +### 方法3:使用服务器API测试 + +启动AI服务器后,通过HTTP API测试: + +```bash +# 1. 启动服务器 +cd /Users/liwener/programming/ospp/dubbo-admin/ai +go run main.go --config config.yaml + +# 2. 在另一个终端测试 +curl -X POST http://localhost:8888/api/v1/ai/chat/stream \ + -H "Content-Type: application/json" \ + -d '{ + "sessionId": "test-session", + "message": "Hello, who are you?", + "stream": false + }' +``` + +## 配置的模型列表 + +根据 `config/models.yaml`,当前配置了以下模型: + +### Dashscope(通义千问) +- `qwen-max` - 聊天模型(默认) +- `qwen-plus` - 聊天模型 +- `qwen-flash` - 快速聊天模型 +- `qwen3-coder` - 代码生成模型 +- `qwen3-embedding` - 文本嵌入模型 + +### Gemini(Google) +- `gemini-pro` - 聊天模型 +- `gemini-pro-vision` - 多模态模型 +- `text-embedding-004` - 文本嵌入模型 + +### SiliconFlow +- `gpt-3.5-turbo` - 聊天模型 +- `gpt-4` - 聊天模型 +- `text-embedding-ada-002` - 文本嵌入模型 + +## 测试所有模型 + +要测试所有配置的模型,可以: + +1. **设置所有API密钥**: +```bash +export DASHSCOPE_API_KEY="your_qwen_api_key" +export GEMINI_API_KEY="your_gemini_api_key" +export SILICONFLOW_API_KEY="your_siliconflow_key" +``` + +2. **修改测试代码**,遍历所有模型 + +3. **或者使用脚本**: +```bash +for provider in dashscope gemini siliconflow; do + for model in qwen-max gemini-pro gpt-3.5-turbo; do + echo "Testing $provider/$model..." + # 调用API测试 + done +done +``` + +## 常见问题 + +### Q: 模型调用失败怎么办? +A: 检查: +1. API密钥是否正确设置 +2. 网络连接是否正常 +3. API是否有效(额度和限制) +4. 模型名称是否正确 + +### Q: 如何添加新模型? +A: 编辑 `config/models.yaml`,按照现有格式添加。 + +### Q: 测试太慢怎么办? +A: 使用 `go test -short` 跳过耗时测试,或者只测试默认模型。 + +## 相关文件 + +- `config/models.yaml` - 模型配置文件 +- `ai/test/llm_test.go` - 简单模型测试 +- `ai/main.go` - 服务器入口 +- `ai/component/models/component.go` - Models组件实现 diff --git a/ai/test/rag_test.go b/ai/test/rag_test.go deleted file mode 100644 index 2b1dcbca..00000000 --- a/ai/test/rag_test.go +++ /dev/null @@ -1,141 +0,0 @@ -package test - -import ( - "context" - "fmt" - "testing" - - "dubbo-admin-ai/config" - "dubbo-admin-ai/manager" - "dubbo-admin-ai/plugins/dashscope" - "dubbo-admin-ai/utils" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/core" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/pinecone" -) - -func pdf2Docs(pdfPath string, chunkSize, chunkOverlap int) ([]*ai.Document, error) { - chunks, err := utils.SplitPDFWithClean(pdfPath, chunkSize, chunkOverlap) - if err != nil { - return nil, fmt.Errorf("Failed to split PDF into chunks: %w", err) - } - - metadata := map[string]any{ - "description": "A collection of classic cocktail recipes.", - } - - docs := make([]*ai.Document, len(chunks)) - for i, chunk := range chunks { - docs[i] = ai.DocumentFromText(chunk, metadata) - } - - return docs, nil -} - -func docs2Index(g *genkit.Genkit, docs []*ai.Document, indexName, namespace string) error { - ctx := context.Background() - embedder := genkit.LookupEmbedder(g, dashscope.Qwen3_embedding.Key()) - if embedder == nil { - return fmt.Errorf("failed to find embedder %s", dashscope.Qwen3_embedding.Key()) - } - docstore, _, err := pinecone.DefineRetriever(ctx, g, - pinecone.Config{ - IndexID: indexName, - Embedder: embedder, - }, - &ai.RetrieverOptions{ - Label: "cocktail-retriever", - ConfigSchema: core.InferSchemaMap(pinecone.PineconeRetrieverOptions{}), - }) - - if err != nil { - return fmt.Errorf("failed to setup retriever: %w", err) - } - - // 分批索引文档,每批最多10个 - batchSize := 10 - for i := 0; i < len(docs); i += batchSize { - end := min(i+batchSize, len(docs)) - batch := docs[i:end] - manager.GetLogger().Info("正在索引文档", "start", i+1, "end", end, "total", len(docs)) - if err := pinecone.Index(ctx, batch, docstore, namespace); err != nil { - return fmt.Errorf("failed to index documents batch %d-%d: %w", i+1, end, err) - } - manager.GetLogger().Info("成功索引文档", "count", len(batch)) - } - - return nil -} - -func TestChunks(t *testing.T) { - pdfPath := config.PROJECT_ROOT + "/reference/Classic-Cocktails.pdf" - chunks, err := utils.SplitPDFWithClean(pdfPath, 50, 10) - if err != nil { - t.Fatal("Failed to split PDF into chunks:", err) - } - manager.ProductionLogger().Info("成功分割 PDF", "length", len(chunks), "chunks", chunks) -} - -// TestCreateIndex - 创建索引 -func TestCreateIndex(t *testing.T) { - namespace := "cocktails" - pdfPath := config.PROJECT_ROOT + "/reference/Classic-Cocktails.pdf" - docs, err := pdf2Docs(pdfPath, 100, 20) - if err != nil { - t.Fatal("Failed to convert PDF to documents:", err) - } - - g := manager.Registry(config.DEFAULT_MODEL.Key(), config.PROJECT_ROOT+"/.env", manager.ProductionLogger()) - err = docs2Index(g, docs, config.PINECONE_INDEX_NAME, namespace) - if err != nil { - t.Fatal("Failed to create index:", err) - } -} - -// func TestReRank(t *testing.T) { -// queries := []string{ -// "请给我玛格丽特鸡尾酒的配方", -// "Please give me the recipe of Margarita cocktail", -// "What are the ingredients of whiskey sour?", -// "How to make a vodka martini?", -// "What are tropical fruit cocktails?", -// } - -// } - -// TestSearch - 搜索测试(可以运行多次) -func TestSearch(t *testing.T) { - g := manager.Registry(config.DEFAULT_MODEL.Key(), config.PROJECT_ROOT+"/.env", manager.ProductionLogger()) - indexName := config.PINECONE_INDEX_NAME - queries := []string{ - "请给我玛格丽特鸡尾酒的配方", - "Please give me the recipe of Margarita cocktail", - "What are the ingredients of whiskey sour?", - "How to make a vodka martini?", - "What are tropical fruit cocktails?", - } - results, err := utils.RetrieveFromPinecone(g, dashscope.Qwen3_embedding.Key(), indexName, "cocktails", queries, 10, true, 5) - if err != nil { - t.Fatalf("search in pinecone failed: %v", err) - } - - for query, docs := range results { - manager.GetLogger().Info("搜索结果", "query", query, "result", docs) - } - -} - -func TestRerank(t *testing.T) { - g := manager.Registry(config.DEFAULT_MODEL.Key(), config.PROJECT_ROOT+"/.env", manager.ProductionLogger()) - indexName := config.PINECONE_INDEX_NAME - queries := []string{ - "What are the ingredients of whiskey sour?", - } - results, err := utils.RetrieveFromPinecone(g, dashscope.Qwen3_embedding.Key(), indexName, "cocktails", queries, 10, true, 5) - if err != nil { - t.Fatalf("search in pinecone failed: %v", err) - } - manager.GetLogger().Info("重排序结果", "result", results) -} diff --git a/ai/testutils/fixtures.go b/ai/testutils/fixtures.go new file mode 100644 index 00000000..dc3c3f46 --- /dev/null +++ b/ai/testutils/fixtures.go @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package testutils + +import ( + "gopkg.in/yaml.v3" +) + +// ConfigFixture 提供各组件的测试配置 +type ConfigFixture struct{} + +// ValidLoggerConfig 返回有效的Logger配置 +func (f *ConfigFixture) ValidLoggerConfig(level string) *yaml.Node { + var node yaml.Node + node.Encode(map[string]any{ + "level": level, + }) + return &node +} + +// InvalidLoggerConfig 返回无效的Logger配置(空配置) +func (f *ConfigFixture) InvalidLoggerConfig() *yaml.Node { + var node yaml.Node + node.Encode(map[string]any{}) + return &node +} + +// ValidMemoryConfig 返回有效的Memory配置 +func (f *ConfigFixture) ValidMemoryConfig() *yaml.Node { + var node yaml.Node + node.Encode(map[string]any{ + "history_key": "chat_history", + "max_turns": 100, + }) + return &node +} + +// InvalidMemoryConfig 返回无效的Memory配置 +func (f *ConfigFixture) InvalidMemoryConfig() *yaml.Node { + var node yaml.Node + node.Encode(map[string]any{ + "max_turns": "invalid", // 错误的类型 + }) + return &node +} + +// ValidModelsConfig 返回有效的Models配置 +func (f *ConfigFixture) ValidModelsConfig() *yaml.Node { + var node yaml.Node + node.Encode(map[string]any{ + "default_provider": "dashscope", + "default_model": "dashscope/qwen-max", + "default_embedding": "dashscope/qwen3-embedding", + "providers": map[string]any{ + "dashscope": map[string]any{ + "api_key": "test-key-${DASHSCOPE_API_KEY}", + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "models": []map[string]any{ + { + "name": "qwen-max", + "key": "qwen-max", + "type": "chat", + }, + }, + "embedders": []map[string]any{ + { + "name": "qwen3-embedding", + "key": "qwen3-embedding", + "type": "text", + "dimensions": 1024, + }, + }, + }, + }, + }) + return &node +} + +// ModelsConfigWithEmptyAPIKey 返回包含空API Key的Models配置 +func (f *ConfigFixture) ModelsConfigWithEmptyAPIKey() *yaml.Node { + var node yaml.Node + node.Encode(map[string]any{ + "default_model": "dashscope/qwen-max", + "default_embedding": "dashscope/qwen3-embedding", + "providers": map[string]any{ + "dashscope": map[string]any{ + "api_key": "", // 空API Key + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "models": []map[string]any{ + { + "name": "qwen-max", + "key": "qwen-max", + "type": "chat", + }, + }, + }, + }, + }) + return &node +} + +// ModelsConfigWithMultipleProviders 返回包含多个Provider的Models配置 +func (f *ConfigFixture) ModelsConfigWithMultipleProviders() *yaml.Node { + var node yaml.Node + node.Encode(map[string]any{ + "default_model": "dashscope/qwen-max", + "default_embedding": "dashscope/qwen3-embedding", + "providers": map[string]any{ + "dashscope": map[string]any{ + "api_key": "test-dashscope-key", + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "models": []map[string]any{ + {"name": "qwen-max", "key": "qwen-max", "type": "chat"}, + }, + "embedders": []map[string]any{ + {"name": "qwen3-embedding", "key": "qwen3-embedding", "dimensions": 1024}, + }, + }, + "gemini": map[string]any{ + "api_key": "", // Gemini需要特殊处理 + "base_url": "https://generativelanguage.googleapis.com", + "models": []map[string]any{ + {"name": "gemini-pro", "key": "gemini-pro", "type": "chat"}, + }, + }, + }, + }) + return &node +} + +// ValidToolsConfig 返回有效的Tools配置 +func (f *ConfigFixture) ValidToolsConfig() *yaml.Node { + var node yaml.Node + node.Encode(map[string]any{ + "enable_mock_tools": true, + "enable_internal_tools": false, + "enable_mcp_tools": false, + }) + return &node +} + +// ToolsConfigWithAllEnabled 返回所有工具都启用的配置 +func (f *ConfigFixture) ToolsConfigWithAllEnabled() *yaml.Node { + var node yaml.Node + node.Encode(map[string]any{ + "enable_mock_tools": true, + "enable_internal_tools": true, + "enable_mcp_tools": true, + }) + return &node +} + +// ValidRAGConfig 返回有效的RAG配置 +func (f *ConfigFixture) ValidRAGConfig() *yaml.Node { + var node yaml.Node + node.Encode(map[string]any{ + "loader": map[string]any{ + "type": "pdf", + }, + "splitter": map[string]any{ + "type": "recursive", + "chunk_size": 1000, + "chunk_overlap": 200, + }, + "indexer": map[string]any{ + "type": "pinecone", + "index_name": "test-index", + "dimension": 1024, + }, + "retriever": map[string]any{ + "type": "vector", + "top_k": 10, + }, + }) + return &node +} + +// ValidAgentConfig 返回有效的Agent配置 +func (f *ConfigFixture) ValidAgentConfig() *yaml.Node { + var node yaml.Node + node.Encode(map[string]any{ + "agent_type": "react", + "default_model": "qwen-max", + "prompt_base_path": "./prompts", + "max_iterations": 10, + "stages": []map[string]any{ + { + "name": "agentThinking", + "flow_type": "think", + "prompt_file": "agentThink.txt", + "temperature": 0.7, + "enable_tools": true, + }, + { + "name": "agentTool", + "flow_type": "act", + "prompt_file": "agentTool.txt", + "temperature": 0.7, + "enable_tools": true, + }, + }, + }) + return &node +} + +// ValidServerConfig 返回有效的Server配置 +func (f *ConfigFixture) ValidServerConfig() *yaml.Node { + var node yaml.Node + node.Encode(map[string]any{ + "host": "localhost", + "port": 8080, + }) + return &node +} + +// ServerConfigWithCustomPort 返回自定义端口的Server配置 +func (f *ConfigFixture) ServerConfigWithCustomPort(port int) *yaml.Node { + var node yaml.Node + node.Encode(map[string]any{ + "host": "0.0.0.0", + "port": port, + }) + return &node +} diff --git a/ai/testutils/helpers.go b/ai/testutils/helpers.go new file mode 100644 index 00000000..5f3bf90f --- /dev/null +++ b/ai/testutils/helpers.go @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package testutils + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "dubbo-admin-ai/runtime" + + "github.com/firebase/genkit/go/genkit" +) + +// SetupTestEnvironment 设置测试环境 +// 返回清理函数 +func SetupTestEnvironment(t *testing.T) func() { + // 保存当前工作目录 + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Failed to get current directory: %v", err) + } + + // 切换到项目根目录 + if err := os.Chdir("../.."); err != nil { + t.Fatalf("Failed to change directory: %v", err) + } + + // 返回清理函数 + return func() { + _ = os.Chdir(origDir) + } +} + +// AssertNoError 辅助函数:断言没有错误 +func AssertNoError(t *testing.T, err error, msg string) { + t.Helper() + if err != nil { + t.Fatalf("%s: %v", msg, err) + } +} + +// AssertError 辅助函数:断言有错误 +func AssertError(t *testing.T, err error, msg string) { + t.Helper() + if err == nil { + t.Fatalf("%s: expected error but got nil", msg) + } +} + +// AssertEqual 辅助函数:断言相等 +func AssertEqual[T comparable](t *testing.T, got, want T, msg string) { + t.Helper() + if got != want { + t.Fatalf("%s: got %v, want %v", msg, got, want) + } +} + +// AssertNotEqual 辅助函数:断言不相等 +func AssertNotEqual[T comparable](t *testing.T, got, notWant T, msg string) { + t.Helper() + if got == notWant { + t.Fatalf("%s: got %v, should not equal %v", msg, got, notWant) + } +} + +// AssertNotNil 辅助函数:断言非空 +func AssertNotNil(t *testing.T, v interface{}, msg string) { + t.Helper() + if v == nil { + t.Fatalf("%s: value is nil", msg) + } +} + +// AssertNil 辅助函数:断言为空 +func AssertNil(t *testing.T, v interface{}, msg string) { + t.Helper() + if v != nil { + t.Fatalf("%s: expected nil but got %v", msg, v) + } +} + +// RunComponentLifecycleTest 运行组件生命周期测试 +func RunComponentLifecycleTest(t *testing.T, comp runtime.Component) { + t.Helper() + + // 测试Name + name := comp.Name() + if name == "" { + t.Error("Component name should not be empty") + } + + // 创建真实的Runtime用于Init测试 + rt := runtime.NewRuntime() + rt.SetGenkitRegistry(CreateMockGenkitRegistry(t)) + + // 测试Init + err := comp.Init(rt) + if err != nil { + t.Errorf("Component.Init() failed: %v", err) + } + + // 测试Start + err = comp.Start() + if err != nil { + t.Errorf("Component.Start() failed: %v", err) + } + + // 测试Stop + err = comp.Stop() + if err != nil { + t.Errorf("Component.Stop() failed: %v", err) + } +} + +// CreateMockGenkitRegistry 创建Mock Genkit Registry +func CreateMockGenkitRegistry(t *testing.T) *genkit.Genkit { + t.Helper() + ctx := context.Background() + g := genkit.Init(ctx) + return g +} + +// SetupComponentWithRuntime 使用Runtime设置组件 +func SetupComponentWithRuntime(t *testing.T, comp runtime.Component) *runtime.Runtime { + t.Helper() + rt := runtime.NewRuntime() + rt.SetGenkitRegistry(CreateMockGenkitRegistry(t)) + rt.RegisterComponent(comp) + err := comp.Init(rt) + if err != nil { + t.Fatalf("Failed to initialize component: %v", err) + } + return rt +} + +// CleanupComponent 清理组件资源 +func CleanupComponent(t *testing.T, comp runtime.Component) { + t.Helper() + if err := comp.Stop(); err != nil { + t.Errorf("Failed to stop component: %v", err) + } +} + +// GetTestAPIKey 获取测试用的API Key +func GetTestAPIKey(provider string) string { + // 首先检查环境变量 + envKey := fmt.Sprintf("%s_API_KEY", provider) + if key := os.Getenv(envKey); key != "" { + return key + } + + // 返回测试用的假Key + return fmt.Sprintf("test-%s-api-key", provider) +} + +// SkipIfMissingAPIKey 如果缺少API Key则跳过测试 +func SkipIfMissingAPIKey(t *testing.T, provider string) { + t.Helper() + key := GetTestAPIKey(provider) + if key == "" || key == fmt.Sprintf("test-%s-api-key", provider) { + t.Skipf("Skipping test: %s_API_KEY not set", provider) + } +} + +// MeasureTime 测量函数执行时间 +func MeasureTime(fn func()) time.Duration { + start := time.Now() + fn() + return time.Since(start) +} + +// AssertExecutionTime 断言执行时间在指定范围内 +func AssertExecutionTime(t *testing.T, fn func(), min, max time.Duration, msg string) { + t.Helper() + duration := MeasureTime(fn) + if duration < min { + t.Errorf("%s: execution time %v is less than minimum %v", msg, duration, min) + } + if duration > max { + t.Errorf("%s: execution time %v exceeds maximum %v", msg, duration, max) + } +} diff --git a/ai/testutils/mocks.go b/ai/testutils/mocks.go new file mode 100644 index 00000000..46c96945 --- /dev/null +++ b/ai/testutils/mocks.go @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package testutils + +import ( + "context" + "fmt" + "log/slog" + "sync" + + "dubbo-admin-ai/runtime" + + "github.com/firebase/genkit/go/genkit" +) + +// MockRuntime 提供测试用的Mock Runtime +// 它实现了runtime.Runtime接口,但简化了某些方法 +type MockRuntime struct { + components map[string]runtime.Component + mu sync.RWMutex + logger *slog.Logger + genkit *genkit.Genkit + factories map[string]runtime.ComponentFactory + factoryOrder []string +} + +// NewMockRuntime 创建一个新的MockRuntime实例 +func NewMockRuntime() *MockRuntime { + ctx := context.Background() + g := genkit.Init(ctx) + + return &MockRuntime{ + components: make(map[string]runtime.Component), + logger: slog.Default(), + genkit: g, + factories: make(map[string]runtime.ComponentFactory), + factoryOrder: make([]string, 0), + } +} + +// RegisterComponent 注册一个组件 +func (m *MockRuntime) RegisterComponent(comp runtime.Component) { + m.mu.Lock() + defer m.mu.Unlock() + m.components[comp.Name()] = comp +} + +// GetComponent 根据名称获取组件 +func (m *MockRuntime) GetComponent(name string) (runtime.Component, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + comp, ok := m.components[name] + if !ok { + return nil, fmt.Errorf("component not found: %s", name) + } + return comp, nil +} + +// GetLogger 返回logger +func (m *MockRuntime) GetLogger() *slog.Logger { + return m.logger +} + +// GetContext 返回context +func (m *MockRuntime) GetContext() context.Context { + return context.Background() +} + +// GetGenkitRegistry 返回genkit registry +func (m *MockRuntime) GetGenkitRegistry() *genkit.Genkit { + return m.genkit +} + +// SetGenkitRegistry 设置genkit registry +func (m *MockRuntime) SetGenkitRegistry(registry *genkit.Genkit) { + m.mu.Lock() + defer m.mu.Unlock() + m.genkit = registry +} + +// RegisterFactory 注册组件工厂 +func (m *MockRuntime) RegisterFactory(componentType string, factory runtime.ComponentFactory) { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.factories[componentType]; !exists { + m.factoryOrder = append(m.factoryOrder, componentType) + } + + m.factories[componentType] = factory +} + +// GetFactoryFn 获取工厂函数 +func (m *MockRuntime) GetFactoryFn(componentType string) (runtime.ComponentFactory, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + factory, exists := m.factories[componentType] + if !exists { + return nil, fmt.Errorf("component type '%s' not registered", componentType) + } + + return factory, nil +} diff --git a/ai/tools/memory.go b/ai/tools/memory.go deleted file mode 100644 index 5bb46c50..00000000 --- a/ai/tools/memory.go +++ /dev/null @@ -1,82 +0,0 @@ -package tools - -import ( - "dubbo-admin-ai/config" - "dubbo-admin-ai/memory" - "dubbo-admin-ai/utils" - "fmt" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" -) - -const ( - GetAllMemoryTool string = "memory_all_by_session_id" - RetrieveBasicConceptFromK8SDocTool string = "retrieve_basic_concept_from_k8s_doc" -) - -type MemoryToolInput struct { - SessionID string `json:"session_id"` -} - -func defineMemoryTools(g *genkit.Genkit, history *memory.History) []ai.Tool { - tools := []ai.Tool{ - getAllMemoryBySession(g, history), - RetrieveBasicConceptFromK8SDoc(g, config.EMBEDDING_MODEL.Key(), config.K8S_RAG_INDEX, config.RAG_TOP_K, config.RERANK_TOP_N), - } - return tools -} - -func getAllMemoryBySession(g *genkit.Genkit, history *memory.History) ai.Tool { - return genkit.DefineTool( - g, GetAllMemoryTool, "Get all history memory messages of a session by input `session_id`", - func(ctx *ai.ToolContext, input MemoryToolInput) (ToolOutput, error) { - if input.SessionID == "" { - return ToolOutput{}, fmt.Errorf("sessionID is required") - } - - if history.IsEmpty(input.SessionID) { - return ToolOutput{ - ToolName: GetAllMemoryTool, - Summary: "No memory available", - }, nil - } - - return ToolOutput{ - ToolName: GetAllMemoryTool, - Result: history.AllMemory(input.SessionID), - Summary: "", - }, nil - }, - ) -} - -type K8SRAGQueryInput struct { - Querys []string `json:"query"` -} - -const ( - K8S_CONCEPTS_NAMESPACE string = "concepts" -) - -func RetrieveBasicConceptFromK8SDoc(g *genkit.Genkit, embedder, indexName string, topK, topN int) ai.Tool { - return genkit.DefineTool( - g, RetrieveBasicConceptFromK8SDocTool, "Retrieve the basic kubernetes concepts from RAG", - func(ctx *ai.ToolContext, input K8SRAGQueryInput) (ToolOutput, error) { - if input.Querys == nil { - return ToolOutput{}, fmt.Errorf("query is required") - } - - results, err := utils.RetrieveFromPinecone(g, embedder, indexName, K8S_CONCEPTS_NAMESPACE, input.Querys, topK, config.RERANK_ENABLE, topN) - if err != nil { - return ToolOutput{}, fmt.Errorf("failed to retrieve from RAG: %w", err) - } - - return ToolOutput{ - ToolName: RetrieveBasicConceptFromK8SDocTool, - Result: results, - Summary: fmt.Sprintf("Retrieved %d results", len(results)), - }, nil - }, - ) -} diff --git a/ai/utils/rag.go b/ai/utils/rag.go deleted file mode 100644 index d09a740e..00000000 --- a/ai/utils/rag.go +++ /dev/null @@ -1,785 +0,0 @@ -package utils - -import ( - "context" - "fmt" - "io" - "os" - "path/filepath" - "regexp" - "strings" - "unicode" - "unicode/utf8" - - "dubbo-admin-ai/config" - "dubbo-admin-ai/manager" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/core" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/pinecone" - "github.com/gomarkdown/markdown/ast" - "github.com/gomarkdown/markdown/parser" - "github.com/ledongthuc/pdf" - "github.com/tmc/langchaingo/textsplitter" - - cohere "github.com/cohere-ai/cohere-go/v2" - cohereClient "github.com/cohere-ai/cohere-go/v2/client" -) - -type PineconeResult struct { - Content string - RelevanceScore float64 -} - -func IndexInPinecone(g *genkit.Genkit, indexName string, namespace string, embedderName string, metadata map[string]any, chunks []string) error { - docs := make([]*ai.Document, len(chunks)) - for i, chunk := range chunks { - docs[i] = ai.DocumentFromText(chunk, metadata) - } - - ctx := context.Background() - embedder := genkit.LookupEmbedder(g, embedderName) - if embedder == nil { - return fmt.Errorf("failed to find embedder %s", embedderName) - } - docstore, _, err := pinecone.DefineRetriever(ctx, g, - pinecone.Config{ - IndexID: indexName, - Embedder: embedder, - }, - &ai.RetrieverOptions{ - Label: indexName, - ConfigSchema: core.InferSchemaMap(pinecone.PineconeRetrieverOptions{}), - }) - - if err != nil { - return fmt.Errorf("failed to setup retriever: %w", err) - } - // 分批索引文档,每批最多10个 - batchSize := 10 - for i := 0; i < len(docs); i += batchSize { - end := min(i+batchSize, len(docs)) - batch := docs[i:end] - manager.GetLogger().Info("正在索引文档", "start", i+1, "end", end, "total", len(docs)) - if err := pinecone.Index(ctx, batch, docstore, namespace); err != nil { - return fmt.Errorf("failed to index documents batch %d-%d: %w", i+1, end, err) - } - manager.GetLogger().Info("成功索引文档", "count", len(batch)) - } - return nil -} - -func RetrieveFromPinecone(g *genkit.Genkit, embedderName, indexName, namespace string, queries []string, topK int, rerank bool, topN int) (resp map[string][]*PineconeResult, err error) { - ctx := context.Background() - embedder := genkit.LookupEmbedder(g, embedderName) - if embedder == nil { - return nil, fmt.Errorf("failed to find embedder %s", embedderName) - } - - // Define retriever with embedder - var retriever ai.Retriever - if !pinecone.IsDefinedRetriever(g, indexName) { - _, retriever, err = pinecone.DefineRetriever(ctx, g, - pinecone.Config{ - IndexID: indexName, - Embedder: embedder, - }, - &ai.RetrieverOptions{ - Label: indexName, - ConfigSchema: core.InferSchemaMap(pinecone.PineconeRetrieverOptions{}), - }) - } else { - retriever = pinecone.Retriever(g, indexName) - } - - if err != nil { - return nil, fmt.Errorf("failed to define retriever: %w", err) - } - - // Search for each query - option := &pinecone.PineconeRetrieverOptions{ - K: topK, - Namespace: namespace, - } - resp = make(map[string][]*PineconeResult, len(queries)) - for _, query := range queries { - response, err := retriever.Retrieve(ctx, &ai.RetrieverRequest{ - Query: ai.DocumentFromText(query, nil), - Options: option, - }) - if err != nil { - return nil, fmt.Errorf("failed to search for query '%s': %v", query, err) - } - - resp[query] = make([]*PineconeResult, 0, len(response.Documents)) - for _, doc := range response.Documents { - // 初始化 PineconeResult,暂时不设置 RelevanceScore(将在 rerank 后设置) - result := &PineconeResult{ - Content: doc.Content[0].Text, - RelevanceScore: 0, - } - resp[query] = append(resp[query], result) - } - } - - results := make(map[string][]*PineconeResult, len(queries)) - if rerank && len(queries) > 0 { - for query, docs := range resp { - // 提取文档内容用于 rerank - docTexts := make([]*string, len(docs)) - for i, doc := range docs { - docTexts[i] = &doc.Content - } - - rerankRes, err := Rerank(config.COHERE_API_KEY, config.RERANK_MODEL, query, docTexts, topN) - if err != nil { - return nil, err - } - - // 根据 rerank 结果构建最终结果,包含 RelevanceScore - for _, res := range rerankRes { - originalDoc := docs[res.Index] - resultDoc := &PineconeResult{ - Content: originalDoc.Content, - RelevanceScore: res.RelevanceScore, - } - results[query] = append(results[query], resultDoc) - } - } - } - - return results, nil -} - -func Rerank(apiKey, model, query string, documents []*string, topN int) ([]*cohere.RerankResponseResultsItem, error) { - client := cohereClient.NewClient(cohereClient.WithToken(apiKey)) - - var rerankDocs []*cohere.RerankRequestDocumentsItem - for _, doc := range documents { - rerankDoc := &cohere.RerankRequestDocumentsItem{} - rerankDoc.String = *doc - rerankDocs = append(rerankDocs, rerankDoc) - } - - rerankResponse, err := client.Rerank( - context.Background(), - &cohere.RerankRequest{ - Query: query, - Documents: rerankDocs, - TopN: &topN, - Model: &model, - }, - ) - if err != nil { - return nil, fmt.Errorf("failed to call rerank API: %w", err) - } - - return rerankResponse.Results, nil -} - -// Helper function to extract plain text from a PDF. -func ReadPDF(path string) (string, error) { - f, r, err := pdf.Open(path) - if f != nil { - defer f.Close() - } - if err != nil { - return "", err - } - - reader, err := r.GetPlainText() - if err != nil { - return "", err - } - - bytes, err := io.ReadAll(reader) - if err != nil { - return "", err - } - - return string(bytes), nil -} - -// pdfTextCleaner - 清洗从PDF中提取的文本数据 -func pdfTextCleaner(text string) string { - // 1. 移除控制字符和不可打印字符(保留换行符、制表符和普通空格) - cleaned := "" - for _, r := range text { - if r == '\n' || r == '\t' || r == ' ' || (r >= 32 && r < 127) || r > 127 { - // 保留换行符、制表符、空格、可打印ASCII字符和非ASCII字符(如中文) - cleaned += string(r) - } - } - - // 2. 移除多余的空白字符和换行符 - cleaned = strings.ReplaceAll(cleaned, "\n \n", "\n") - cleaned = strings.ReplaceAll(cleaned, " \n", "\n") - cleaned = strings.ReplaceAll(cleaned, "\n ", "\n") - - // 3. 将多个连续的换行符合并为单个换行符 - multipleNewlines := regexp.MustCompile(`\n{3,}`) - cleaned = multipleNewlines.ReplaceAllString(cleaned, "\n\n") - - // 4. 移除单独的字符行(可能是PDF解析错误) - lines := strings.Split(cleaned, "\n") - var cleanedLines []string - - for _, line := range lines { - line = strings.TrimSpace(line) - // 跳过空行 - if line == "" { - continue - } - // 跳过只有1个字符的行(通常是PDF解析错误) - if len(line) <= 1 { - continue - } - // 跳过只包含特殊字符的行 - if regexp.MustCompile(`^[^\w\s]+$`).MatchString(line) { - continue - } - cleanedLines = append(cleanedLines, line) - } - - // 5. 重新组合文本 - result := strings.Join(cleanedLines, "\n") - - // 6. 清理常见的PDF解析问题 - // 移除单独的数字(可能是页码) - result = regexp.MustCompile(`(?m)^\d+$`).ReplaceAllString(result, "") - - // 移除多余的空格 - result = regexp.MustCompile(`\s+`).ReplaceAllString(result, " ") - - // 恢复合理的换行 - result = strings.ReplaceAll(result, " \n", "\n") - result = strings.ReplaceAll(result, "\n ", "\n") - - // 7. 最后的清理 - result = strings.TrimSpace(result) - - return result -} - -func SplitPDFWithClean(pdfPath string, chunkSize, chunkOverlap int) ([]string, error) { - pdfText, err := ReadPDF(pdfPath) - if err != nil { - return nil, fmt.Errorf("failed to read PDF: %w", err) - } - cleanedText := pdfTextCleaner(pdfText) - - splitter := textsplitter.NewRecursiveCharacter( - textsplitter.WithChunkSize(chunkSize), - textsplitter.WithChunkOverlap(chunkOverlap), - ) - chunks, err := splitter.SplitText(cleanedText) - if err != nil { - return nil, fmt.Errorf("failed to split text: %w", err) - } - - return chunks, nil -} - -// MarkdownCleaner 用于清洗 Markdown 文档为 RAG 友好的纯文本 -type MarkdownCleaner struct { - // 配置选项 - preserveCodeContent bool // 是否保留代码块内容 - preserveListStructure bool // 是否保留列表结构 - preserveTableContent bool // 是否保留表格内容 - maxLineLength int // 最大行长度 - - // 内部状态 - result strings.Builder - inList bool - listDepth int - inTable bool -} - -// NewMarkdownCleaner 创建新的清洗器实例 -func NewMarkdownCleaner() *MarkdownCleaner { - return &MarkdownCleaner{ - preserveCodeContent: true, - preserveListStructure: true, - preserveTableContent: true, - maxLineLength: 500, - } -} - -// SetOptions 设置清洗器选项 -func (c *MarkdownCleaner) SetOptions(preserveCode, preserveList, preserveTable bool, maxLineLen int) { - c.preserveCodeContent = preserveCode - c.preserveListStructure = preserveList - c.preserveTableContent = preserveTable - c.maxLineLength = maxLineLen -} - -// Clean 清洗 Markdown 文本 -func (c *MarkdownCleaner) Clean(markdown string) string { - // 重置前置元数据和内部状态 - c.result.Reset() - c.inList = false - c.listDepth = 0 - c.inTable = false - - // 预处理:删除 frontmatter 和 Hugo shortcodes - markdown = c.removeFrontmatter(markdown) - // 移除 HugoShortcodes - hugoRe := regexp.MustCompile(`{{<[^>]+>}}|{{%[^%]+%}}`) - markdown = hugoRe.ReplaceAllString(markdown, "") - - // 创建解析器,包含 Frontmatter 扩展 - extensions := parser.CommonExtensions | parser.Mmark | parser.Footnotes - p := parser.NewWithExtensions(extensions) - - // 解析 markdown 为 AST - doc := p.Parse([]byte(markdown)) - - // 遍历 AST 并提取内容 - c.walkAST(doc) - - // 后处理 - cleaned := c.postProcess(c.result.String()) - - return cleaned -} - -// walkAST 遍历 AST 节点 -func (c *MarkdownCleaner) walkAST(node ast.Node) { - if node == nil { - return - } - - switch n := node.(type) { - case *ast.Document: - c.processChildren(n) - - case *ast.Heading: - c.result.WriteString("\r\n") - c.processHeading(n) - - case *ast.Paragraph: - c.processParagraph(n) - c.result.WriteString("\r\n") - - case *ast.List: - c.processList(n) - c.result.WriteString("\n") - - case *ast.ListItem: - c.processListItem(n) - - case *ast.CodeBlock: - c.processCodeBlock(n) - c.result.WriteString("\n") - - case *ast.Table: - c.result.WriteString("\r\n") - c.processTable(n) - c.result.WriteString("\r\n") - - case *ast.TableRow: - c.processTableRow(n) - c.result.WriteString("\n") - - case *ast.TableCell: - c.processTableCell(n) - - case *ast.Text: - c.processText(n) - - case *ast.Emph, *ast.Strong: - c.processChildren(n) - - case *ast.Link: - c.processLink(n) - - case *ast.Image: - c.processImage(n) - - case *ast.Code: - c.processInlineCode(n) - - case *ast.Softbreak, *ast.Hardbreak: - c.result.WriteString(" ") - - case *ast.BlockQuote: - c.processBlockQuote(n) - - case *ast.HorizontalRule: - c.result.WriteString("\n") - - default: - c.result.WriteString("") - } -} - -// processChildren 处理子节点 -func (c *MarkdownCleaner) processChildren(node ast.Node) { - children := node.GetChildren() - for _, child := range children { - c.walkAST(child) - } -} - -// processHeading 处理标题 -func (c *MarkdownCleaner) processHeading(h *ast.Heading) { - c.result.WriteString("# ") - c.processChildren(h) - c.result.WriteString(": ") -} - -// processParagraph 处理段落 -func (c *MarkdownCleaner) processParagraph(p *ast.Paragraph) { - c.processChildren(p) -} - -// processList 处理列表 -func (c *MarkdownCleaner) processList(l *ast.List) { - if !c.preserveListStructure { - c.processChildren(l) - return - } - - wasInList := c.inList - c.inList = true - c.listDepth++ - - c.processChildren(l) - - c.listDepth-- - if c.listDepth == 0 { - c.inList = wasInList - } -} - -// processListItem 处理列表项 -func (c *MarkdownCleaner) processListItem(li *ast.ListItem) { - if c.preserveListStructure { - indent := strings.Repeat(" ", c.listDepth-1) - c.result.WriteString(indent + "- ") - } - c.processChildren(li) -} - -// processCodeBlock 处理代码块 -func (c *MarkdownCleaner) processCodeBlock(cb *ast.CodeBlock) { - if !c.preserveCodeContent { - return - } - code := string(cb.Literal) - if string(cb.Info) != "" { - c.result.WriteString(string(cb.Info) + ":") - } - cleanCode := c.cleanCodeContent(code) - c.result.WriteString(cleanCode) -} - -// processTable 处理表格 -func (c *MarkdownCleaner) processTable(t *ast.Table) { - if !c.preserveTableContent { - return - } - - c.inTable = true - c.processChildren(t) - c.inTable = false -} - -// processTableRow 处理表格行 -func (c *MarkdownCleaner) processTableRow(tr *ast.TableRow) { - if !c.preserveTableContent { - return - } - - var cellContents []string - - // 收集单元格内容 - children := tr.GetChildren() - for _, cell := range children { - if tableCell, ok := cell.(*ast.TableCell); ok { - var cellBuilder strings.Builder - tempResult := c.result - c.result = cellBuilder - c.processChildren(tableCell) - c.result = tempResult - - content := strings.TrimSpace(cellBuilder.String()) - if content != "" { - cellContents = append(cellContents, content) - } - } - } - - // 输出行内容 - if len(cellContents) > 0 { - c.result.WriteString(strings.Join(cellContents, " | ")) - } -} - -// processTableCell 处理表格单元格 -func (c *MarkdownCleaner) processTableCell(tc *ast.TableCell) { - c.processChildren(tc) -} - -// processText 处理纯文本 -func (c *MarkdownCleaner) processText(t *ast.Text) { - text := string(t.Literal) - cleanText := c.cleanText(text) - c.result.WriteString(cleanText) -} - -// processLink 处理链接 -func (c *MarkdownCleaner) processLink(l *ast.Link) { - c.result.WriteString(" [") - c.processChildren(l) - c.result.WriteString("] ") -} - -// processImage 处理图片 -func (c *MarkdownCleaner) processImage(img *ast.Image) { - c.processChildren(img) -} - -// processInlineCode 处理行内代码 -func (c *MarkdownCleaner) processInlineCode(code *ast.Code) { - if c.preserveCodeContent { - cleanCode := c.cleanText(string(code.Literal)) - c.result.WriteString(cleanCode) - } -} - -// processBlockQuote 处理引用块 -func (c *MarkdownCleaner) processBlockQuote(bq *ast.BlockQuote) { - c.processChildren(bq) -} - -// cleanText 清洗文本内容 -func (c *MarkdownCleaner) cleanText(text string) string { - // 移除HTML标签 - htmlRe := regexp.MustCompile(`<[^>]*>`) - text = htmlRe.ReplaceAllString(text, "") - - // 移除多余的空白字符 - spaceRe := regexp.MustCompile(`\s+`) - text = spaceRe.ReplaceAllString(text, " ") - - // 移除控制字符 - text = strings.Map(func(r rune) rune { - if unicode.IsControl(r) && r != '\n' && r != '\r' && r != '\t' { - return -1 - } - return r - }, text) - - return strings.TrimSpace(text) -} - -func (c *MarkdownCleaner) cleanCodeContent(code string) string { - lines := strings.Split(code, "\n") - var cleanLines []string - for _, line := range lines { - line = strings.TrimSpace(line) - if len(line) == 0 { - continue - } - cleanLines = append(cleanLines, line) - } - return strings.Join(cleanLines, " ") -} - -// removeFrontmatter 删除 Markdown 文本开头的 frontmatter -func (c *MarkdownCleaner) removeFrontmatter(markdown string) string { - // 检查是否以 --- 开头 - if !strings.HasPrefix(markdown, "---") { - return markdown - } - - lines := strings.Split(markdown, "\n") - if len(lines) < 3 { - return markdown - } - - // 寻找第二个 --- - endIndex := -1 - for i := 1; i < len(lines); i++ { - if strings.TrimSpace(lines[i]) == "---" { - endIndex = i - break - } - } - - // 如果找到结束标记,删除整个 frontmatter 块 - if endIndex > 0 { - // 返回 frontmatter 之后的内容,跳过空行 - remainingLines := lines[endIndex+1:] - for len(remainingLines) > 0 && strings.TrimSpace(remainingLines[0]) == "" { - remainingLines = remainingLines[1:] - } - return strings.Join(remainingLines, "\n") - } - - return markdown -} - -// postProcess 后处理清洗后的文本 -func (c *MarkdownCleaner) postProcess(text string) string { - // 移除多余的空行 - multiNewlineRe := regexp.MustCompile(`\n{3,}`) - text = multiNewlineRe.ReplaceAllString(text, "\n\n") - - // 移除首尾空白 - text = strings.TrimSpace(text) - - // 确保UTF-8编码有效 - if !utf8.ValidString(text) { - text = strings.ToValidUTF8(text, "") - } - - return text -} - -// CleanMarkdownFile 从文件路径读取并清洗 Markdown 文件 -func CleanMarkdownFile(mdPath string) (string, error) { - // 检查文件是否存在 - if _, err := os.Stat(mdPath); os.IsNotExist(err) { - return "", fmt.Errorf("文件不存在: %s", mdPath) - } - - // 读取文件内容 - content, err := os.ReadFile(mdPath) - if err != nil { - return "", fmt.Errorf("读取文件失败: %w", err) - } - - // 检查文件是否为空 - if len(content) == 0 { - return "", fmt.Errorf("文件为空: %s", mdPath) - } - - // 创建清洗器并进行清洗 - cleaner := NewMarkdownCleaner() - - // 可以根据需要调整配置 - // cleaner.SetOptions(preserveCode, preserveList, preserveTable, extractFrontMatter, maxLineLength) - - // 清洗内容 - cleaned := cleaner.Clean(string(content)) - - return cleaned, nil -} - -// BatchCleanMarkdownFiles 批量处理多个 Markdown 文件 -func BatchCleanMarkdownFiles(mdPaths []string) (map[string]string, error) { - results := make(map[string]string) - var errs []string - - for _, path := range mdPaths { - cleaned, err := CleanMarkdownFile(path) - if err != nil { - errs = append(errs, fmt.Sprintf("%s: %v", path, err)) - continue - } - results[path] = cleaned - } - - if len(errs) > 0 { - return results, fmt.Errorf("部分文件处理失败: %s", strings.Join(errs, "; ")) - } - - return results, nil -} - -func MDSplitter() textsplitter.TextSplitter { - return textsplitter.NewRecursiveCharacter( - textsplitter.WithChunkSize(1000), - textsplitter.WithChunkOverlap(100), - textsplitter.WithSeparators([]string{"\r\n\r\n", "\n\n", "\r\n", "\n", " ", ""}), - ) -} - -// ProcessMarkdownDirectory 处理目录下的所有 Markdown 文件,清洗并分块 -func ProcessMarkdownDirectory(dirPath string) ([]string, error) { - // 获取目录下的所有 .md 文件 - mdFiles, err := getMDFiles(dirPath) - if err != nil { - return nil, fmt.Errorf("获取 MD 文件失败: %w", err) - } - - if len(mdFiles) == 0 { - return nil, fmt.Errorf("目录中未找到 .md 文件: %s", dirPath) - } - - // 创建清洗器和分割器 - cleaner := NewMarkdownCleaner() - splitter := MDSplitter() - - var allChunks []string - - // 处理每个 MD 文件 - for _, filePath := range mdFiles { - chunks, err := processMarkdownFile(filePath, cleaner, splitter) - if err != nil { - // 记录错误但继续处理其他文件 - fmt.Printf("处理文件 %s 时出错: %v\n", filePath, err) - continue - } - - allChunks = append(allChunks, chunks...) - } - - return allChunks, nil -} - -// getMDFiles 递归获取目录下的所有 .md 文件 -func getMDFiles(dirPath string) ([]string, error) { - var mdFiles []string - - err := filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - // 检查是否为 .md 文件 - if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".md") { - mdFiles = append(mdFiles, path) - } - - return nil - }) - - if err != nil { - return nil, err - } - - return mdFiles, nil -} - -// processMarkdownFile 处理单个 Markdown 文件,清洗并分块 -func processMarkdownFile(filePath string, cleaner *MarkdownCleaner, splitter textsplitter.TextSplitter) ([]string, error) { - // 读取文件内容 - content, err := os.ReadFile(filePath) - if err != nil { - return nil, fmt.Errorf("读取文件失败: %w", err) - } - - // 检查文件是否为空 - if len(content) == 0 { - return nil, fmt.Errorf("文件为空: %s", filePath) - } - - // 清洗 Markdown 内容 - cleaned := cleaner.Clean(string(content)) - - // 如果清洗后内容为空,跳过 - if strings.TrimSpace(cleaned) == "" { - return nil, fmt.Errorf("清洗后内容为空: %s", filePath) - } - - // 使用分割器进行分块 - chunks, err := splitter.SplitText(cleaned) - if err != nil { - return nil, fmt.Errorf("分块失败: %w", err) - } - return chunks, nil -} diff --git a/ai/utils/utils.go b/ai/utils/utils.go index 682996e6..b1acce57 100644 --- a/ai/utils/utils.go +++ b/ai/utils/utils.go @@ -3,10 +3,90 @@ package utils import ( "fmt" "io" + "maps" "os" "path/filepath" + "strings" + + "github.com/cloudwego/eino/schema" + "github.com/firebase/genkit/go/ai" ) +// ToGenkitDocument converts an Eino schema.Document to a Genkit ai.Document +func ToGenkitDocument(doc *schema.Document) *ai.Document { + if doc == nil { + return nil + } + // Create text part + part := ai.NewTextPart(doc.Content) + + // Copy metadata + meta := make(map[string]any) + maps.Copy(meta, doc.MetaData) + // Store ID in metadata if present + if doc.ID != "" { + meta["_id"] = doc.ID + } + + return &ai.Document{ + Content: []*ai.Part{part}, + Metadata: meta, + } +} + +// ToEinoDocument converts a Genkit ai.Document to an Eino schema.Document +func ToEinoDocument(doc *ai.Document) *schema.Document { + if doc == nil { + return nil + } + + // Extract text content + var contentBuilder strings.Builder + for _, part := range doc.Content { + if part.IsText() { + contentBuilder.WriteString(part.Text) + } + } + + // Copy metadata and extract ID + meta := make(map[string]any) + var id string + maps.Copy(meta, doc.Metadata) + if strID, ok := doc.Metadata["_id"].(string); ok { + id = strID + } + + return &schema.Document{ + ID: id, + Content: contentBuilder.String(), + MetaData: meta, + } +} + +// ToGenkitDocuments converts a slice of Eino schema.Documents to Genkit ai.Documents +func ToGenkitDocuments(docs []*schema.Document) []*ai.Document { + if docs == nil { + return nil + } + res := make([]*ai.Document, len(docs)) + for i, doc := range docs { + res[i] = ToGenkitDocument(doc) + } + return res +} + +// ToEinoDocuments converts a slice of Genkit ai.Documents to Eino schema.Documents +func ToEinoDocuments(docs []*ai.Document) []*schema.Document { + if docs == nil { + return nil + } + res := make([]*schema.Document, len(docs)) + for i, doc := range docs { + res[i] = ToEinoDocument(doc) + } + return res +} + // CopyFile copies source file content to target file, creates the file if target doesn't exist // srcPath: source file path // dstPath: target file path diff --git a/ai/utils/utils_test.go b/ai/utils/utils_test.go deleted file mode 100644 index 90ef9298..00000000 --- a/ai/utils/utils_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package utils_test - -import ( - "context" - "dubbo-admin-ai/config" - "dubbo-admin-ai/manager" - "dubbo-admin-ai/plugins/dashscope" - "dubbo-admin-ai/tools" - "dubbo-admin-ai/utils" - "fmt" - "testing" - - "github.com/tmc/langchaingo/textsplitter" -) - -func TestMdCleaner(t *testing.T) { - mdPath := "/Users/liwener/programming/ospp/dubbo-admin/ai/reference/k8s_docs/concepts/overview/kubernetes-api.md" - cleaned, err := utils.CleanMarkdownFile(mdPath) - if err != nil { - t.Fatalf("err: %v\n", err) - } - fmt.Printf("清洗后的内容:\n%s\n", cleaned) -} - -func TestMdChunks(t *testing.T) { - mdPath := "/Users/liwener/programming/ospp/dubbo-admin/ai/reference/k8s_docs/concepts/overview/kubernetes-api.md" - cleaned, err := utils.CleanMarkdownFile(mdPath) - if err != nil { - t.Fatalf("err: %v\n", err) - } - splitter := textsplitter.NewRecursiveCharacter( - textsplitter.WithChunkSize(1000), - textsplitter.WithChunkOverlap(100), - textsplitter.WithSeparators([]string{"\r\n\r\n", "\n\n", "\r\n", "\n", " ", ""}), - ) - - chunks, err := splitter.SplitText(cleaned) - if err != nil { - t.Fatalf("err: %v\n", err) - } - for i, chunk := range chunks { - fmt.Printf("第 %d 个chunk:\n%s\n\n", i+1, chunk) - } -} - -func TestMdChunksInDir(t *testing.T) { - mdDir := "/Users/liwener/programming/ospp/dubbo-admin/ai/reference/k8s_docs/concepts/overview" - - chunks, err := utils.ProcessMarkdownDirectory(mdDir) - if err != nil { - t.Fatalf("err: %v\n", err) - } - g := manager.Registry(dashscope.Qwen3.Key(), config.PROJECT_ROOT+"/.env", manager.DevLogger()) - err = utils.IndexInPinecone(g, "kubernetes", "concepts", dashscope.Qwen3_embedding.Key(), nil, chunks) - if err != nil { - t.Fatalf("err: %v\n", err) - } -} - -func TestMdRetrive(t *testing.T) { - g := manager.Registry(dashscope.Qwen3.Key(), config.PROJECT_ROOT+"/.env", manager.ProductionLogger()) - query := []string{ - "什么是 Pod?", - "什么是 Deployment?", - "Kubernetes网络模式", - } - results, err := utils.RetrieveFromPinecone(g, dashscope.Qwen3_embedding.Key(), "kubernetes", "concepts", query, 10, true, 3) - - if err != nil { - t.Fatalf("err: %v\n", err) - } - - for q, docs := range results { - fmt.Printf("查询: %s\n", q) - for i, doc := range docs { - fmt.Printf("结果 %d: %v\n\n", i+1, doc) - } - } -} - -func TestRAGTool(t *testing.T) { - g := manager.Registry(dashscope.Qwen3.Key(), config.PROJECT_ROOT+"/.env", manager.ProductionLogger()) - ragTool := tools.RetrieveBasicConceptFromK8SDoc(g, dashscope.Qwen3_embedding.Key(), "kube-docs", 10, 3) - toolOutput, err := ragTool.RunRaw(context.Background(), tools.K8SRAGQueryInput{Querys: []string{"什么是 Deployment?"}}) - if err != nil { - t.Fatalf("err: %v\n", err) - } - fmt.Printf("RAG Tool Output: \n%v+\n\n", toolOutput) -}