From 5cae4c2d7f2edd6c300d8771cf55d76815d6b1bb Mon Sep 17 00:00:00 2001 From: zstar <65890619+zstar1003@users.noreply.github.com> Date: Sat, 17 May 2025 15:28:21 +0800 Subject: [PATCH] =?UTF-8?q?refactor(prompts):=20=E5=BC=95=E7=94=A8?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=E8=AF=8D=E5=A2=9E=E5=BC=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rag/prompts.py | 72 ++++++++++++++++++++++++-------------------------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/rag/prompts.py b/rag/prompts.py index 87397b9..2e2723c 100644 --- a/rag/prompts.py +++ b/rag/prompts.py @@ -32,16 +32,19 @@ def chunks_format(reference): def get_value(d, k1, k2): return d.get(k1, d.get(k2)) - return [{ - "id": get_value(chunk, "chunk_id", "id"), - "content": get_value(chunk, "content", "content_with_weight"), - "document_id": get_value(chunk, "doc_id", "document_id"), - "document_name": get_value(chunk, "docnm_kwd", "document_name"), - "dataset_id": get_value(chunk, "kb_id", "dataset_id"), - "image_id": get_value(chunk, "image_id", "img_id"), - "positions": get_value(chunk, "positions", "position_int"), - "url": chunk.get("url") - } for chunk in reference.get("chunks", [])] + return [ + { + "id": get_value(chunk, "chunk_id", "id"), + "content": get_value(chunk, "content", "content_with_weight"), + "document_id": get_value(chunk, "doc_id", "document_id"), + "document_name": get_value(chunk, "docnm_kwd", "document_name"), + "dataset_id": get_value(chunk, "kb_id", "dataset_id"), + "image_id": get_value(chunk, "image_id", "img_id"), + "positions": get_value(chunk, "positions", "position_int"), + "url": chunk.get("url"), + } + for chunk in reference.get("chunks", []) + ] def llm_id2llm_type(llm_id): @@ -57,21 +60,21 @@ def llm_id2llm_type(llm_id): def message_fit_in(msg, max_length=4000): """ 调整消息列表使其token总数不超过max_length限制 - + 参数: msg: 消息列表,每个元素为包含role和content的字典 max_length: 最大token数限制,默认4000 - + 返回: tuple: (实际token数, 调整后的消息列表) """ + def count(): """计算当前消息列表的总token数""" nonlocal msg tks_cnts = [] for m in msg: - tks_cnts.append( - {"role": m["role"], "count": num_tokens_from_string(m["content"])}) + tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])}) total = 0 for m in tks_cnts: total += m["count"] @@ -81,7 +84,7 @@ def message_fit_in(msg, max_length=4000): # 如果不超限制,直接返回 if c < max_length: return c, msg - + # 第一次精简:保留系统消息和最后一条消息 msg_ = [m for m in msg if m["role"] == "system"] if len(msg) > 1: @@ -90,20 +93,20 @@ def message_fit_in(msg, max_length=4000): c = count() if c < max_length: return c, msg - + # 计算系统消息和最后一条消息的token数 ll = num_tokens_from_string(msg_[0]["content"]) ll2 = num_tokens_from_string(msg_[-1]["content"]) # 如果系统消息占比超过80%,则截断系统消息 if ll / (ll + ll2) > 0.8: m = msg_[0]["content"] - m = encoder.decode(encoder.encode(m)[:max_length - ll2]) + m = encoder.decode(encoder.encode(m)[: max_length - ll2]) msg[0]["content"] = m return max_length, msg - + # 否则截断最后一条消息 m = msg_[-1]["content"] - m = encoder.decode(encoder.encode(m)[:max_length - ll2]) + m = encoder.decode(encoder.encode(m)[: max_length - ll2]) msg[-1]["content"] = m return max_length, msg @@ -111,18 +114,18 @@ def message_fit_in(msg, max_length=4000): def kb_prompt(kbinfos, max_tokens): """ 将检索到的知识库内容格式化为适合大语言模型的提示词 - + 参数: kbinfos (dict): 检索结果,包含chunks等信息 max_tokens (int): 模型的最大token限制 - + 流程: 1. 提取所有检索到的文档片段内容 2. 计算token数量,确保不超过模型限制 3. 获取文档元数据 4. 按文档名组织文档片段 5. 格式化为结构化提示词 - + 返回: list: 格式化后的知识库内容列表,每个元素是一个文档的相关信息 """ @@ -134,7 +137,7 @@ def kb_prompt(kbinfos, max_tokens): chunks_num += 1 if max_tokens * 0.97 < used_token_count: knowledges = knowledges[:i] - logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}") + logging.warning(f"Not all the retrieval into prompt: {i + 1}/{len(knowledges)}") break docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]]) @@ -163,6 +166,10 @@ def citation_prompt(): - 以格式 '##i$$ ##j$$'插入引用,其中 i, j 是所引用内容的 ID,并用 '##' 和 '$$' 包裹。 - 在句子末尾插入引用,每个句子最多 4 个引用。 - 如果答案内容不来自检索到的文本块,则不要插入引用。 +- 不要使用独立的文档 ID(例如 `#ID#`)。 +- 在任何情况下,不得使用其他引用样式或格式(例如 `~~i==`、`[i]`、`(i)` 等)。 +- 引用必须始终使用 `##i$$` 格式。 +- 任何未能遵守上述规则的情况,包括但不限于格式错误、使用禁止的样式或不支持的引用,都将被视为错误,应跳过为该句添加引用。 --- 示例 --- : 以下是知识库: @@ -210,10 +217,7 @@ def keyword_extraction(chat_mdl, content, topn=3): ### 文本内容 {content} """ - msg = [ - {"role": "system", "content": prompt}, - {"role": "user", "content": "Output: "} - ] + msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): @@ -240,10 +244,7 @@ Requirements: {content} """ - msg = [ - {"role": "system", "content": prompt}, - {"role": "user", "content": "Output: "} - ] + msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): @@ -368,10 +369,7 @@ Output: {content} """ - msg = [ - {"role": "system", "content": prompt}, - {"role": "user", "content": "Output: "} - ] + msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5}) if isinstance(kwd, tuple): @@ -384,8 +382,8 @@ Output: return json_repair.loads(kwd) except json_repair.JSONDecodeError: try: - result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip() - result = '{' + result.split('{')[1].split('}')[0] + '}' + result = kwd.replace(prompt[:-1], "").replace("user", "").replace("model", "").strip() + result = "{" + result.split("{")[1].split("}")[0] + "}" return json_repair.loads(result) except Exception as e: logging.exception(f"JSON parsing error: {result} -> {e}")