diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 9f9f146aa..ac1960b90 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -205,21 +205,165 @@ async def text_chat_stream( model=None, **kwargs, ): - # raise NotImplementedError("This method is not implemented yet.") - # 调用 text_chat 模拟流式 - llm_response = await self.text_chat( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - ) - llm_response.is_chunk = True - yield llm_response - llm_response.is_chunk = False - yield llm_response + if image_urls is None: + image_urls = [] + session_id = session_id or kwargs.get("user") or "unknown" + conversation_id = self.conversation_ids.get(session_id, "") + + files_payload = [] + for image_url in image_urls: + image_path = ( + await download_image_by_url(image_url) + if image_url.startswith("http") + else image_url + ) + file_response = await self.api_client.file_upload( + image_path, + user=session_id, + ) + logger.debug(f"Dify 上传图片响应:{file_response}") + if "id" not in file_response: + logger.warning( + f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。", + ) + continue + files_payload.append( + { + "type": "image", + "transfer_method": "local_file", + "upload_file_id": file_response["id"], + }, + ) + + # 获得会话变量 + payload_vars = self.variables.copy() + # 动态变量 + session_var = await sp.session_get(session_id, "session_variables", default={}) + payload_vars.update(session_var) + payload_vars["system_prompt"] = system_prompt + + try: + match self.api_type: + case "chat" | "agent" | "chatflow": + if not prompt: + prompt = "请描述这张图片。" + + accumulated_text = "" + async for chunk in self.api_client.chat_messages( + inputs={ + **payload_vars, + }, + query=prompt, + user=session_id, + conversation_id=conversation_id, + files=files_payload, + timeout=self.timeout, + ): + logger.debug(f"dify resp chunk: {chunk}") + if ( + chunk["event"] == "message" + or chunk["event"] == "agent_message" + ): + accumulated_text += chunk["answer"] + if not conversation_id: + self.conversation_ids[session_id] = chunk[ + "conversation_id" + ] + conversation_id = chunk["conversation_id"] + + # Yield streaming chunk + llm_response = LLMResponse( + role="assistant", + result_chain=MessageChain( + chain=[Comp.Plain(chunk["answer"])] + ), + is_chunk=True, + ) + yield llm_response + + elif chunk["event"] == "message_end": + logger.debug("Dify message end") + break + elif chunk["event"] == "error": + logger.error(f"Dify 出现错误:{chunk}") + yield LLMResponse( + role="err", + completion_text=f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}", + ) + return + + # Yield final complete result + chain = MessageChain(chain=[Comp.Plain(accumulated_text)]) + yield LLMResponse( + role="assistant", result_chain=chain, is_chunk=False + ) + + case "workflow": + workflow_result = None + async for chunk in self.api_client.workflow_run( + inputs={ + self.dify_query_input_key: prompt, + "astrbot_session_id": session_id, + **payload_vars, + }, + user=session_id, + files=files_payload, + timeout=self.timeout, + ): + match chunk["event"]: + case "workflow_started": + logger.info( + f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。", + ) + case "node_finished": + logger.debug( + f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。", + ) + case "workflow_finished": + logger.info( + f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束", + ) + logger.debug(f"Dify 工作流结果:{chunk}") + if chunk["data"]["error"]: + logger.error( + f"Dify 工作流出现错误:{chunk['data']['error']}", + ) + yield LLMResponse( + role="err", + completion_text=f"Dify 工作流出现错误:{chunk['data']['error']}", + ) + return + if ( + self.workflow_output_key + not in chunk["data"]["outputs"] + ): + yield LLMResponse( + role="err", + completion_text=f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}", + ) + return + workflow_result = chunk + + if workflow_result: + chain = await self.parse_dify_result(workflow_result) + yield LLMResponse( + role="assistant", result_chain=chain, is_chunk=False + ) + else: + logger.warning("Dify 工作流请求结果为空,请查看 Debug 日志。") + yield LLMResponse( + role="err", completion_text="Dify 工作流请求结果为空" + ) + + case _: + yield LLMResponse( + role="err", + completion_text=f"未知的 Dify API 类型:{self.api_type}", + ) + + except Exception as e: + logger.error(f"Dify 请求失败:{e!s}") + yield LLMResponse(role="err", completion_text=f"Dify 请求失败:{e!s}") async def parse_dify_result(self, chunk: dict | str) -> MessageChain: if isinstance(chunk, str):