refactor(api): 重构聊天模块中图片插入逻辑,使其能够插入到引用块中间,并优化es图片读取逻辑路径

- 移除了多轮对话优化和推理相关代码
- 新增图片 Markdown 插入逻辑,支持从 MinIO 中获取图片
- 优化了引用文献的处理流程
- 简化了错误提示信息
- 添加了时间信息统计
This commit is contained in:
zstar 2025-06-07 13:00:07 +08:00
parent 370e0e62db
commit 66fbd297aa
4 changed files with 183 additions and 79 deletions

View File

@ -193,7 +193,7 @@ ollama pull bge-m3:latest
- **不修改代码**若仅原样运行不修改、不衍生仍需遵守AGPLv3包括
- 提供完整的源代码(即使未修改)。
- 若作为网络服务提供需允许用户下载对应源码AGPLv3第13条
- **不允许闭源商用**:如需闭源(不公开修改后的代码)商用,需获得获得所有代码版权持有人的书面授权包括上游AGPLv3代码作者
- **不允许闭源商用**如需闭源不公开修改后的代码商用需获得所有代码版权持有人的书面授权包括上游AGPLv3代码作者
3. **免责声明**
本项目不提供任何担保,使用者需自行承担合规风险。若需法律建议,请咨询专业律师。

133
api/db/services/database.py Normal file
View File

@ -0,0 +1,133 @@
import mysql.connector
import os
import redis
from minio import Minio
from dotenv import load_dotenv
from elasticsearch import Elasticsearch
from pathlib import Path
# 加载环境变量
env_path = Path(__file__).parent.parent.parent / "docker" / ".env"
load_dotenv(env_path)
# 检测是否在Docker容器中运行
def is_running_in_docker():
# 检查是否存在/.dockerenv文件
docker_env = os.path.exists("/.dockerenv")
# 或者检查cgroup中是否包含docker字符串
try:
with open("/proc/self/cgroup", "r") as f:
return docker_env or "docker" in f.read()
except: # noqa: E722
return docker_env
# 根据运行环境选择合适的主机地址和端口
if is_running_in_docker():
MYSQL_HOST = "mysql"
MYSQL_PORT = 3306
MINIO_HOST = os.getenv("MINIO_VISIT_HOST", "host.docker.internal")
MINIO_PORT = 9000
ES_HOST = "es01"
ES_PORT = 9200
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
REDIS_PORT = int(os.getenv("REDIS_PORT", "6379"))
else:
MYSQL_HOST = "localhost"
MYSQL_PORT = int(os.getenv("MYSQL_PORT", "5455"))
MINIO_HOST = "localhost"
MINIO_PORT = int(os.getenv("MINIO_PORT", "9000"))
ES_HOST = "localhost"
ES_PORT = int(os.getenv("ES_PORT", "9200"))
REDIS_HOST = "localhost"
REDIS_PORT = int(os.getenv("REDIS_PORT", "6379"))
# 数据库连接配置
DB_CONFIG = {
"host": MYSQL_HOST,
"port": MYSQL_PORT,
"user": "root",
"password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"),
"database": "rag_flow",
}
# MinIO连接配置
MINIO_CONFIG = {
"endpoint": f"{MINIO_HOST}:{MINIO_PORT}",
"access_key": os.getenv("MINIO_USER", "rag_flow"),
"secret_key": os.getenv("MINIO_PASSWORD", "infini_rag_flow"),
"secure": False,
}
# Elasticsearch连接配置
ES_CONFIG = {
"host": f"http://{ES_HOST}:{ES_PORT}",
"user": os.getenv("ELASTIC_USER", "elastic"),
"password": os.getenv("ELASTIC_PASSWORD", "infini_rag_flow"),
"use_ssl": os.getenv("ES_USE_SSL", "false").lower() == "true",
}
# Redis连接配置
REDIS_CONFIG = {
"host": REDIS_HOST,
"port": REDIS_PORT,
"password": os.getenv("REDIS_PASSWORD", "infini_rag_flow"),
"decode_responses": False,
}
def get_db_connection():
"""创建MySQL数据库连接"""
try:
conn = mysql.connector.connect(**DB_CONFIG)
return conn
except Exception as e:
print(f"MySQL连接失败: {str(e)}")
raise e
def get_minio_client():
"""创建MinIO客户端连接"""
try:
minio_client = Minio(endpoint=MINIO_CONFIG["endpoint"], access_key=MINIO_CONFIG["access_key"], secret_key=MINIO_CONFIG["secret_key"], secure=MINIO_CONFIG["secure"])
return minio_client
except Exception as e:
print(f"MinIO连接失败: {str(e)}")
raise e
def get_es_client():
"""创建Elasticsearch客户端连接"""
try:
# 构建连接参数
es_params = {"hosts": [ES_CONFIG["host"]]}
# 添加认证信息
if ES_CONFIG["user"] and ES_CONFIG["password"]:
es_params["basic_auth"] = (ES_CONFIG["user"], ES_CONFIG["password"])
# 添加SSL配置
if ES_CONFIG["use_ssl"]:
es_params["use_ssl"] = True
es_params["verify_certs"] = False # 在开发环境中可以设置为False生产环境应该设置为True
es_client = Elasticsearch(**es_params)
return es_client
except Exception as e:
print(f"Elasticsearch连接失败: {str(e)}")
raise e
def get_redis_connection():
"""创建Redis连接"""
try:
# 使用配置创建Redis连接
r = redis.Redis(**REDIS_CONFIG)
# 测试连接
r.ping()
return r
except Exception as e:
print(f"Redis连接失败: {str(e)}")
raise e

