RAGflow/api/db/services/dialog_service.py

561 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import binascii
import time
from functools import partial
import re
from copy import deepcopy
from timeit import default_timer as timer
from agentic_reasoning import DeepResearcher
from api.db import LLMType, ParserType, StatusEnum
from api.db.db_models import Dialog, DB
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService, LLMBundle
from api import settings
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format
from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.tavily_conn import Tavily
class DialogService(CommonService):
model = Dialog
@classmethod
@DB.connection_context()
def get_list(cls, tenant_id,
page_number, items_per_page, orderby, desc, id, name):
chats = cls.model.select()
if id:
chats = chats.where(cls.model.id == id)
if name:
chats = chats.where(cls.model.name == name)
chats = chats.where(
(cls.model.tenant_id == tenant_id)
& (cls.model.status == StatusEnum.VALID.value)
)
if desc:
chats = chats.order_by(cls.model.getter_by(orderby).desc())
else:
chats = chats.order_by(cls.model.getter_by(orderby).asc())
chats = chats.paginate(page_number, items_per_page)
return list(chats.dicts())
def chat_solo(dialog, messages, stream=True):
if llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
prompt_config = dialog.prompt_config
tts_mdl = None
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
for m in messages if m["role"] != "system"]
if stream:
last_ans = ""
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt":"", "created_at": time.time()}
else:
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
def chat(dialog, messages, stream=True, **kwargs):
# 确保最后一条消息来自用户,否则抛出异常
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
# 如果对话没有关联知识库则调用chat_solo函数进行简单对话处理
if not dialog.kb_ids:
for ans in chat_solo(dialog, messages, stream):
yield ans
return
# 记录聊天开始时间,用于后续性能分析
chat_start_ts = timer()
# 根据对话配置的LLM类型获取模型配置
if llm_id2llm_type(dialog.llm_id) == "image2text":
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
# 获取模型支持的最大token数默认为8192
max_tokens = llm_model_config.get("max_tokens", 8192)
# 记录检查LLM配置的时间点
check_llm_ts = timer()
# 获取对话关联的所有知识库
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
# 提取所有知识库使用的嵌入模型ID并确保它们使用相同的嵌入模型
embedding_list = list(set([kb.embd_id for kb in kbs]))
if len(embedding_list) != 1:
# 如果知识库使用了不同的嵌入模型,返回错误信息
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
# 获取嵌入模型名称
embedding_model_name = embedding_list[0]
# 获取检索器实例
retriever = settings.retrievaler
# 提取最近3条用户消息作为问题上下文
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
# 处理附件文档ID
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
if "doc_ids" in messages[-1]:
attachments = messages[-1]["doc_ids"]
# 记录创建检索器的时间点
create_retriever_ts = timer()
# 初始化嵌入模型
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
# 如果嵌入模型不存在,抛出异常
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embedding_model_name)
# 记录绑定嵌入模型的时间点
bind_embedding_ts = timer()
# 初始化聊天模型
if llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
# 记录绑定LLM模型的时间点
bind_llm_ts = timer()
# 获取提示词配置
prompt_config = dialog.prompt_config
# 获取字段映射用于SQL检索
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
tts_mdl = None
# 初始化文本转语音模型(如果配置中启用)
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
# try to use sql if field mapping is good to go
# 尝试使用SQL检索如果有字段映射
if field_map:
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
if ans:
# 如果SQL检索成功直接返回结果
yield ans
return
# 处理提示词配置中的参数
for p in prompt_config["parameters"]:
# 跳过knowledge参数它将在后面处理
if p["key"] == "knowledge":
continue
# 检查必需参数是否提供
if p["key"] not in kwargs and not p["optional"]:
raise KeyError("Miss parameter: " + p["key"])
# 如果参数未提供且为可选,则在提示词中替换为空格
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace(
"{%s}" % p["key"], " ")
# 处理多轮对话开启并且问题数量大于1
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
# 使用full_question函数将多轮对话合并为一个完整问题
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
else:
# 否则只使用最后一个问题
questions = questions[-1:]
# 记录问题精炼的时间点
refine_question_ts = timer()
# 初始化重排序模型(如果配置)
rerank_mdl = None
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
# 记录绑定重排序模型的时间点
bind_reranker_ts = timer()
generate_keyword_ts = bind_reranker_ts
# 初始化思考过程和知识库信息
thought = ""
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
# 检查是否需要知识库检索
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
knowledges = []
else:
# 如果启用了关键词提取,则增强问题
if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
generate_keyword_ts = timer()
# 获取所有知识库的表ID
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
knowledges = []
# 如果启用了推理功能使用DeepResearcher进行深度推理
if prompt_config.get("reasoning", False):
reasoner = DeepResearcher(chat_mdl,
prompt_config,
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3))
# 执行推理过程
for think in reasoner.thinking(kbinfos, " ".join(questions)):
if isinstance(think, str):
thought = think
knowledges = [t for t in think.split("\n") if t]
elif stream:
yield think
else:
# 使用标准检索方法
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs)
)
# 如果配置了Tavily API使用外部搜索增强检索结果
if prompt_config.get("tavily_api_key"):
tav = Tavily(prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(" ".join(questions))
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
# 如果启用了知识图谱,使用知识图谱检索结果
if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(" ".join(questions),
tenant_ids,
dialog.kb_ids,
embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
# 将检索到的知识格式化为提示词
knowledges = kb_prompt(kbinfos, max_tokens)
# 记录检索到的知识
logging.debug(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
# 记录检索完成的时间点
retrieval_ts = timer()
# 如果没有检索到知识且配置了空响应,则返回空响应
if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
# 将检索到的知识添加到kwargs中用于格式化系统提示词
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
# 获取LLM生成配置
gen_conf = dialog.llm_setting
# 构建消息列表,包括系统提示词和用户消息
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
for m in messages if m["role"] != "system"])
# 确保消息列表不超过模型最大token限制
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
prompt = msg[0]["content"]
# 调整生成配置中的max_tokens确保不超过模型限制
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(
gen_conf["max_tokens"],
max_tokens - used_token_count)
def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
refs = []
# 处理思考过程,如果答案包含</think>标签
ans = answer.split("</think>")
think = ""
if len(ans) == 2:
think = ans[0] + "</think>"
answer = ans[1]
# 如果有知识且启用了引用功能,添加引用
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer, idx = retriever.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs
refs = deepcopy(kbinfos)
for c in refs["chunks"]:
if c.get("vector"):
del c["vector"]
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
# 记录完成时间
finish_chat_ts = timer()
# 计算各阶段耗时
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
prompt += "\n\n### Query:\n%s" % " ".join(questions)
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
# 根据stream参数决定是流式输出还是批量输出
if stream:
# 流式输出模式
last_ans = ""
answer = ""
# 调用LLM的流式生成接口
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
# 如果有思考过程,从回答中移除思考部分
if thought:
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
answer = ans
# 计算增量回答(与上次输出相比新增的内容)
delta_ans = ans[len(last_ans):]
# 如果增量内容太少,跳过本次输出
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
# 生成包含回答、引用和音频的输出
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans):]
# 处理最后一部分增量内容
if delta_ans:
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# 最后输出装饰后的完整回答(包含引用和性能指标)
yield decorate_answer(thought+answer)
else:
# 非流式输出模式
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
res = decorate_answer(answer)
res["audio_binary"] = tts(tts_mdl, answer)
yield res
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}
Question are as follows:
{}
Please write the SQL, only SQL, without any other explanations or text.
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question
)
tried_times = 0
def get_table():
nonlocal sys_prompt, user_prompt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
"temperature": 0.06})
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"([;]|```).*", "", sql)
if sql[:len("select ")] != "select ":
return None, None
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
if sql[:len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
else:
flds = []
for k in field_map.keys():
if k in forbidden_select_fields4resume:
continue
if len(flds) > 11:
break
flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table()
if tbl is None:
return None
if tbl.get("error") and tried_times <= 2:
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}
Question are as follows:
{}
Please write the SQL, only SQL, without any other explanations or text.
The SQL error you provided last time is as follows:
{}
Error issued by database as follows:
{}
Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question, sql, tbl["error"]
)
tbl, sql = get_table()
logging.debug("TRY it again: {}".format(sql))
logging.debug("GET table: {}".format(tbl))
if tbl.get("error") or len(tbl["rows"]) == 0:
return None
docid_idx = set([ii for ii, c in enumerate(
tbl["columns"]) if c["name"] == "doc_id"])
doc_name_idx = set([ii for ii, c in enumerate(
tbl["columns"]) if c["name"] == "docnm_kwd"])
column_idx = [ii for ii in range(
len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
# compose Markdown table
columns = "|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"],
tbl["columns"][i]["name"])) for i in
column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \
("|------|" if docid_idx and docid_idx else "")
rows = ["|" +
"|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
"|" for r in tbl["rows"]]
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
if quota:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
else:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
if not docid_idx or not doc_name_idx:
logging.warning("SQL missing field: " + sql)
return {
"answer": "\n".join([columns, line, rows]),
"reference": {"chunks": [], "doc_aggs": []},
"prompt": sys_prompt
}
docid_idx = list(docid_idx)[0]
doc_name_idx = list(doc_name_idx)[0]
doc_aggs = {}
for r in tbl["rows"]:
if r[docid_idx] not in doc_aggs:
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
doc_aggs[r[docid_idx]]["count"] += 1
return {
"answer": "\n".join([columns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
doc_aggs.items()]},
"prompt": sys_prompt
}
def tts(tts_mdl, text):
if not tts_mdl or not text:
return
bin = b""
for chunk in tts_mdl.tts(text):
bin += chunk
return binascii.hexlify(bin).decode("utf-8")
def ask(question, kb_ids, tenant_id):
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
1, 12, 0.1, 0.3, aggs=False,
rank_feature=label_question(question, kbs)
)
knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """
Role: You're a smart assistant. Your name is Miss R.
Task: Summarize the information from knowledge bases and answer user's question.
Requirements and restriction:
- DO NOT make things up, especially for numbers.
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
- Answer with markdown format text.
- Answer in language of user's question.
- DO NOT make things up, especially for numbers.
### Information from knowledge bases
%s
The above is information from knowledge bases.
""" % "\n".join(knowledges)
msg = [{"role": "user", "content": question}]
def decorate_answer(answer):
nonlocal knowledges, kbinfos, prompt
answer, idx = retriever.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=0.7,
vtweight=0.3)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs
refs = deepcopy(kbinfos)
for c in refs["chunks"]:
if c.get("vector"):
del c["vector"]
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
return {"answer": answer, "reference": chunks_format(refs)}
answer = ""
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)