451 lines
17 KiB
Python
451 lines
17 KiB
Python
#
|
||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
#
|
||
import datetime
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
from collections import defaultdict
|
||
import json_repair
|
||
from api.db import LLMType
|
||
from api.db.services.document_service import DocumentService
|
||
from api.db.services.llm_service import TenantLLMService, LLMBundle
|
||
from api.utils.file_utils import get_project_base_directory
|
||
from rag.settings import TAG_FLD
|
||
from rag.utils import num_tokens_from_string, encoder
|
||
|
||
|
||
def chunks_format(reference):
|
||
def get_value(d, k1, k2):
|
||
return d.get(k1, d.get(k2))
|
||
|
||
return [
|
||
{
|
||
"id": get_value(chunk, "chunk_id", "id"),
|
||
"content": get_value(chunk, "content", "content_with_weight"),
|
||
"document_id": get_value(chunk, "doc_id", "document_id"),
|
||
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
||
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
||
"image_id": get_value(chunk, "image_id", "img_id"),
|
||
"positions": get_value(chunk, "positions", "position_int"),
|
||
"url": chunk.get("url"),
|
||
"similarity": chunk.get("similarity"),
|
||
"vector_similarity": chunk.get("vector_similarity"),
|
||
"term_similarity": chunk.get("term_similarity"),
|
||
"doc_type": chunk.get("doc_type_kwd"),
|
||
}
|
||
for chunk in reference.get("chunks", [])
|
||
]
|
||
|
||
|
||
def llm_id2llm_type(llm_id):
|
||
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||
fnm = os.path.join(get_project_base_directory(), "conf")
|
||
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
||
for llm_factory in llm_factories["factory_llm_infos"]:
|
||
for llm in llm_factory["llm"]:
|
||
if llm_id == llm["llm_name"]:
|
||
return llm["model_type"].strip(",")[-1]
|
||
|
||
|
||
def message_fit_in(msg, max_length=4000):
|
||
"""
|
||
调整消息列表使其token总数不超过max_length限制
|
||
|
||
参数:
|
||
msg: 消息列表,每个元素为包含role和content的字典
|
||
max_length: 最大token数限制,默认4000
|
||
|
||
返回:
|
||
tuple: (实际token数, 调整后的消息列表)
|
||
"""
|
||
|
||
def count():
|
||
"""计算当前消息列表的总token数"""
|
||
nonlocal msg
|
||
tks_cnts = []
|
||
for m in msg:
|
||
tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||
total = 0
|
||
for m in tks_cnts:
|
||
total += m["count"]
|
||
return total
|
||
|
||
c = count()
|
||
# 如果不超限制,直接返回
|
||
if c < max_length:
|
||
return c, msg
|
||
|
||
# 第一次精简:保留系统消息和最后一条消息
|
||
msg_ = [m for m in msg if m["role"] == "system"]
|
||
if len(msg) > 1:
|
||
msg_.append(msg[-1])
|
||
msg = msg_
|
||
c = count()
|
||
if c < max_length:
|
||
return c, msg
|
||
|
||
# 计算系统消息和最后一条消息的token数
|
||
ll = num_tokens_from_string(msg_[0]["content"])
|
||
ll2 = num_tokens_from_string(msg_[-1]["content"])
|
||
# 如果系统消息占比超过80%,则截断系统消息
|
||
if ll / (ll + ll2) > 0.8:
|
||
m = msg_[0]["content"]
|
||
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
|
||
msg[0]["content"] = m
|
||
return max_length, msg
|
||
|
||
# 否则截断最后一条消息
|
||
m = msg_[-1]["content"]
|
||
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
|
||
msg[-1]["content"] = m
|
||
return max_length, msg
|
||
|
||
|
||
def kb_prompt(kbinfos, max_tokens):
|
||
"""
|
||
将检索到的知识库内容格式化为适合大语言模型的提示词
|
||
|
||
参数:
|
||
kbinfos (dict): 检索结果,包含chunks等信息
|
||
max_tokens (int): 模型的最大token限制
|
||
|
||
流程:
|
||
1. 提取所有检索到的文档片段内容
|
||
2. 计算token数量,确保不超过模型限制
|
||
3. 获取文档元数据
|
||
4. 按文档名组织文档片段
|
||
5. 格式化为结构化提示词
|
||
|
||
返回:
|
||
list: 格式化后的知识库内容列表,每个元素是一个文档的相关信息
|
||
"""
|
||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||
used_token_count = 0
|
||
chunks_num = 0
|
||
for i, c in enumerate(knowledges):
|
||
used_token_count += num_tokens_from_string(c)
|
||
chunks_num += 1
|
||
if max_tokens * 0.97 < used_token_count:
|
||
knowledges = knowledges[:i]
|
||
logging.warning(f"Not all the retrieval into prompt: {i + 1}/{len(knowledges)}")
|
||
break
|
||
|
||
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
|
||
docs = {d.id: d.meta_fields for d in docs}
|
||
|
||
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
|
||
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
||
cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
|
||
cnt += ck["content_with_weight"]
|
||
doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
|
||
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
|
||
|
||
knowledges = []
|
||
for nm, cks_meta in doc2chunks.items():
|
||
txt = f"\n文档: {nm} \n"
|
||
for k, v in cks_meta["meta"].items():
|
||
txt += f"{k}: {v}\n"
|
||
txt += "相关片段如下:\n"
|
||
for i, chunk in enumerate(cks_meta["chunks"], 1):
|
||
txt += f"{chunk}\n"
|
||
knowledges.append(txt)
|
||
return knowledges
|
||
|
||
|
||
def citation_prompt():
|
||
return """
|
||
# 引用要求:
|
||
- 以格式 '##i$$ ##j$$'插入引用,其中 i, j 是所引用内容的 ID,并用 '##' 和 '$$' 包裹。
|
||
- 在句子末尾插入引用,每个句子最多 4 个引用。
|
||
- 如果答案内容不来自检索到的文本块,则不要插入引用。
|
||
- 不要使用独立的文档 ID(例如 `#ID#`)。
|
||
- 在任何情况下,不得使用其他引用样式或格式(例如 `~~i==`、`[i]`、`(i)` 等)。
|
||
- 引用必须始终使用 `##i$$` 格式。
|
||
- 任何未能遵守上述规则的情况,包括但不限于格式错误、使用禁止的样式或不支持的引用,都将被视为错误,应跳过为该句添加引用。
|
||
|
||
--- 示例 ---
|
||
<SYSTEM>: 以下是知识库:
|
||
|
||
Document: 埃隆·马斯克打破沉默谈加密货币,警告不要全仓狗狗币 ...
|
||
URL: https://blockworks.co/news/elon-musk-crypto-dogecoin
|
||
ID: 0
|
||
特斯拉联合创始人建议不要全仓投入 Dogecoin,但埃隆·马斯克表示它仍然是他最喜欢的加密货币...
|
||
|
||
Document: 埃隆·马斯克关于狗狗币的推文引发社交媒体狂热
|
||
ID: 1
|
||
马斯克表示他“愿意服务”D.O.G.E.——即 Dogecoin 的缩写。
|
||
|
||
Document: 埃隆·马斯克推文对狗狗币价格的因果影响
|
||
ID: 2
|
||
如果你想到 Dogecoin——这个基于表情包的加密货币,你就无法不想到埃隆·马斯克...
|
||
|
||
Document: 埃隆·马斯克推文点燃狗狗币在公共服务领域的未来前景
|
||
ID: 3
|
||
在埃隆·马斯克关于 Dogecoin 的公告后,市场正在升温。这是否意味着加密货币的新纪元?...
|
||
|
||
以上是知识库。
|
||
|
||
<USER>: 埃隆·马斯克对 Dogecoin 的看法是什么?
|
||
|
||
<ASSISTANT>: 马斯克一贯表达了对 Dogecoin 的喜爱,常常提及其幽默感和品牌中狗的元素。他曾表示这是他最喜欢的加密货币 ##0 ##1。
|
||
最近,马斯克暗示 Dogecoin 未来可能会有新的应用场景。他的推文引发了关于 Dogecoin 可能被整合到公共服务中的猜测 ##3$$。
|
||
总体而言,虽然马斯克喜欢 Dogecoin 并经常推广它,但他也警告不要过度投资,反映了他对其投机性质的既喜爱又谨慎的态度。
|
||
|
||
--- 示例结束 ---
|
||
|
||
"""
|
||
|
||
|
||
def keyword_extraction(chat_mdl, content, topn=3):
|
||
prompt = f"""
|
||
角色:文本分析器
|
||
任务:提取给定文本内容中最重要的关键词/短语
|
||
要求:
|
||
- 总结文本内容,给出前{topn}个重要关键词/短语
|
||
- 关键词必须使用原文语言
|
||
- 关键词之间用英文逗号分隔
|
||
- 仅输出关键词
|
||
|
||
### 文本内容
|
||
{content}
|
||
"""
|
||
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
|
||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||
if isinstance(kwd, tuple):
|
||
kwd = kwd[0]
|
||
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
|
||
if kwd.find("**ERROR**") >= 0:
|
||
return ""
|
||
return kwd
|
||
|
||
|
||
def question_proposal(chat_mdl, content, topn=3):
|
||
prompt = f"""
|
||
Role: You're a text analyzer.
|
||
Task: propose {topn} questions about a given piece of text content.
|
||
Requirements:
|
||
- Understand and summarize the text content, and propose top {topn} important questions.
|
||
- The questions SHOULD NOT have overlapping meanings.
|
||
- The questions SHOULD cover the main content of the text as much as possible.
|
||
- The questions MUST be in language of the given piece of text content.
|
||
- One question per line.
|
||
- Question ONLY in output.
|
||
|
||
### Text Content
|
||
{content}
|
||
|
||
"""
|
||
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
|
||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||
if isinstance(kwd, tuple):
|
||
kwd = kwd[0]
|
||
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
|
||
if kwd.find("**ERROR**") >= 0:
|
||
return ""
|
||
return kwd
|
||
|
||
|
||
def full_question(tenant_id, llm_id, messages, language=None):
|
||
if llm_id2llm_type(llm_id) == "image2text":
|
||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||
else:
|
||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||
conv = []
|
||
for m in messages:
|
||
if m["role"] not in ["user", "assistant"]:
|
||
continue
|
||
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
||
conv = "\n".join(conv)
|
||
today = datetime.date.today().isoformat()
|
||
yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
|
||
tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
|
||
prompt = f"""
|
||
Role: A helpful assistant
|
||
|
||
Task and steps:
|
||
1. Generate a full user question that would follow the conversation.
|
||
2. If the user's question involves relative date, you need to convert it into absolute date based on the current date, which is {today}. For example: 'yesterday' would be converted to {yesterday}.
|
||
|
||
Requirements & Restrictions:
|
||
- If the user's latest question is completely, don't do anything, just return the original question.
|
||
- DON'T generate anything except a refined question."""
|
||
if language:
|
||
prompt += f"""
|
||
- Text generated MUST be in {language}."""
|
||
else:
|
||
prompt += """
|
||
- Text generated MUST be in the same language of the original user's question.
|
||
"""
|
||
prompt += f"""
|
||
|
||
######################
|
||
-Examples-
|
||
######################
|
||
|
||
# Example 1
|
||
## Conversation
|
||
USER: What is the name of Donald Trump's father?
|
||
ASSISTANT: Fred Trump.
|
||
USER: And his mother?
|
||
###############
|
||
Output: What's the name of Donald Trump's mother?
|
||
|
||
------------
|
||
# Example 2
|
||
## Conversation
|
||
USER: What is the name of Donald Trump's father?
|
||
ASSISTANT: Fred Trump.
|
||
USER: And his mother?
|
||
ASSISTANT: Mary Trump.
|
||
User: What's her full name?
|
||
###############
|
||
Output: What's the full name of Donald Trump's mother Mary Trump?
|
||
|
||
------------
|
||
# Example 3
|
||
## Conversation
|
||
USER: What's the weather today in London?
|
||
ASSISTANT: Cloudy.
|
||
USER: What's about tomorrow in Rochester?
|
||
###############
|
||
Output: What's the weather in Rochester on {tomorrow}?
|
||
|
||
######################
|
||
# Real Data
|
||
## Conversation
|
||
{conv}
|
||
###############
|
||
"""
|
||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
|
||
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
||
|
||
|
||
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
||
prompt = f"""
|
||
Role: You're a text analyzer.
|
||
|
||
Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
|
||
|
||
Steps::
|
||
- Comprehend the tag/label set.
|
||
- Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
|
||
- Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
|
||
|
||
Requirements
|
||
- The tags MUST be from the tag set.
|
||
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
|
||
- The relevance score must be range from 1 to 10.
|
||
- Keywords ONLY in output.
|
||
|
||
# TAG SET
|
||
{", ".join(all_tags)}
|
||
|
||
"""
|
||
for i, ex in enumerate(examples):
|
||
prompt += """
|
||
# Examples {}
|
||
### Text Content
|
||
{}
|
||
|
||
Output:
|
||
{}
|
||
|
||
""".format(i, ex["content"], json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False))
|
||
|
||
prompt += f"""
|
||
# Real Data
|
||
### Text Content
|
||
{content}
|
||
|
||
"""
|
||
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
|
||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
|
||
if isinstance(kwd, tuple):
|
||
kwd = kwd[0]
|
||
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
|
||
if kwd.find("**ERROR**") >= 0:
|
||
raise Exception(kwd)
|
||
|
||
try:
|
||
return json_repair.loads(kwd)
|
||
except json_repair.JSONDecodeError:
|
||
try:
|
||
result = kwd.replace(prompt[:-1], "").replace("user", "").replace("model", "").strip()
|
||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||
return json_repair.loads(result)
|
||
except Exception as e:
|
||
logging.exception(f"JSON parsing error: {result} -> {e}")
|
||
raise e
|
||
|
||
|
||
def vision_llm_describe_prompt(page=None) -> str:
|
||
prompt_en = """
|
||
INSTRUCTION:
|
||
Transcribe the content from the provided PDF page image into clean Markdown format.
|
||
- Only output the content transcribed from the image.
|
||
- Do NOT output this instruction or any other explanation.
|
||
- If the content is missing or you do not understand the input, return an empty string.
|
||
|
||
RULES:
|
||
1. Do NOT generate examples, demonstrations, or templates.
|
||
2. Do NOT output any extra text such as 'Example', 'Example Output', or similar.
|
||
3. Do NOT generate any tables, headings, or content that is not explicitly present in the image.
|
||
4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content.
|
||
5. Do NOT explain Markdown or mention that you are using Markdown.
|
||
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
||
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
||
8. Preserve the original language, information, and order exactly as shown in the image.
|
||
"""
|
||
|
||
if page is not None:
|
||
prompt_en += f"\nAt the end of the transcription, add the page divider: `--- Page {page} ---`."
|
||
|
||
prompt_en += """
|
||
FAILURE HANDLING:
|
||
- If you do not detect valid content in the image, return an empty string.
|
||
"""
|
||
return prompt_en
|
||
|
||
|
||
def vision_llm_figure_describe_prompt() -> str:
|
||
prompt = """
|
||
You are an expert visual data analyst. Analyze the image and provide a comprehensive description of its content. Focus on identifying the type of visual data representation (e.g., bar chart, pie chart, line graph, table, flowchart), its structure, and any text captions or labels included in the image.
|
||
|
||
Tasks:
|
||
1. Describe the overall structure of the visual representation. Specify if it is a chart, graph, table, or diagram.
|
||
2. Identify and extract any axes, legends, titles, or labels present in the image. Provide the exact text where available.
|
||
3. Extract the data points from the visual elements (e.g., bar heights, line graph coordinates, pie chart segments, table rows and columns).
|
||
4. Analyze and explain any trends, comparisons, or patterns shown in the data.
|
||
5. Capture any annotations, captions, or footnotes, and explain their relevance to the image.
|
||
6. Only include details that are explicitly present in the image. If an element (e.g., axis, legend, or caption) does not exist or is not visible, do not mention it.
|
||
|
||
Output format (include only sections relevant to the image content):
|
||
- Visual Type: [Type]
|
||
- Title: [Title text, if available]
|
||
- Axes / Legends / Labels: [Details, if available]
|
||
- Data Points: [Extracted data]
|
||
- Trends / Insights: [Analysis and interpretation]
|
||
- Captions / Annotations: [Text and relevance, if available]
|
||
|
||
Ensure high accuracy, clarity, and completeness in your analysis, and includes only the information present in the image. Avoid unnecessary statements about missing elements.
|
||
"""
|
||
return prompt
|