393 lines
13 KiB
Python
393 lines
13 KiB
Python
#
|
||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
#
|
||
import datetime
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
from collections import defaultdict
|
||
import json_repair
|
||
from api.db import LLMType
|
||
from api.db.services.document_service import DocumentService
|
||
from api.db.services.llm_service import TenantLLMService, LLMBundle
|
||
from api.utils.file_utils import get_project_base_directory
|
||
from rag.settings import TAG_FLD
|
||
from rag.utils import num_tokens_from_string, encoder
|
||
|
||
|
||
def chunks_format(reference):
|
||
def get_value(d, k1, k2):
|
||
return d.get(k1, d.get(k2))
|
||
|
||
return [{
|
||
"id": get_value(chunk, "chunk_id", "id"),
|
||
"content": get_value(chunk, "content", "content_with_weight"),
|
||
"document_id": get_value(chunk, "doc_id", "document_id"),
|
||
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
||
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
||
"image_id": get_value(chunk, "image_id", "img_id"),
|
||
"positions": get_value(chunk, "positions", "position_int"),
|
||
"url": chunk.get("url")
|
||
} for chunk in reference.get("chunks", [])]
|
||
|
||
|
||
def llm_id2llm_type(llm_id):
|
||
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||
fnm = os.path.join(get_project_base_directory(), "conf")
|
||
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
||
for llm_factory in llm_factories["factory_llm_infos"]:
|
||
for llm in llm_factory["llm"]:
|
||
if llm_id == llm["llm_name"]:
|
||
return llm["model_type"].strip(",")[-1]
|
||
|
||
|
||
def message_fit_in(msg, max_length=4000):
|
||
"""
|
||
调整消息列表使其token总数不超过max_length限制
|
||
|
||
参数:
|
||
msg: 消息列表,每个元素为包含role和content的字典
|
||
max_length: 最大token数限制,默认4000
|
||
|
||
返回:
|
||
tuple: (实际token数, 调整后的消息列表)
|
||
"""
|
||
def count():
|
||
"""计算当前消息列表的总token数"""
|
||
nonlocal msg
|
||
tks_cnts = []
|
||
for m in msg:
|
||
tks_cnts.append(
|
||
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||
total = 0
|
||
for m in tks_cnts:
|
||
total += m["count"]
|
||
return total
|
||
|
||
c = count()
|
||
# 如果不超限制,直接返回
|
||
if c < max_length:
|
||
return c, msg
|
||
|
||
# 第一次精简:保留系统消息和最后一条消息
|
||
msg_ = [m for m in msg if m["role"] == "system"]
|
||
if len(msg) > 1:
|
||
msg_.append(msg[-1])
|
||
msg = msg_
|
||
c = count()
|
||
if c < max_length:
|
||
return c, msg
|
||
|
||
# 计算系统消息和最后一条消息的token数
|
||
ll = num_tokens_from_string(msg_[0]["content"])
|
||
ll2 = num_tokens_from_string(msg_[-1]["content"])
|
||
# 如果系统消息占比超过80%,则截断系统消息
|
||
if ll / (ll + ll2) > 0.8:
|
||
m = msg_[0]["content"]
|
||
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
||
msg[0]["content"] = m
|
||
return max_length, msg
|
||
|
||
# 否则截断最后一条消息
|
||
m = msg_[-1]["content"]
|
||
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
||
msg[-1]["content"] = m
|
||
return max_length, msg
|
||
|
||
|
||
def kb_prompt(kbinfos, max_tokens):
|
||
"""
|
||
将检索到的知识库内容格式化为适合大语言模型的提示词
|
||
|
||
参数:
|
||
kbinfos (dict): 检索结果,包含chunks等信息
|
||
max_tokens (int): 模型的最大token限制
|
||
|
||
流程:
|
||
1. 提取所有检索到的文档片段内容
|
||
2. 计算token数量,确保不超过模型限制
|
||
3. 获取文档元数据
|
||
4. 按文档名组织文档片段
|
||
5. 格式化为结构化提示词
|
||
|
||
返回:
|
||
list: 格式化后的知识库内容列表,每个元素是一个文档的相关信息
|
||
"""
|
||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||
used_token_count = 0
|
||
chunks_num = 0
|
||
for i, c in enumerate(knowledges):
|
||
used_token_count += num_tokens_from_string(c)
|
||
chunks_num += 1
|
||
if max_tokens * 0.97 < used_token_count:
|
||
knowledges = knowledges[:i]
|
||
logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}")
|
||
break
|
||
|
||
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
|
||
docs = {d.id: d.meta_fields for d in docs}
|
||
|
||
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
|
||
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
||
doc2chunks[ck["docnm_kwd"]]["chunks"].append((f"URL: {ck['url']}\n" if "url" in ck else "") + f"ID: {i}\n" + ck["content_with_weight"])
|
||
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
|
||
|
||
knowledges = []
|
||
for nm, cks_meta in doc2chunks.items():
|
||
txt = f"\nDocument: {nm} \n"
|
||
for k, v in cks_meta["meta"].items():
|
||
txt += f"{k}: {v}\n"
|
||
txt += "Relevant fragments as following:\n"
|
||
for i, chunk in enumerate(cks_meta["chunks"], 1):
|
||
txt += f"{chunk}\n"
|
||
knowledges.append(txt)
|
||
return knowledges
|
||
|
||
|
||
def citation_prompt():
|
||
return """
|
||
# 引用要求:
|
||
- 以格式 '##i$$ ##j$$'插入引用,其中 i, j 是所引用内容的 ID,并用 '##' 和 '$$' 包裹。
|
||
- 在句子末尾插入引用,每个句子最多 4 个引用。
|
||
- 如果答案内容不来自检索到的文本块,则不要插入引用。
|
||
|
||
--- 示例 ---
|
||
<SYSTEM>: 以下是知识库:
|
||
|
||
Document: 埃隆·马斯克打破沉默谈加密货币,警告不要全仓狗狗币 ...
|
||
URL: https://blockworks.co/news/elon-musk-crypto-dogecoin
|
||
ID: 0
|
||
特斯拉联合创始人建议不要全仓投入 Dogecoin,但埃隆·马斯克表示它仍然是他最喜欢的加密货币...
|
||
|
||
Document: 埃隆·马斯克关于狗狗币的推文引发社交媒体狂热
|
||
ID: 1
|
||
马斯克表示他“愿意服务”D.O.G.E.——即 Dogecoin 的缩写。
|
||
|
||
Document: 埃隆·马斯克推文对狗狗币价格的因果影响
|
||
ID: 2
|
||
如果你想到 Dogecoin——这个基于表情包的加密货币,你就无法不想到埃隆·马斯克...
|
||
|
||
Document: 埃隆·马斯克推文点燃狗狗币在公共服务领域的未来前景
|
||
ID: 3
|
||
在埃隆·马斯克关于 Dogecoin 的公告后,市场正在升温。这是否意味着加密货币的新纪元?...
|
||
|
||
以上是知识库。
|
||
|
||
<USER>: 埃隆·马斯克对 Dogecoin 的看法是什么?
|
||
|
||
<ASSISTANT>: 马斯克一贯表达了对 Dogecoin 的喜爱,常常提及其幽默感和品牌中狗的元素。他曾表示这是他最喜欢的加密货币 ##0 ##1。
|
||
最近,马斯克暗示 Dogecoin 未来可能会有新的应用场景。他的推文引发了关于 Dogecoin 可能被整合到公共服务中的猜测 ##3$$。
|
||
总体而言,虽然马斯克喜欢 Dogecoin 并经常推广它,但他也警告不要过度投资,反映了他对其投机性质的既喜爱又谨慎的态度。
|
||
|
||
--- 示例结束 ---
|
||
|
||
"""
|
||
|
||
|
||
def keyword_extraction(chat_mdl, content, topn=3):
|
||
prompt = f"""
|
||
角色:文本分析器
|
||
任务:提取给定文本内容中最重要的关键词/短语
|
||
要求:
|
||
- 总结文本内容,给出前{topn}个重要关键词/短语
|
||
- 关键词必须使用原文语言
|
||
- 关键词之间用英文逗号分隔
|
||
- 仅输出关键词
|
||
|
||
### 文本内容
|
||
{content}
|
||
"""
|
||
msg = [
|
||
{"role": "system", "content": prompt},
|
||
{"role": "user", "content": "Output: "}
|
||
]
|
||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||
if isinstance(kwd, tuple):
|
||
kwd = kwd[0]
|
||
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
|
||
if kwd.find("**ERROR**") >= 0:
|
||
return ""
|
||
return kwd
|
||
|
||
|
||
def question_proposal(chat_mdl, content, topn=3):
|
||
prompt = f"""
|
||
Role: You're a text analyzer.
|
||
Task: propose {topn} questions about a given piece of text content.
|
||
Requirements:
|
||
- Understand and summarize the text content, and propose top {topn} important questions.
|
||
- The questions SHOULD NOT have overlapping meanings.
|
||
- The questions SHOULD cover the main content of the text as much as possible.
|
||
- The questions MUST be in language of the given piece of text content.
|
||
- One question per line.
|
||
- Question ONLY in output.
|
||
|
||
### Text Content
|
||
{content}
|
||
|
||
"""
|
||
msg = [
|
||
{"role": "system", "content": prompt},
|
||
{"role": "user", "content": "Output: "}
|
||
]
|
||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||
if isinstance(kwd, tuple):
|
||
kwd = kwd[0]
|
||
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
|
||
if kwd.find("**ERROR**") >= 0:
|
||
return ""
|
||
return kwd
|
||
|
||
|
||
def full_question(tenant_id, llm_id, messages, language=None):
|
||
if llm_id2llm_type(llm_id) == "image2text":
|
||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||
else:
|
||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||
conv = []
|
||
for m in messages:
|
||
if m["role"] not in ["user", "assistant"]:
|
||
continue
|
||
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
||
conv = "\n".join(conv)
|
||
today = datetime.date.today().isoformat()
|
||
yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
|
||
tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
|
||
prompt = f"""
|
||
Role: A helpful assistant
|
||
|
||
Task and steps:
|
||
1. Generate a full user question that would follow the conversation.
|
||
2. If the user's question involves relative date, you need to convert it into absolute date based on the current date, which is {today}. For example: 'yesterday' would be converted to {yesterday}.
|
||
|
||
Requirements & Restrictions:
|
||
- If the user's latest question is completely, don't do anything, just return the original question.
|
||
- DON'T generate anything except a refined question."""
|
||
if language:
|
||
prompt += f"""
|
||
- Text generated MUST be in {language}."""
|
||
else:
|
||
prompt += """
|
||
- Text generated MUST be in the same language of the original user's question.
|
||
"""
|
||
prompt += f"""
|
||
|
||
######################
|
||
-Examples-
|
||
######################
|
||
|
||
# Example 1
|
||
## Conversation
|
||
USER: What is the name of Donald Trump's father?
|
||
ASSISTANT: Fred Trump.
|
||
USER: And his mother?
|
||
###############
|
||
Output: What's the name of Donald Trump's mother?
|
||
|
||
------------
|
||
# Example 2
|
||
## Conversation
|
||
USER: What is the name of Donald Trump's father?
|
||
ASSISTANT: Fred Trump.
|
||
USER: And his mother?
|
||
ASSISTANT: Mary Trump.
|
||
User: What's her full name?
|
||
###############
|
||
Output: What's the full name of Donald Trump's mother Mary Trump?
|
||
|
||
------------
|
||
# Example 3
|
||
## Conversation
|
||
USER: What's the weather today in London?
|
||
ASSISTANT: Cloudy.
|
||
USER: What's about tomorrow in Rochester?
|
||
###############
|
||
Output: What's the weather in Rochester on {tomorrow}?
|
||
|
||
######################
|
||
# Real Data
|
||
## Conversation
|
||
{conv}
|
||
###############
|
||
"""
|
||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
|
||
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
||
|
||
|
||
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
||
prompt = f"""
|
||
Role: You're a text analyzer.
|
||
|
||
Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
|
||
|
||
Steps::
|
||
- Comprehend the tag/label set.
|
||
- Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
|
||
- Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
|
||
|
||
Requirements
|
||
- The tags MUST be from the tag set.
|
||
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
|
||
- The relevance score must be range from 1 to 10.
|
||
- Keywords ONLY in output.
|
||
|
||
# TAG SET
|
||
{", ".join(all_tags)}
|
||
|
||
"""
|
||
for i, ex in enumerate(examples):
|
||
prompt += """
|
||
# Examples {}
|
||
### Text Content
|
||
{}
|
||
|
||
Output:
|
||
{}
|
||
|
||
""".format(i, ex["content"], json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False))
|
||
|
||
prompt += f"""
|
||
# Real Data
|
||
### Text Content
|
||
{content}
|
||
|
||
"""
|
||
msg = [
|
||
{"role": "system", "content": prompt},
|
||
{"role": "user", "content": "Output: "}
|
||
]
|
||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
|
||
if isinstance(kwd, tuple):
|
||
kwd = kwd[0]
|
||
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
|
||
if kwd.find("**ERROR**") >= 0:
|
||
raise Exception(kwd)
|
||
|
||
try:
|
||
return json_repair.loads(kwd)
|
||
except json_repair.JSONDecodeError:
|
||
try:
|
||
result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip()
|
||
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
||
return json_repair.loads(result)
|
||
except Exception as e:
|
||
logging.exception(f"JSON parsing error: {result} -> {e}")
|
||
raise e
|