refactor(prompts): 引用提示词增强
This commit is contained in:
parent
899b49ddc6
commit
5cae4c2d7f
|
@ -32,16 +32,19 @@ def chunks_format(reference):
|
|||
def get_value(d, k1, k2):
|
||||
return d.get(k1, d.get(k2))
|
||||
|
||||
return [{
|
||||
"id": get_value(chunk, "chunk_id", "id"),
|
||||
"content": get_value(chunk, "content", "content_with_weight"),
|
||||
"document_id": get_value(chunk, "doc_id", "document_id"),
|
||||
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
||||
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
||||
"image_id": get_value(chunk, "image_id", "img_id"),
|
||||
"positions": get_value(chunk, "positions", "position_int"),
|
||||
"url": chunk.get("url")
|
||||
} for chunk in reference.get("chunks", [])]
|
||||
return [
|
||||
{
|
||||
"id": get_value(chunk, "chunk_id", "id"),
|
||||
"content": get_value(chunk, "content", "content_with_weight"),
|
||||
"document_id": get_value(chunk, "doc_id", "document_id"),
|
||||
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
||||
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
||||
"image_id": get_value(chunk, "image_id", "img_id"),
|
||||
"positions": get_value(chunk, "positions", "position_int"),
|
||||
"url": chunk.get("url"),
|
||||
}
|
||||
for chunk in reference.get("chunks", [])
|
||||
]
|
||||
|
||||
|
||||
def llm_id2llm_type(llm_id):
|
||||
|
@ -57,21 +60,21 @@ def llm_id2llm_type(llm_id):
|
|||
def message_fit_in(msg, max_length=4000):
|
||||
"""
|
||||
调整消息列表使其token总数不超过max_length限制
|
||||
|
||||
|
||||
参数:
|
||||
msg: 消息列表,每个元素为包含role和content的字典
|
||||
max_length: 最大token数限制,默认4000
|
||||
|
||||
|
||||
返回:
|
||||
tuple: (实际token数, 调整后的消息列表)
|
||||
"""
|
||||
|
||||
def count():
|
||||
"""计算当前消息列表的总token数"""
|
||||
nonlocal msg
|
||||
tks_cnts = []
|
||||
for m in msg:
|
||||
tks_cnts.append(
|
||||
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||||
tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||||
total = 0
|
||||
for m in tks_cnts:
|
||||
total += m["count"]
|
||||
|
@ -81,7 +84,7 @@ def message_fit_in(msg, max_length=4000):
|
|||
# 如果不超限制,直接返回
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
|
||||
# 第一次精简:保留系统消息和最后一条消息
|
||||
msg_ = [m for m in msg if m["role"] == "system"]
|
||||
if len(msg) > 1:
|
||||
|
@ -90,20 +93,20 @@ def message_fit_in(msg, max_length=4000):
|
|||
c = count()
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
|
||||
# 计算系统消息和最后一条消息的token数
|
||||
ll = num_tokens_from_string(msg_[0]["content"])
|
||||
ll2 = num_tokens_from_string(msg_[-1]["content"])
|
||||
# 如果系统消息占比超过80%,则截断系统消息
|
||||
if ll / (ll + ll2) > 0.8:
|
||||
m = msg_[0]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
||||
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
|
||||
msg[0]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
|
||||
# 否则截断最后一条消息
|
||||
m = msg_[-1]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
||||
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
|
||||
msg[-1]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
|
@ -111,18 +114,18 @@ def message_fit_in(msg, max_length=4000):
|
|||
def kb_prompt(kbinfos, max_tokens):
|
||||
"""
|
||||
将检索到的知识库内容格式化为适合大语言模型的提示词
|
||||
|
||||
|
||||
参数:
|
||||
kbinfos (dict): 检索结果,包含chunks等信息
|
||||
max_tokens (int): 模型的最大token限制
|
||||
|
||||
|
||||
流程:
|
||||
1. 提取所有检索到的文档片段内容
|
||||
2. 计算token数量,确保不超过模型限制
|
||||
3. 获取文档元数据
|
||||
4. 按文档名组织文档片段
|
||||
5. 格式化为结构化提示词
|
||||
|
||||
|
||||
返回:
|
||||
list: 格式化后的知识库内容列表,每个元素是一个文档的相关信息
|
||||
"""
|
||||
|
@ -134,7 +137,7 @@ def kb_prompt(kbinfos, max_tokens):
|
|||
chunks_num += 1
|
||||
if max_tokens * 0.97 < used_token_count:
|
||||
knowledges = knowledges[:i]
|
||||
logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}")
|
||||
logging.warning(f"Not all the retrieval into prompt: {i + 1}/{len(knowledges)}")
|
||||
break
|
||||
|
||||
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
|
||||
|
@ -163,6 +166,10 @@ def citation_prompt():
|
|||
- 以格式 '##i$$ ##j$$'插入引用,其中 i, j 是所引用内容的 ID,并用 '##' 和 '$$' 包裹。
|
||||
- 在句子末尾插入引用,每个句子最多 4 个引用。
|
||||
- 如果答案内容不来自检索到的文本块,则不要插入引用。
|
||||
- 不要使用独立的文档 ID(例如 `#ID#`)。
|
||||
- 在任何情况下,不得使用其他引用样式或格式(例如 `~~i==`、`[i]`、`(i)` 等)。
|
||||
- 引用必须始终使用 `##i$$` 格式。
|
||||
- 任何未能遵守上述规则的情况,包括但不限于格式错误、使用禁止的样式或不支持的引用,都将被视为错误,应跳过为该句添加引用。
|
||||
|
||||
--- 示例 ---
|
||||
<SYSTEM>: 以下是知识库:
|
||||
|
@ -210,10 +217,7 @@ def keyword_extraction(chat_mdl, content, topn=3):
|
|||
### 文本内容
|
||||
{content}
|
||||
"""
|
||||
msg = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": "Output: "}
|
||||
]
|
||||
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
|
@ -240,10 +244,7 @@ Requirements:
|
|||
{content}
|
||||
|
||||
"""
|
||||
msg = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": "Output: "}
|
||||
]
|
||||
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
|
@ -368,10 +369,7 @@ Output:
|
|||
{content}
|
||||
|
||||
"""
|
||||
msg = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": "Output: "}
|
||||
]
|
||||
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
|
||||
if isinstance(kwd, tuple):
|
||||
|
@ -384,8 +382,8 @@ Output:
|
|||
return json_repair.loads(kwd)
|
||||
except json_repair.JSONDecodeError:
|
||||
try:
|
||||
result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip()
|
||||
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
||||
result = kwd.replace(prompt[:-1], "").replace("user", "").replace("model", "").strip()
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
return json_repair.loads(result)
|
||||
except Exception as e:
|
||||
logging.exception(f"JSON parsing error: {result} -> {e}")
|
||||
|
|
Loading…
Reference in New Issue