View File

@ -30,6 +30,7 @@ from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, chunks_format, citation_prompt
from rag.utils import rmSpace, num_tokens_from_string
from .database import MINIO_CONFIG
class DialogService(CommonService):
@ -152,11 +153,6 @@ def chat(dialog, messages, stream=True, **kwargs):
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
# 不再使用多轮对话优化
# if len(questions) > 1 and prompt_config.get("refine_multiturn"):
# questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
# else:
# questions = questions[-1:]
questions = questions[-1:]
refine_question_ts = timer()
@ -181,41 +177,6 @@ def chat(dialog, messages, stream=True, **kwargs):
knowledges = []
# 不再使用推理
# if prompt_config.get("reasoning", False):
# reasoner = DeepResearcher(chat_mdl,
# prompt_config,
# partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3))
# for think in reasoner.thinking(kbinfos, " ".join(questions)):
# if isinstance(think, str):
# thought = think
# knowledges = [t for t in think.split("\n") if t]
# elif stream:
# yield think
# else:
# kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
# dialog.similarity_threshold,
# dialog.vector_similarity_weight,
# doc_ids=attachments,
# top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
# rank_feature=label_question(" ".join(questions), kbs)
# )
# if prompt_config.get("tavily_api_key"):
# tav = Tavily(prompt_config["tavily_api_key"])
# tav_res = tav.retrieve_chunks(" ".join(questions))
# kbinfos["chunks"].extend(tav_res["chunks"])
# kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
# if prompt_config.get("use_kg"):
# ck = settings.kg_retrievaler.retrieval(" ".join(questions),
# tenant_ids,
# dialog.kb_ids,
# embd_mdl,
# LLMBundle(dialog.tenant_id, LLMType.CHAT))
# if ck["content_with_weight"]:
# kbinfos["chunks"].insert(0, ck)
# knowledges = kb_prompt(kbinfos, max_tokens)
kbinfos = retriever.retrieval(
" ".join(questions),
embd_mdl,
@ -261,15 +222,18 @@ def chat(dialog, messages, stream=True, **kwargs):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
refs = []
image_markdowns = [] # 用于存储图片的 Markdown 字符串
ans = answer.split("</think>")
think = ""
if len(ans) == 2:
think = ans[0] + "</think>"
answer = ans[1]
cited_chunk_indices = set()
inserted_images = {}
processed_image_urls = set()
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
cited_chunk_indices = set() # 用于存储被引用的 chunk 索引
# 获取引用的 chunk 索引
if not re.search(r"##[0-9]+\$\$", answer):
answer, idx = retriever.insert_citations(
answer,
@ -279,35 +243,41 @@ def chat(dialog, messages, stream=True, **kwargs):
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight,
)
cited_chunk_indices = idx # 获取 insert_citations 返回的索引
cited_chunk_indices = idx
else:
idx = set([])
for r in re.finditer(r"##([0-9]+)\$\$", answer):
i = int(r.group(1))
if i < len(kbinfos["chunks"]):
idx.add(i)
cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引
cited_chunk_indices.add(i)
# 根据引用的 chunk 索引提取图像信息并生成 Markdown
cited_doc_ids = set()
processed_image_urls = set() # 避免重复添加同一张图片
print(f"DEBUG: cited_chunk_indices = {cited_chunk_indices}")
for i in cited_chunk_indices:
i_int = int(i)
if i_int < len(kbinfos["chunks"]):
chunk = kbinfos["chunks"][i_int]
cited_doc_ids.add(chunk["doc_id"])
print(f"DEBUG: chunk = {chunk}")
# 检查 chunk 是否有关联的 image_id (URL) 且未被处理过
print(f"DEBUG: chunk_id={chunk.get('chunk_id', i_int)}, image_id={chunk.get('image_id')}")
img_url = chunk.get("image_id")
if img_url and img_url not in processed_image_urls:
# 生成 Markdown 字符串alt text 可以简单设为 "image" 或 chunk ID
image_markdowns.append(f"\n![{img_url}]({img_url})")
processed_image_urls.add(img_url) # 标记为已处理
# 处理图片插入
def insert_image_markdown(match):
idx = int(match.group(1))
if idx >= len(kbinfos["chunks"]):
return match.group(0)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
chunk = kbinfos["chunks"][idx]
img_path = chunk.get("image_id")
if not img_path:
return match.group(0)
protocol = "https" if MINIO_CONFIG.get("secure", False) else "http"
img_url = f"{protocol}://{MINIO_CONFIG['endpoint']}/{img_path}"
if img_url in processed_image_urls:
return match.group(0)
processed_image_urls.add(img_url)
inserted_images[idx] = img_url
# 插入图片,不加任何括号包裹引用标记
return f"{match.group(0)}\n\n![image]({img_url})"
# 用正则替换插图
answer = re.sub(r"##(\d+)\$\$", insert_image_markdown, answer)
# 清理引用文献信息
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in cited_chunk_indices])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
@ -318,14 +288,12 @@ def chat(dialog, messages, stream=True, **kwargs):
if c.get("vector"):
del c["vector"]
# 将图片的 Markdown 字符串追加到回答末尾
if image_markdowns:
answer += "".join(image_markdowns)
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
# 特殊错误提示
if "invalid key" in answer.lower() or "invalid api" in answer.lower():
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
finish_chat_ts = timer()
# 时间信息拼接
finish_chat_ts = timer()
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
@ -339,6 +307,7 @@ def chat(dialog, messages, stream=True, **kwargs):
prompt += "\n\n### Query:\n%s" % " ".join(questions)
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
if stream:

