RAGflow/api/db/services/dialog_service.py

548 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import binascii
import logging
import re
import time
from copy import deepcopy
from timeit import default_timer as timer
from api import settings
from api.db import LLMType, ParserType, StatusEnum
from api.db.db_models import DB, Dialog
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantLLMService
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts import chunks_format, citation_prompt, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
from rag.utils import num_tokens_from_string, rmSpace
from .database import MINIO_CONFIG
class DialogService(CommonService):
model = Dialog
@classmethod
@DB.connection_context()
def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
chats = cls.model.select()
if id:
chats = chats.where(cls.model.id == id)
if name:
chats = chats.where(cls.model.name == name)
chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value))
if desc:
chats = chats.order_by(cls.model.getter_by(orderby).desc())
else:
chats = chats.order_by(cls.model.getter_by(orderby).asc())
chats = chats.paginate(page_number, items_per_page)
return list(chats.dicts())
def chat_solo(dialog, messages, stream=True):
if llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
prompt_config = dialog.prompt_config
tts_mdl = None
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
if stream:
last_ans = ""
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans
delta_ans = ans[len(last_ans) :]
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
else:
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids:
for ans in chat_solo(dialog, messages, stream):
yield ans
return
chat_start_ts = timer()
if llm_id2llm_type(dialog.llm_id) == "image2text":
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
max_tokens = llm_model_config.get("max_tokens", 8192)
check_llm_ts = timer()
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
if len(embedding_list) != 1:
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
embedding_model_name = embedding_list[0]
retriever = settings.retrievaler
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
if "doc_ids" in messages[-1]:
attachments = messages[-1]["doc_ids"]
create_retriever_ts = timer()
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embedding_model_name)
bind_embedding_ts = timer()
if llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
bind_llm_ts = timer()
prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
tts_mdl = None
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
# try to use sql if field mapping is good to go
if field_map:
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
if ans:
yield ans
return
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":
continue
if p["key"] not in kwargs and not p["optional"]:
raise KeyError("Miss parameter: " + p["key"])
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
questions = questions[-1:]
refine_question_ts = timer()
rerank_mdl = None
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
bind_reranker_ts = timer()
generate_keyword_ts = bind_reranker_ts
thought = ""
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
knowledges = []
else:
if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
generate_keyword_ts = timer()
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
knowledges = []
kbinfos = retriever.retrieval(
" ".join(questions),
embd_mdl,
tenant_ids,
dialog.kb_ids,
1,
dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs),
)
knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)}
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
gen_conf = dialog.llm_setting
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
prompt4citation = ""
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
prompt4citation = citation_prompt()
# 过滤掉 system 角色的消息(因为前面已经单独处理了系统消息)
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
prompt = msg[0]["content"]
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
refs = []
ans = answer.split("</think>")
think = ""
if len(ans) == 2:
think = ans[0] + "</think>"
answer = ans[1]
cited_chunk_indices = set()
inserted_images = {}
processed_image_urls = set()
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
# 获取引用的 chunk 索引
if not re.search(r"##[0-9]+\$\$", answer):
answer, idx = retriever.insert_citations(
answer,
[ck["content_ltks"] for ck in kbinfos["chunks"]],
[ck["vector"] for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight,
)
cited_chunk_indices = idx
else:
for r in re.finditer(r"##([0-9]+)\$\$", answer):
i = int(r.group(1))
if i < len(kbinfos["chunks"]):
cited_chunk_indices.add(i)
# 处理图片插入
def insert_image_markdown(match):
idx = int(match.group(1))
if idx >= len(kbinfos["chunks"]):
return match.group(0)
chunk = kbinfos["chunks"][idx]
img_path = chunk.get("image_id")
if not img_path:
return match.group(0)
protocol = "https" if MINIO_CONFIG.get("secure", False) else "http"
img_url = f"{protocol}://{MINIO_CONFIG['visit_point']}/{img_path}"
if img_url in processed_image_urls:
return match.group(0)
processed_image_urls.add(img_url)
inserted_images[idx] = img_url
# 插入图片,并限制最大宽度
return f'{match.group(0)}\n\n<img src="{img_url}" alt="{img_url}" style="max-width:800px;">'
# 用正则替换插图
answer = re.sub(r"##(\d+)\$\$", insert_image_markdown, answer)
# 清理引用文献信息
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in cited_chunk_indices])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs
refs = deepcopy(kbinfos)
for c in refs["chunks"]:
if c.get("vector"):
del c["vector"]
# 特殊错误提示
if "invalid key" in answer.lower() or "invalid api" in answer.lower():
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
# 时间信息拼接
finish_chat_ts = timer()
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
prompt += "\n\n### Query:\n%s" % " ".join(questions)
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
if stream:
last_ans = "" # 记录上一次返回的完整回答
answer = "" # 当前累计的完整回答
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
# 如果存在思考过程(thought),移除相关标记
if thought:
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
answer = ans
# 计算新增的文本片段(delta)
delta_ans = ans[len(last_ans) :]
# 如果新增token太少(小于16),跳过本次返回(避免频繁发送小片段)
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
# 返回当前累计回答(包含思考过程)+新增片段)
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans) :]
if delta_ans:
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought + answer)
else:
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
res = decorate_answer(answer)
res["audio_binary"] = tts(tts_mdl, answer)
yield res
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}
Question are as follows:
{}
Please write the SQL, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
tried_times = 0
def get_table():
nonlocal sys_prompt, user_prompt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL)
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"([;]|```).*", "", sql)
if sql[: len("select ")] != "select ":
return None, None
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
if sql[: len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
else:
flds = []
for k in field_map.keys():
if k in forbidden_select_fields4resume:
continue
if len(flds) > 11:
break
flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table()
if tbl is None:
return None
if tbl.get("error") and tried_times <= 2:
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}
Question are as follows:
{}
Please write the SQL, only SQL, without any other explanations or text.
The SQL error you provided last time is as follows:
{}
Error issued by database as follows:
{}
Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
tbl, sql = get_table()
logging.debug("TRY it again: {}".format(sql))
logging.debug("GET table: {}".format(tbl))
if tbl.get("error") or len(tbl["rows"]) == 0:
return None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
# compose Markdown table
columns = (
"|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
)
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
rows = ["|" + "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
if quota:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
else:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
if not docid_idx or not doc_name_idx:
logging.warning("SQL missing field: " + sql)
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
docid_idx = list(docid_idx)[0]
doc_name_idx = list(doc_name_idx)[0]
doc_aggs = {}
for r in tbl["rows"]:
if r[docid_idx] not in doc_aggs:
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
doc_aggs[r[docid_idx]]["count"] += 1
return {
"answer": "\n".join([columns, line, rows]),
"reference": {
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()],
},
"prompt": sys_prompt,
}
def tts(tts_mdl, text):
if not tts_mdl or not text:
return
bin = b""
for chunk in tts_mdl.tts(text):
bin += chunk
return binascii.hexlify(bin).decode("utf-8")
def ask(question, kb_ids, tenant_id):
"""
处理用户搜索请求,从知识库中检索相关信息并生成回答
参数:
question (str): 用户的问题或查询
kb_ids (list): 知识库ID列表指定要搜索的知识库
tenant_id (str): 租户ID用于权限控制和资源隔离
流程:
1. 获取指定知识库的信息
2. 确定使用的嵌入模型
3. 根据知识库类型选择检索器(普通检索器或知识图谱检索器)
4. 初始化嵌入模型和聊天模型
5. 执行检索操作获取相关文档片段
6. 格式化知识库内容作为上下文
7. 构建系统提示词
8. 生成回答并添加引用标记
9. 流式返回生成的回答
返回:
generator: 生成器对象,产生包含回答和引用信息的字典
"""
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
# 初始化嵌入模型,用于将文本转换为向量表示
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
# 初始化聊天模型,用于生成回答
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
# 获取聊天模型的最大token长度用于控制上下文长度
max_tokens = chat_mdl.max_length
# 获取所有知识库的租户ID并去重
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
# 设置更小的相似度阈值以适配更好的效果(原始值0.1)
similarity_threshold = 0.01
# 调用检索器检索相关文档片段
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, similarity_threshold, 0.3, aggs=False, rank_feature=label_question(question, kbs))
# 将检索结果格式化为提示词并确保不超过模型最大token限制
knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """
角色:你是一个聪明的助手。
任务:总结知识库中的信息并回答用户的问题。
要求与限制:
- 绝不要捏造内容,尤其是数字。
- 如果知识库中的信息与用户问题无关,**只需回答:对不起,未提供相关信息。
- 使用Markdown格式进行回答。
- 使用用户提问所用的语言作答。
- 绝不要捏造内容,尤其是数字。
### 来自知识库的信息
%s
以上是来自知识库的信息。
""" % "\n".join(knowledges)
msg = [{"role": "user", "content": question}]
# 生成完成后添加回答中的引用标记
def decorate_answer(answer):
nonlocal knowledges, kbinfos, prompt
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs
refs = deepcopy(kbinfos)
for c in refs["chunks"]:
if c.get("vector"):
del c["vector"]
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
refs["chunks"] = chunks_format(refs)
return {"answer": answer, "reference": refs}
answer = ""
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)