refactor(api): 重构聊天模块中图片插入逻辑,使其能够插入到引用块中间,并优化es图片读取逻辑路径
- 移除了多轮对话优化和推理相关代码 - 新增图片 Markdown 插入逻辑,支持从 MinIO 中获取图片 - 优化了引用文献的处理流程 - 简化了错误提示信息 - 添加了时间信息统计
This commit is contained in:
parent
370e0e62db
commit
66fbd297aa
|
@ -193,7 +193,7 @@ ollama pull bge-m3:latest
|
|||
- **不修改代码**:若仅原样运行(不修改、不衍生),仍需遵守AGPLv3,包括:
|
||||
- 提供完整的源代码(即使未修改)。
|
||||
- 若作为网络服务提供,需允许用户下载对应源码(AGPLv3第13条)。
|
||||
- **不允许闭源商用**:如需闭源(不公开修改后的代码)商用,需获得获得所有代码版权持有人的书面授权(包括上游AGPLv3代码作者)
|
||||
- **不允许闭源商用**:如需闭源(不公开修改后的代码)商用,需获得所有代码版权持有人的书面授权(包括上游AGPLv3代码作者)
|
||||
|
||||
3. **免责声明**
|
||||
本项目不提供任何担保,使用者需自行承担合规风险。若需法律建议,请咨询专业律师。
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
import mysql.connector
|
||||
import os
|
||||
import redis
|
||||
from minio import Minio
|
||||
from dotenv import load_dotenv
|
||||
from elasticsearch import Elasticsearch
|
||||
from pathlib import Path
|
||||
|
||||
# 加载环境变量
|
||||
env_path = Path(__file__).parent.parent.parent / "docker" / ".env"
|
||||
load_dotenv(env_path)
|
||||
|
||||
|
||||
# 检测是否在Docker容器中运行
|
||||
def is_running_in_docker():
|
||||
# 检查是否存在/.dockerenv文件
|
||||
docker_env = os.path.exists("/.dockerenv")
|
||||
# 或者检查cgroup中是否包含docker字符串
|
||||
try:
|
||||
with open("/proc/self/cgroup", "r") as f:
|
||||
return docker_env or "docker" in f.read()
|
||||
except: # noqa: E722
|
||||
return docker_env
|
||||
|
||||
|
||||
# 根据运行环境选择合适的主机地址和端口
|
||||
if is_running_in_docker():
|
||||
MYSQL_HOST = "mysql"
|
||||
MYSQL_PORT = 3306
|
||||
MINIO_HOST = os.getenv("MINIO_VISIT_HOST", "host.docker.internal")
|
||||
MINIO_PORT = 9000
|
||||
ES_HOST = "es01"
|
||||
ES_PORT = 9200
|
||||
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
|
||||
REDIS_PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
else:
|
||||
MYSQL_HOST = "localhost"
|
||||
MYSQL_PORT = int(os.getenv("MYSQL_PORT", "5455"))
|
||||
MINIO_HOST = "localhost"
|
||||
MINIO_PORT = int(os.getenv("MINIO_PORT", "9000"))
|
||||
ES_HOST = "localhost"
|
||||
ES_PORT = int(os.getenv("ES_PORT", "9200"))
|
||||
REDIS_HOST = "localhost"
|
||||
REDIS_PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
|
||||
|
||||
# 数据库连接配置
|
||||
DB_CONFIG = {
|
||||
"host": MYSQL_HOST,
|
||||
"port": MYSQL_PORT,
|
||||
"user": "root",
|
||||
"password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"),
|
||||
"database": "rag_flow",
|
||||
}
|
||||
|
||||
# MinIO连接配置
|
||||
MINIO_CONFIG = {
|
||||
"endpoint": f"{MINIO_HOST}:{MINIO_PORT}",
|
||||
"access_key": os.getenv("MINIO_USER", "rag_flow"),
|
||||
"secret_key": os.getenv("MINIO_PASSWORD", "infini_rag_flow"),
|
||||
"secure": False,
|
||||
}
|
||||
|
||||
# Elasticsearch连接配置
|
||||
ES_CONFIG = {
|
||||
"host": f"http://{ES_HOST}:{ES_PORT}",
|
||||
"user": os.getenv("ELASTIC_USER", "elastic"),
|
||||
"password": os.getenv("ELASTIC_PASSWORD", "infini_rag_flow"),
|
||||
"use_ssl": os.getenv("ES_USE_SSL", "false").lower() == "true",
|
||||
}
|
||||
|
||||
# Redis连接配置
|
||||
REDIS_CONFIG = {
|
||||
"host": REDIS_HOST,
|
||||
"port": REDIS_PORT,
|
||||
"password": os.getenv("REDIS_PASSWORD", "infini_rag_flow"),
|
||||
"decode_responses": False,
|
||||
}
|
||||
|
||||
|
||||
def get_db_connection():
|
||||
"""创建MySQL数据库连接"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
return conn
|
||||
except Exception as e:
|
||||
print(f"MySQL连接失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
def get_minio_client():
|
||||
"""创建MinIO客户端连接"""
|
||||
try:
|
||||
minio_client = Minio(endpoint=MINIO_CONFIG["endpoint"], access_key=MINIO_CONFIG["access_key"], secret_key=MINIO_CONFIG["secret_key"], secure=MINIO_CONFIG["secure"])
|
||||
return minio_client
|
||||
except Exception as e:
|
||||
print(f"MinIO连接失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
def get_es_client():
|
||||
"""创建Elasticsearch客户端连接"""
|
||||
try:
|
||||
# 构建连接参数
|
||||
es_params = {"hosts": [ES_CONFIG["host"]]}
|
||||
|
||||
# 添加认证信息
|
||||
if ES_CONFIG["user"] and ES_CONFIG["password"]:
|
||||
es_params["basic_auth"] = (ES_CONFIG["user"], ES_CONFIG["password"])
|
||||
|
||||
# 添加SSL配置
|
||||
if ES_CONFIG["use_ssl"]:
|
||||
es_params["use_ssl"] = True
|
||||
es_params["verify_certs"] = False # 在开发环境中可以设置为False,生产环境应该设置为True
|
||||
|
||||
es_client = Elasticsearch(**es_params)
|
||||
return es_client
|
||||
except Exception as e:
|
||||
print(f"Elasticsearch连接失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
def get_redis_connection():
|
||||
"""创建Redis连接"""
|
||||
try:
|
||||
# 使用配置创建Redis连接
|
||||
r = redis.Redis(**REDIS_CONFIG)
|
||||
# 测试连接
|
||||
r.ping()
|
||||
return r
|
||||
except Exception as e:
|
||||
print(f"Redis连接失败: {str(e)}")
|
||||
raise e
|
|
@ -30,6 +30,7 @@ from rag.app.tag import label_question
|
|||
from rag.nlp.search import index_name
|
||||
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, chunks_format, citation_prompt
|
||||
from rag.utils import rmSpace, num_tokens_from_string
|
||||
from .database import MINIO_CONFIG
|
||||
|
||||
|
||||
class DialogService(CommonService):
|
||||
|
@ -152,11 +153,6 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
if p["key"] not in kwargs:
|
||||
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
||||
|
||||
# 不再使用多轮对话优化
|
||||
# if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
||||
# questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
||||
# else:
|
||||
# questions = questions[-1:]
|
||||
questions = questions[-1:]
|
||||
|
||||
refine_question_ts = timer()
|
||||
|
@ -181,41 +177,6 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
|
||||
knowledges = []
|
||||
|
||||
# 不再使用推理
|
||||
# if prompt_config.get("reasoning", False):
|
||||
# reasoner = DeepResearcher(chat_mdl,
|
||||
# prompt_config,
|
||||
# partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3))
|
||||
|
||||
# for think in reasoner.thinking(kbinfos, " ".join(questions)):
|
||||
# if isinstance(think, str):
|
||||
# thought = think
|
||||
# knowledges = [t for t in think.split("\n") if t]
|
||||
# elif stream:
|
||||
# yield think
|
||||
# else:
|
||||
# kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
||||
# dialog.similarity_threshold,
|
||||
# dialog.vector_similarity_weight,
|
||||
# doc_ids=attachments,
|
||||
# top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
|
||||
# rank_feature=label_question(" ".join(questions), kbs)
|
||||
# )
|
||||
# if prompt_config.get("tavily_api_key"):
|
||||
# tav = Tavily(prompt_config["tavily_api_key"])
|
||||
# tav_res = tav.retrieve_chunks(" ".join(questions))
|
||||
# kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
# kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
# if prompt_config.get("use_kg"):
|
||||
# ck = settings.kg_retrievaler.retrieval(" ".join(questions),
|
||||
# tenant_ids,
|
||||
# dialog.kb_ids,
|
||||
# embd_mdl,
|
||||
# LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||
# if ck["content_with_weight"]:
|
||||
# kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
# knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
kbinfos = retriever.retrieval(
|
||||
" ".join(questions),
|
||||
embd_mdl,
|
||||
|
@ -261,15 +222,18 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
|
||||
|
||||
refs = []
|
||||
image_markdowns = [] # 用于存储图片的 Markdown 字符串
|
||||
ans = answer.split("</think>")
|
||||
think = ""
|
||||
if len(ans) == 2:
|
||||
think = ans[0] + "</think>"
|
||||
answer = ans[1]
|
||||
|
||||
cited_chunk_indices = set()
|
||||
inserted_images = {}
|
||||
processed_image_urls = set()
|
||||
|
||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
|
||||
cited_chunk_indices = set() # 用于存储被引用的 chunk 索引
|
||||
# 获取引用的 chunk 索引
|
||||
if not re.search(r"##[0-9]+\$\$", answer):
|
||||
answer, idx = retriever.insert_citations(
|
||||
answer,
|
||||
|
@ -279,35 +243,41 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
tkweight=1 - dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight,
|
||||
)
|
||||
cited_chunk_indices = idx # 获取 insert_citations 返回的索引
|
||||
|
||||
cited_chunk_indices = idx
|
||||
else:
|
||||
idx = set([])
|
||||
for r in re.finditer(r"##([0-9]+)\$\$", answer):
|
||||
i = int(r.group(1))
|
||||
if i < len(kbinfos["chunks"]):
|
||||
idx.add(i)
|
||||
cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引
|
||||
cited_chunk_indices.add(i)
|
||||
|
||||
# 根据引用的 chunk 索引提取图像信息并生成 Markdown
|
||||
cited_doc_ids = set()
|
||||
processed_image_urls = set() # 避免重复添加同一张图片
|
||||
print(f"DEBUG: cited_chunk_indices = {cited_chunk_indices}")
|
||||
for i in cited_chunk_indices:
|
||||
i_int = int(i)
|
||||
if i_int < len(kbinfos["chunks"]):
|
||||
chunk = kbinfos["chunks"][i_int]
|
||||
cited_doc_ids.add(chunk["doc_id"])
|
||||
print(f"DEBUG: chunk = {chunk}")
|
||||
# 检查 chunk 是否有关联的 image_id (URL) 且未被处理过
|
||||
print(f"DEBUG: chunk_id={chunk.get('chunk_id', i_int)}, image_id={chunk.get('image_id')}")
|
||||
img_url = chunk.get("image_id")
|
||||
if img_url and img_url not in processed_image_urls:
|
||||
# 生成 Markdown 字符串,alt text 可以简单设为 "image" 或 chunk ID
|
||||
image_markdowns.append(f"\n")
|
||||
processed_image_urls.add(img_url) # 标记为已处理
|
||||
# 处理图片插入
|
||||
def insert_image_markdown(match):
|
||||
idx = int(match.group(1))
|
||||
if idx >= len(kbinfos["chunks"]):
|
||||
return match.group(0)
|
||||
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
chunk = kbinfos["chunks"][idx]
|
||||
img_path = chunk.get("image_id")
|
||||
if not img_path:
|
||||
return match.group(0)
|
||||
|
||||
protocol = "https" if MINIO_CONFIG.get("secure", False) else "http"
|
||||
img_url = f"{protocol}://{MINIO_CONFIG['endpoint']}/{img_path}"
|
||||
|
||||
if img_url in processed_image_urls:
|
||||
return match.group(0)
|
||||
|
||||
processed_image_urls.add(img_url)
|
||||
inserted_images[idx] = img_url
|
||||
|
||||
# 插入图片,不加任何括号包裹引用标记
|
||||
return f"{match.group(0)}\n\n"
|
||||
|
||||
# 用正则替换插图
|
||||
answer = re.sub(r"##(\d+)\$\$", insert_image_markdown, answer)
|
||||
|
||||
# 清理引用文献信息
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in cited_chunk_indices])
|
||||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
if not recall_docs:
|
||||
recall_docs = kbinfos["doc_aggs"]
|
||||
|
@ -318,14 +288,12 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
if c.get("vector"):
|
||||
del c["vector"]
|
||||
|
||||
# 将图片的 Markdown 字符串追加到回答末尾
|
||||
if image_markdowns:
|
||||
answer += "".join(image_markdowns)
|
||||
|
||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||
# 特殊错误提示
|
||||
if "invalid key" in answer.lower() or "invalid api" in answer.lower():
|
||||
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
||||
finish_chat_ts = timer()
|
||||
|
||||
# 时间信息拼接
|
||||
finish_chat_ts = timer()
|
||||
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
|
||||
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
|
||||
create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
|
||||
|
@ -339,6 +307,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
|
||||
prompt += "\n\n### Query:\n%s" % " ".join(questions)
|
||||
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
|
||||
|
||||
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
|
||||
|
||||
if stream:
|
||||
|
|
|
@ -15,6 +15,7 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
|||
from magic_pdf.config.enums import SupportedPdfParseMethod
|
||||
from magic_pdf.data.read_api import read_local_office, read_local_images
|
||||
from utils import generate_uuid
|
||||
from urllib.parse import urlparse
|
||||
from .rag_tokenizer import RagTokenizer
|
||||
from .excel_parser import parse_excel
|
||||
|
||||
|
@ -679,19 +680,20 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
|
|||
for img_info in image_info_list:
|
||||
# 计算文本块与图片的"距离"
|
||||
distance = abs(i - img_info["position"]) # 使用位置差作为距离度量
|
||||
# 如果文本块与图片的距离间隔小于10个块,则认为块与图片是相关的
|
||||
if distance < 10:
|
||||
# 如果文本块与图片的距离间隔小于5个块,则认为块与图片是相关的
|
||||
if distance < 5:
|
||||
nearest_image = img_info
|
||||
|
||||
# 如果找到了最近的图片,则更新文本块的img_id
|
||||
if nearest_image:
|
||||
# v0.4.1更新,改成存储提取其相对路径部分
|
||||
parsed_url = urlparse(nearest_image["url"])
|
||||
relative_path = parsed_url.path.lstrip("/") # 去掉开头的斜杠
|
||||
# 更新ES中的文档
|
||||
direct_update = {"doc": {"img_id": nearest_image["url"]}}
|
||||
direct_update = {"doc": {"img_id": relative_path}}
|
||||
es_client.update(index=index_name, id=chunk_id, body=direct_update, refresh=True)
|
||||
|
||||
index_name = f"ragflow_{tenant_id}"
|
||||
|
||||
print(f"[Parser-INFO] 更新文本块 {chunk_id} 的图片关联: {nearest_image['url']}")
|
||||
print(f"[Parser-INFO] 更新文本块 {chunk_id} 的图片关联: {relative_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Parser-ERROR] 更新文本块图片关联失败: {e}")
|
||||
|
|
Loading…
Reference in New Issue