View File

@ -15,6 +15,7 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.read_api import read_local_office, read_local_images
from utils import generate_uuid
from urllib.parse import urlparse
from .rag_tokenizer import RagTokenizer
from .excel_parser import parse_excel
@ -679,19 +680,20 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
for img_info in image_info_list:
# 计算文本块与图片的"距离"
distance = abs(i - img_info["position"]) # 使用位置差作为距离度量
# 如果文本块与图片的距离间隔小于10个块,则认为块与图片是相关的
if distance < 10:
# 如果文本块与图片的距离间隔小于5个块,则认为块与图片是相关的
if distance < 5:
nearest_image = img_info
# 如果找到了最近的图片则更新文本块的img_id
if nearest_image:
# v0.4.1更新,改成存储提取其相对路径部分
parsed_url = urlparse(nearest_image["url"])
relative_path = parsed_url.path.lstrip("/") # 去掉开头的斜杠
# 更新ES中的文档
direct_update = {"doc": {"img_id": nearest_image["url"]}}
direct_update = {"doc": {"img_id": relative_path}}
es_client.update(index=index_name, id=chunk_id, body=direct_update, refresh=True)
index_name = f"ragflow_{tenant_id}"
print(f"[Parser-INFO] 更新文本块 {chunk_id} 的图片关联: {nearest_image['url']}")
print(f"[Parser-INFO] 更新文本块 {chunk_id} 的图片关联: {relative_path}")
except Exception as e:
print(f"[Parser-ERROR] 更新文本块图片关联失败: {e}")