From 66fbd297aa665cd8408e78bdca82566a3d34c243 Mon Sep 17 00:00:00 2001 From: zstar <65890619+zstar1003@users.noreply.github.com> Date: Sat, 7 Jun 2025 13:00:07 +0800 Subject: [PATCH] =?UTF-8?q?refactor(api):=20=E9=87=8D=E6=9E=84=E8=81=8A?= =?UTF-8?q?=E5=A4=A9=E6=A8=A1=E5=9D=97=E4=B8=AD=E5=9B=BE=E7=89=87=E6=8F=92?= =?UTF-8?q?=E5=85=A5=E9=80=BB=E8=BE=91=EF=BC=8C=E4=BD=BF=E5=85=B6=E8=83=BD?= =?UTF-8?q?=E5=A4=9F=E6=8F=92=E5=85=A5=E5=88=B0=E5=BC=95=E7=94=A8=E5=9D=97?= =?UTF-8?q?=E4=B8=AD=E9=97=B4=EF=BC=8C=E5=B9=B6=E4=BC=98=E5=8C=96es?= =?UTF-8?q?=E5=9B=BE=E7=89=87=E8=AF=BB=E5=8F=96=E9=80=BB=E8=BE=91=E8=B7=AF?= =?UTF-8?q?=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除了多轮对话优化和推理相关代码 - 新增图片 Markdown 插入逻辑,支持从 MinIO 中获取图片 - 优化了引用文献的处理流程 - 简化了错误提示信息 - 添加了时间信息统计 --- README.md | 2 +- api/db/services/database.py | 133 ++++++++++++++++++ api/db/services/dialog_service.py | 113 ++++++--------- .../knowledgebases/document_parser.py | 14 +- 4 files changed, 183 insertions(+), 79 deletions(-) create mode 100644 api/db/services/database.py diff --git a/README.md b/README.md index 154ea9d..9377db6 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ ollama pull bge-m3:latest - **不修改代码**:若仅原样运行(不修改、不衍生),仍需遵守AGPLv3,包括: - 提供完整的源代码(即使未修改)。 - 若作为网络服务提供,需允许用户下载对应源码(AGPLv3第13条)。 - - **不允许闭源商用**:如需闭源(不公开修改后的代码)商用,需获得获得所有代码版权持有人的书面授权(包括上游AGPLv3代码作者) + - **不允许闭源商用**:如需闭源(不公开修改后的代码)商用,需获得所有代码版权持有人的书面授权(包括上游AGPLv3代码作者) 3. **免责声明** 本项目不提供任何担保,使用者需自行承担合规风险。若需法律建议,请咨询专业律师。 diff --git a/api/db/services/database.py b/api/db/services/database.py new file mode 100644 index 0000000..cb9d826 --- /dev/null +++ b/api/db/services/database.py @@ -0,0 +1,133 @@ +import mysql.connector +import os +import redis +from minio import Minio +from dotenv import load_dotenv +from elasticsearch import Elasticsearch +from pathlib import Path + +# 加载环境变量 +env_path = Path(__file__).parent.parent.parent / "docker" / ".env" +load_dotenv(env_path) + + +# 检测是否在Docker容器中运行 +def is_running_in_docker(): + # 检查是否存在/.dockerenv文件 + docker_env = os.path.exists("/.dockerenv") + # 或者检查cgroup中是否包含docker字符串 + try: + with open("/proc/self/cgroup", "r") as f: + return docker_env or "docker" in f.read() + except: # noqa: E722 + return docker_env + + +# 根据运行环境选择合适的主机地址和端口 +if is_running_in_docker(): + MYSQL_HOST = "mysql" + MYSQL_PORT = 3306 + MINIO_HOST = os.getenv("MINIO_VISIT_HOST", "host.docker.internal") + MINIO_PORT = 9000 + ES_HOST = "es01" + ES_PORT = 9200 + REDIS_HOST = os.getenv("REDIS_HOST", "redis") + REDIS_PORT = int(os.getenv("REDIS_PORT", "6379")) +else: + MYSQL_HOST = "localhost" + MYSQL_PORT = int(os.getenv("MYSQL_PORT", "5455")) + MINIO_HOST = "localhost" + MINIO_PORT = int(os.getenv("MINIO_PORT", "9000")) + ES_HOST = "localhost" + ES_PORT = int(os.getenv("ES_PORT", "9200")) + REDIS_HOST = "localhost" + REDIS_PORT = int(os.getenv("REDIS_PORT", "6379")) + + +# 数据库连接配置 +DB_CONFIG = { + "host": MYSQL_HOST, + "port": MYSQL_PORT, + "user": "root", + "password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"), + "database": "rag_flow", +} + +# MinIO连接配置 +MINIO_CONFIG = { + "endpoint": f"{MINIO_HOST}:{MINIO_PORT}", + "access_key": os.getenv("MINIO_USER", "rag_flow"), + "secret_key": os.getenv("MINIO_PASSWORD", "infini_rag_flow"), + "secure": False, +} + +# Elasticsearch连接配置 +ES_CONFIG = { + "host": f"http://{ES_HOST}:{ES_PORT}", + "user": os.getenv("ELASTIC_USER", "elastic"), + "password": os.getenv("ELASTIC_PASSWORD", "infini_rag_flow"), + "use_ssl": os.getenv("ES_USE_SSL", "false").lower() == "true", +} + +# Redis连接配置 +REDIS_CONFIG = { + "host": REDIS_HOST, + "port": REDIS_PORT, + "password": os.getenv("REDIS_PASSWORD", "infini_rag_flow"), + "decode_responses": False, +} + + +def get_db_connection(): + """创建MySQL数据库连接""" + try: + conn = mysql.connector.connect(**DB_CONFIG) + return conn + except Exception as e: + print(f"MySQL连接失败: {str(e)}") + raise e + + +def get_minio_client(): + """创建MinIO客户端连接""" + try: + minio_client = Minio(endpoint=MINIO_CONFIG["endpoint"], access_key=MINIO_CONFIG["access_key"], secret_key=MINIO_CONFIG["secret_key"], secure=MINIO_CONFIG["secure"]) + return minio_client + except Exception as e: + print(f"MinIO连接失败: {str(e)}") + raise e + + +def get_es_client(): + """创建Elasticsearch客户端连接""" + try: + # 构建连接参数 + es_params = {"hosts": [ES_CONFIG["host"]]} + + # 添加认证信息 + if ES_CONFIG["user"] and ES_CONFIG["password"]: + es_params["basic_auth"] = (ES_CONFIG["user"], ES_CONFIG["password"]) + + # 添加SSL配置 + if ES_CONFIG["use_ssl"]: + es_params["use_ssl"] = True + es_params["verify_certs"] = False # 在开发环境中可以设置为False,生产环境应该设置为True + + es_client = Elasticsearch(**es_params) + return es_client + except Exception as e: + print(f"Elasticsearch连接失败: {str(e)}") + raise e + + +def get_redis_connection(): + """创建Redis连接""" + try: + # 使用配置创建Redis连接 + r = redis.Redis(**REDIS_CONFIG) + # 测试连接 + r.ping() + return r + except Exception as e: + print(f"Redis连接失败: {str(e)}") + raise e diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 2f35c4f..50d0ddc 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -30,6 +30,7 @@ from rag.app.tag import label_question from rag.nlp.search import index_name from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, chunks_format, citation_prompt from rag.utils import rmSpace, num_tokens_from_string +from .database import MINIO_CONFIG class DialogService(CommonService): @@ -152,11 +153,6 @@ def chat(dialog, messages, stream=True, **kwargs): if p["key"] not in kwargs: prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") - # 不再使用多轮对话优化 - # if len(questions) > 1 and prompt_config.get("refine_multiturn"): - # questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] - # else: - # questions = questions[-1:] questions = questions[-1:] refine_question_ts = timer() @@ -181,41 +177,6 @@ def chat(dialog, messages, stream=True, **kwargs): knowledges = [] - # 不再使用推理 - # if prompt_config.get("reasoning", False): - # reasoner = DeepResearcher(chat_mdl, - # prompt_config, - # partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3)) - - # for think in reasoner.thinking(kbinfos, " ".join(questions)): - # if isinstance(think, str): - # thought = think - # knowledges = [t for t in think.split("\n") if t] - # elif stream: - # yield think - # else: - # kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, - # dialog.similarity_threshold, - # dialog.vector_similarity_weight, - # doc_ids=attachments, - # top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, - # rank_feature=label_question(" ".join(questions), kbs) - # ) - # if prompt_config.get("tavily_api_key"): - # tav = Tavily(prompt_config["tavily_api_key"]) - # tav_res = tav.retrieve_chunks(" ".join(questions)) - # kbinfos["chunks"].extend(tav_res["chunks"]) - # kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) - # if prompt_config.get("use_kg"): - # ck = settings.kg_retrievaler.retrieval(" ".join(questions), - # tenant_ids, - # dialog.kb_ids, - # embd_mdl, - # LLMBundle(dialog.tenant_id, LLMType.CHAT)) - # if ck["content_with_weight"]: - # kbinfos["chunks"].insert(0, ck) - - # knowledges = kb_prompt(kbinfos, max_tokens) kbinfos = retriever.retrieval( " ".join(questions), embd_mdl, @@ -261,15 +222,18 @@ def chat(dialog, messages, stream=True, **kwargs): nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions refs = [] - image_markdowns = [] # 用于存储图片的 Markdown 字符串 ans = answer.split("") think = "" if len(ans) == 2: think = ans[0] + "" answer = ans[1] + + cited_chunk_indices = set() + inserted_images = {} + processed_image_urls = set() + if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): - answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL) - cited_chunk_indices = set() # 用于存储被引用的 chunk 索引 + # 获取引用的 chunk 索引 if not re.search(r"##[0-9]+\$\$", answer): answer, idx = retriever.insert_citations( answer, @@ -279,35 +243,41 @@ def chat(dialog, messages, stream=True, **kwargs): tkweight=1 - dialog.vector_similarity_weight, vtweight=dialog.vector_similarity_weight, ) - cited_chunk_indices = idx # 获取 insert_citations 返回的索引 - + cited_chunk_indices = idx else: - idx = set([]) for r in re.finditer(r"##([0-9]+)\$\$", answer): i = int(r.group(1)) if i < len(kbinfos["chunks"]): - idx.add(i) - cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引 + cited_chunk_indices.add(i) - # 根据引用的 chunk 索引提取图像信息并生成 Markdown - cited_doc_ids = set() - processed_image_urls = set() # 避免重复添加同一张图片 - print(f"DEBUG: cited_chunk_indices = {cited_chunk_indices}") - for i in cited_chunk_indices: - i_int = int(i) - if i_int < len(kbinfos["chunks"]): - chunk = kbinfos["chunks"][i_int] - cited_doc_ids.add(chunk["doc_id"]) - print(f"DEBUG: chunk = {chunk}") - # 检查 chunk 是否有关联的 image_id (URL) 且未被处理过 - print(f"DEBUG: chunk_id={chunk.get('chunk_id', i_int)}, image_id={chunk.get('image_id')}") - img_url = chunk.get("image_id") - if img_url and img_url not in processed_image_urls: - # 生成 Markdown 字符串,alt text 可以简单设为 "image" 或 chunk ID - image_markdowns.append(f"\n![{img_url}]({img_url})") - processed_image_urls.add(img_url) # 标记为已处理 + # 处理图片插入 + def insert_image_markdown(match): + idx = int(match.group(1)) + if idx >= len(kbinfos["chunks"]): + return match.group(0) - idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) + chunk = kbinfos["chunks"][idx] + img_path = chunk.get("image_id") + if not img_path: + return match.group(0) + + protocol = "https" if MINIO_CONFIG.get("secure", False) else "http" + img_url = f"{protocol}://{MINIO_CONFIG['endpoint']}/{img_path}" + + if img_url in processed_image_urls: + return match.group(0) + + processed_image_urls.add(img_url) + inserted_images[idx] = img_url + + # 插入图片,不加任何括号包裹引用标记 + return f"{match.group(0)}\n\n![image]({img_url})" + + # 用正则替换插图 + answer = re.sub(r"##(\d+)\$\$", insert_image_markdown, answer) + + # 清理引用文献信息 + idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in cited_chunk_indices]) recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] if not recall_docs: recall_docs = kbinfos["doc_aggs"] @@ -318,14 +288,12 @@ def chat(dialog, messages, stream=True, **kwargs): if c.get("vector"): del c["vector"] - # 将图片的 Markdown 字符串追加到回答末尾 - if image_markdowns: - answer += "".join(image_markdowns) - - if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: + # 特殊错误提示 + if "invalid key" in answer.lower() or "invalid api" in answer.lower(): answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'" - finish_chat_ts = timer() + # 时间信息拼接 + finish_chat_ts = timer() total_time_cost = (finish_chat_ts - chat_start_ts) * 1000 check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000 create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000 @@ -339,6 +307,7 @@ def chat(dialog, messages, stream=True, **kwargs): prompt += "\n\n### Query:\n%s" % " ".join(questions) prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms" + return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()} if stream: diff --git a/management/server/services/knowledgebases/document_parser.py b/management/server/services/knowledgebases/document_parser.py index d2857be..190272c 100644 --- a/management/server/services/knowledgebases/document_parser.py +++ b/management/server/services/knowledgebases/document_parser.py @@ -15,6 +15,7 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.data.read_api import read_local_office, read_local_images from utils import generate_uuid +from urllib.parse import urlparse from .rag_tokenizer import RagTokenizer from .excel_parser import parse_excel @@ -679,19 +680,20 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info): for img_info in image_info_list: # 计算文本块与图片的"距离" distance = abs(i - img_info["position"]) # 使用位置差作为距离度量 - # 如果文本块与图片的距离间隔小于10个块,则认为块与图片是相关的 - if distance < 10: + # 如果文本块与图片的距离间隔小于5个块,则认为块与图片是相关的 + if distance < 5: nearest_image = img_info # 如果找到了最近的图片,则更新文本块的img_id if nearest_image: + # v0.4.1更新,改成存储提取其相对路径部分 + parsed_url = urlparse(nearest_image["url"]) + relative_path = parsed_url.path.lstrip("/") # 去掉开头的斜杠 # 更新ES中的文档 - direct_update = {"doc": {"img_id": nearest_image["url"]}} + direct_update = {"doc": {"img_id": relative_path}} es_client.update(index=index_name, id=chunk_id, body=direct_update, refresh=True) - index_name = f"ragflow_{tenant_id}" - - print(f"[Parser-INFO] 更新文本块 {chunk_id} 的图片关联: {nearest_image['url']}") + print(f"[Parser-INFO] 更新文本块 {chunk_id} 的图片关联: {relative_path}") except Exception as e: print(f"[Parser-ERROR] 更新文本块图片关联失败: {e}")