RAGflow/rag/prompts.py

393 lines
13 KiB
Python
Raw Normal View History

2025-03-24 11:19:28 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json
import logging
import os
import re
from collections import defaultdict
import json_repair
from api.db import LLMType
from api.db.services.document_service import DocumentService
from api.db.services.llm_service import TenantLLMService, LLMBundle
from api.utils.file_utils import get_project_base_directory
from rag.settings import TAG_FLD
from rag.utils import num_tokens_from_string, encoder
def chunks_format(reference):
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))
return [{
"id": get_value(chunk, "chunk_id", "id"),
"content": get_value(chunk, "content", "content_with_weight"),
"document_id": get_value(chunk, "doc_id", "document_id"),
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"),
"url": chunk.get("url")
} for chunk in reference.get("chunks", [])]
def llm_id2llm_type(llm_id):
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
fnm = os.path.join(get_project_base_directory(), "conf")
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
for llm_factory in llm_factories["factory_llm_infos"]:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1]
def message_fit_in(msg, max_length=4000):
"""
调整消息列表使其token总数不超过max_length限制
参数:
msg: 消息列表每个元素为包含role和content的字典
max_length: 最大token数限制默认4000
返回:
tuple: (实际token数, 调整后的消息列表)
"""
2025-03-24 11:19:28 +08:00
def count():
"""计算当前消息列表的总token数"""
2025-03-24 11:19:28 +08:00
nonlocal msg
tks_cnts = []
for m in msg:
tks_cnts.append(
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
total = 0
for m in tks_cnts:
total += m["count"]
return total
c = count()
# 如果不超限制,直接返回
2025-03-24 11:19:28 +08:00
if c < max_length:
return c, msg
# 第一次精简:保留系统消息和最后一条消息
2025-03-31 10:53:42 +08:00
msg_ = [m for m in msg if m["role"] == "system"]
2025-03-24 11:19:28 +08:00
if len(msg) > 1:
msg_.append(msg[-1])
msg = msg_
c = count()
if c < max_length:
return c, msg
# 计算系统消息和最后一条消息的token数
2025-03-24 11:19:28 +08:00
ll = num_tokens_from_string(msg_[0]["content"])
ll2 = num_tokens_from_string(msg_[-1]["content"])
# 如果系统消息占比超过80%,则截断系统消息
2025-03-24 11:19:28 +08:00
if ll / (ll + ll2) > 0.8:
m = msg_[0]["content"]
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
msg[0]["content"] = m
return max_length, msg
# 否则截断最后一条消息
2025-03-31 10:53:42 +08:00
m = msg_[-1]["content"]
2025-03-24 11:19:28 +08:00
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
2025-03-31 10:53:42 +08:00
msg[-1]["content"] = m
2025-03-24 11:19:28 +08:00
return max_length, msg
def kb_prompt(kbinfos, max_tokens):
"""
将检索到的知识库内容格式化为适合大语言模型的提示词
参数:
kbinfos (dict): 检索结果包含chunks等信息
max_tokens (int): 模型的最大token限制
流程:
1. 提取所有检索到的文档片段内容
2. 计算token数量确保不超过模型限制
3. 获取文档元数据
4. 按文档名组织文档片段
5. 格式化为结构化提示词
返回:
list: 格式化后的知识库内容列表每个元素是一个文档的相关信息
"""
2025-03-24 11:19:28 +08:00
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
used_token_count = 0
chunks_num = 0
for i, c in enumerate(knowledges):
used_token_count += num_tokens_from_string(c)
chunks_num += 1
if max_tokens * 0.97 < used_token_count:
knowledges = knowledges[:i]
logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}")
break
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
docs = {d.id: d.meta_fields for d in docs}
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
2025-03-31 10:53:42 +08:00
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
doc2chunks[ck["docnm_kwd"]]["chunks"].append((f"URL: {ck['url']}\n" if "url" in ck else "") + f"ID: {i}\n" + ck["content_with_weight"])
2025-03-24 11:19:28 +08:00
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
knowledges = []
for nm, cks_meta in doc2chunks.items():
2025-03-31 10:53:42 +08:00
txt = f"\nDocument: {nm} \n"
2025-03-24 11:19:28 +08:00
for k, v in cks_meta["meta"].items():
txt += f"{k}: {v}\n"
txt += "Relevant fragments as following:\n"
for i, chunk in enumerate(cks_meta["chunks"], 1):
2025-03-31 10:53:42 +08:00
txt += f"{chunk}\n"
2025-03-24 11:19:28 +08:00
knowledges.append(txt)
return knowledges
2025-03-31 10:53:42 +08:00
def citation_prompt():
return """
# 引用要求:
- 以格式 '##i$$ ##j$$'插入引用其中 i, j 是所引用内容的 ID并用 '##' '$$' 包裹
- 在句子末尾插入引用每个句子最多 4 个引用
- 如果答案内容不来自检索到的文本块则不要插入引用
2025-03-31 10:53:42 +08:00
--- 示例 ---
<SYSTEM>: 以下是知识库:
2025-03-31 10:53:42 +08:00
Document: 埃隆·马斯克打破沉默谈加密货币警告不要全仓狗狗币 ...
2025-03-31 10:53:42 +08:00
URL: https://blockworks.co/news/elon-musk-crypto-dogecoin
ID: 0
特斯拉联合创始人建议不要全仓投入 Dogecoin但埃隆·马斯克表示它仍然是他最喜欢的加密货币...
2025-03-31 10:53:42 +08:00
Document: 埃隆·马斯克关于狗狗币的推文引发社交媒体狂热
2025-03-31 10:53:42 +08:00
ID: 1
马斯克表示他愿意服务D.O.G.E. Dogecoin 的缩写
2025-03-31 10:53:42 +08:00
Document: 埃隆·马斯克推文对狗狗币价格的因果影响
2025-03-31 10:53:42 +08:00
ID: 2
如果你想到 Dogecoin这个基于表情包的加密货币你就无法不想到埃隆·马斯克...
2025-03-31 10:53:42 +08:00
Document: 埃隆·马斯克推文点燃狗狗币在公共服务领域的未来前景
2025-03-31 10:53:42 +08:00
ID: 3
在埃隆·马斯克关于 Dogecoin 的公告后市场正在升温这是否意味着加密货币的新纪元...
2025-03-31 10:53:42 +08:00
以上是知识库
2025-03-31 10:53:42 +08:00
<USER>: 埃隆·马斯克对 Dogecoin 的看法是什么
2025-03-31 10:53:42 +08:00
<ASSISTANT>: 马斯克一贯表达了对 Dogecoin 的喜爱常常提及其幽默感和品牌中狗的元素他曾表示这是他最喜欢的加密货币 ##0 ##1。
最近马斯克暗示 Dogecoin 未来可能会有新的应用场景他的推文引发了关于 Dogecoin 可能被整合到公共服务中的猜测 ##3$$。
总体而言虽然马斯克喜欢 Dogecoin 并经常推广它但他也警告不要过度投资反映了他对其投机性质的既喜爱又谨慎的态度
2025-03-31 10:53:42 +08:00
--- 示例结束 ---
2025-03-31 10:53:42 +08:00
"""
2025-03-24 11:19:28 +08:00
def keyword_extraction(chat_mdl, content, topn=3):
prompt = f"""
角色文本分析器
任务提取给定文本内容中最重要的关键词/短语
要求
- 总结文本内容给出前{topn}个重要关键词/短语
- 关键词必须使用原文语言
- 关键词之间用英文逗号分隔
- 仅输出关键词
### 文本内容
2025-03-24 11:19:28 +08:00
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
def question_proposal(chat_mdl, content, topn=3):
prompt = f"""
Role: You're a text analyzer.
Task: propose {topn} questions about a given piece of text content.
Requirements:
- Understand and summarize the text content, and propose top {topn} important questions.
- The questions SHOULD NOT have overlapping meanings.
- The questions SHOULD cover the main content of the text as much as possible.
- The questions MUST be in language of the given piece of text content.
- One question per line.
- Question ONLY in output.
### Text Content
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
2025-03-31 10:53:42 +08:00
def full_question(tenant_id, llm_id, messages, language=None):
2025-03-24 11:19:28 +08:00
if llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
conv = []
for m in messages:
if m["role"] not in ["user", "assistant"]:
continue
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
conv = "\n".join(conv)
today = datetime.date.today().isoformat()
yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
prompt = f"""
Role: A helpful assistant
Task and steps:
1. Generate a full user question that would follow the conversation.
2. If the user's question involves relative date, you need to convert it into absolute date based on the current date, which is {today}. For example: 'yesterday' would be converted to {yesterday}.
Requirements & Restrictions:
- If the user's latest question is completely, don't do anything, just return the original question.
2025-03-31 10:53:42 +08:00
- DON'T generate anything except a refined question."""
if language:
prompt += f"""
- Text generated MUST be in {language}."""
else:
prompt += """
- Text generated MUST be in the same language of the original user's question.
"""
prompt += f"""
2025-03-24 11:19:28 +08:00
######################
-Examples-
######################
# Example 1
## Conversation
USER: What is the name of Donald Trump's father?
ASSISTANT: Fred Trump.
USER: And his mother?
###############
Output: What's the name of Donald Trump's mother?
------------
# Example 2
## Conversation
USER: What is the name of Donald Trump's father?
ASSISTANT: Fred Trump.
USER: And his mother?
ASSISTANT: Mary Trump.
User: What's her full name?
###############
Output: What's the full name of Donald Trump's mother Mary Trump?
------------
# Example 3
## Conversation
USER: What's the weather today in London?
ASSISTANT: Cloudy.
USER: What's about tomorrow in Rochester?
###############
Output: What's the weather in Rochester on {tomorrow}?
2025-03-31 10:53:42 +08:00
######################
2025-03-24 11:19:28 +08:00
# Real Data
## Conversation
{conv}
###############
"""
ans = chat_mdl.chat(prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
prompt = f"""
Role: You're a text analyzer.
Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
Steps::
- Comprehend the tag/label set.
- Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
- Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
Requirements
- The tags MUST be from the tag set.
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
- The relevance score must be range from 1 to 10.
- Keywords ONLY in output.
# TAG SET
{", ".join(all_tags)}
"""
for i, ex in enumerate(examples):
prompt += """
# Examples {}
### Text Content
{}
Output:
{}
""".format(i, ex["content"], json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False))
prompt += f"""
# Real Data
### Text Content
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
raise Exception(kwd)
try:
return json_repair.loads(kwd)
except json_repair.JSONDecodeError:
try:
result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
return json_repair.loads(result)
except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}")
raise e