Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 159 additions & 15 deletions astrbot/core/provider/sources/dify_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,21 +205,165 @@ async def text_chat_stream(
model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
if image_urls is None:
image_urls = []
session_id = session_id or kwargs.get("user") or "unknown"
conversation_id = self.conversation_ids.get(session_id, "")

files_payload = []
for image_url in image_urls:
image_path = (
await download_image_by_url(image_url)
if image_url.startswith("http")
else image_url
)
file_response = await self.api_client.file_upload(
image_path,
user=session_id,
)
logger.debug(f"Dify 上传图片响应:{file_response}")
if "id" not in file_response:
logger.warning(
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。",
)
continue
files_payload.append(
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response["id"],
},
)

# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_var = await sp.session_get(session_id, "session_variables", default={})
payload_vars.update(session_var)
payload_vars["system_prompt"] = system_prompt

try:
match self.api_type:
case "chat" | "agent" | "chatflow":
if not prompt:
prompt = "请描述这张图片。"

accumulated_text = ""
async for chunk in self.api_client.chat_messages(
inputs={
**payload_vars,
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
files=files_payload,
timeout=self.timeout,
):
logger.debug(f"dify resp chunk: {chunk}")
if (
chunk["event"] == "message"
or chunk["event"] == "agent_message"
):
accumulated_text += chunk["answer"]
if not conversation_id:
self.conversation_ids[session_id] = chunk[
"conversation_id"
]
conversation_id = chunk["conversation_id"]

# Yield streaming chunk
llm_response = LLMResponse(
role="assistant",
result_chain=MessageChain(
chain=[Comp.Plain(chunk["answer"])]
),
is_chunk=True,
)
yield llm_response

elif chunk["event"] == "message_end":
logger.debug("Dify message end")
break
elif chunk["event"] == "error":
logger.error(f"Dify 出现错误:{chunk}")
yield LLMResponse(
role="err",
completion_text=f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}",
)
return

# Yield final complete result
chain = MessageChain(chain=[Comp.Plain(accumulated_text)])
yield LLMResponse(
role="assistant", result_chain=chain, is_chunk=False
)

case "workflow":
workflow_result = None
async for chunk in self.api_client.workflow_run(
inputs={
self.dify_query_input_key: prompt,
"astrbot_session_id": session_id,
**payload_vars,
},
user=session_id,
files=files_payload,
timeout=self.timeout,
):
match chunk["event"]:
case "workflow_started":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。",
)
case "node_finished":
logger.debug(
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。",
)
case "workflow_finished":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束",
)
logger.debug(f"Dify 工作流结果:{chunk}")
if chunk["data"]["error"]:
logger.error(
f"Dify 工作流出现错误:{chunk['data']['error']}",
)
yield LLMResponse(
role="err",
completion_text=f"Dify 工作流出现错误:{chunk['data']['error']}",
)
return
if (
self.workflow_output_key
not in chunk["data"]["outputs"]
):
yield LLMResponse(
role="err",
completion_text=f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}",
)
return
workflow_result = chunk

if workflow_result:
chain = await self.parse_dify_result(workflow_result)
yield LLMResponse(
role="assistant", result_chain=chain, is_chunk=False
)
else:
logger.warning("Dify 工作流请求结果为空,请查看 Debug 日志。")
yield LLMResponse(
role="err", completion_text="Dify 工作流请求结果为空"
)

case _:
yield LLMResponse(
role="err",
completion_text=f"未知的 Dify API 类型:{self.api_type}",
)

except Exception as e:
logger.error(f"Dify 请求失败:{e!s}")
yield LLMResponse(role="err", completion_text=f"Dify 请求失败:{e!s}")

async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
if isinstance(chunk, str):
Expand Down