RAGflow/management/server/services/knowledgebases/document_parser.py

554 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2025 zstar1003. All Rights Reserved.
# Project source code: https://github.com/zstar1003/ragflow-plus
import json
import os
import re
import shutil
import tempfile
import time
from datetime import datetime
from urllib.parse import urlparse
import requests
from database import MINIO_CONFIG, get_db_connection, get_es_client, get_minio_client
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import FileBasedDataReader, FileBasedDataWriter
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.data.read_api import read_local_images, read_local_office
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from . import logger
from .excel_parser import parse_excel_file
from .rag_tokenizer import RagTokenizer
from .utils import _create_task_record, _update_document_progress, _update_kb_chunk_count, generate_uuid, get_bbox_from_block
tknzr = RagTokenizer()
def tokenize_text(text):
"""使用分词器对文本进行分词"""
return tknzr.tokenize(text)
def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
"""
执行文档解析的核心逻辑
Args:
doc_id (str): 文档ID.
doc_info (dict): 包含文档信息的字典 (name, location, type, kb_id, parser_config, created_by).
file_info (dict): 包含文件信息的字典 (parent_id/bucket_name).
kb_info (dict): 包含知识库信息的字典 (created_by).
Returns:
dict: 包含解析结果的字典 (success, chunk_count).
"""
temp_pdf_path = None
temp_image_dir = None
start_time = time.time()
middle_json_content = None # 初始化 middle_json_content
image_info_list = [] # 图片信息列表
# 默认值处理
embedding_model_name = embedding_config.get("llm_name") if embedding_config and embedding_config.get("llm_name") else "bge-m3" # 默认模型
# 对模型名称进行处理
if embedding_model_name and "___" in embedding_model_name:
embedding_model_name = embedding_model_name.split("___")[0]
# 替换特定模型名称(对硅基流动平台进行特异性处理)
if embedding_model_name == "netease-youdao/bce-embedding-base_v1":
embedding_model_name = "BAAI/bge-m3"
embedding_api_base = embedding_config.get("api_base") if embedding_config and embedding_config.get("api_base") else "http://localhost:11434" # 默认基础 URL
# 如果 API 基础地址为空字符串,设置为硅基流动的 API 地址
if embedding_api_base == "":
embedding_api_base = "https://api.siliconflow.cn/v1/embeddings"
logger.info(f"[Parser-INFO] API 基础地址为空,已设置为硅基流动的 API 地址: {embedding_api_base}")
embedding_api_key = embedding_config.get("api_key") if embedding_config else None # 可能为 None 或空字符串
# 构建完整的 Embedding API URL
embedding_url = None # 默认为 None
if embedding_api_base:
# 确保 embedding_api_base 包含协议头 (http:// 或 https://)
if not embedding_api_base.startswith(("http://", "https://")):
embedding_api_base = "http://" + embedding_api_base
# 移除末尾斜杠以方便判断
normalized_base_url = embedding_api_base.rstrip("/")
# 如果请求url端口号为11434则认为是ollama模型采用ollama特定的api
is_ollama = "11434" in normalized_base_url
if is_ollama:
# Ollama 的特殊接口路径
embedding_url = normalized_base_url + "/api/embeddings"
elif normalized_base_url.endswith("/v1"):
embedding_url = normalized_base_url + "/embeddings"
elif normalized_base_url.endswith("/embeddings"):
embedding_url = normalized_base_url
else:
embedding_url = normalized_base_url + "/v1/embeddings"
logger.info(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}")
try:
kb_id = doc_info["kb_id"]
file_location = doc_info["location"]
# 从文件路径中提取原始后缀名
_, file_extension = os.path.splitext(file_location)
file_type = doc_info["type"].lower()
bucket_name = file_info["parent_id"] # 文件存储的桶是 parent_id
tenant_id = kb_info["created_by"] # 知识库创建者作为 tenant_id
# 进度更新回调 (直接调用内部更新函数)
def update_progress(prog=None, msg=None):
_update_document_progress(doc_id, progress=prog, message=msg)
logger.info(f"[Parser-PROGRESS] Doc: {doc_id}, Progress: {prog}, Message: {msg}")
# 1. 从 MinIO 获取文件内容
minio_client = get_minio_client()
if not minio_client.bucket_exists(bucket_name):
raise Exception(f"存储桶不存在: {bucket_name}")
update_progress(0.1, f"正在从存储中获取文件: {file_location}")
response = minio_client.get_object(bucket_name, file_location)
file_content = response.read()
response.close()
update_progress(0.2, "文件获取成功,准备解析")
# 2. 根据文件类型选择解析器
content_list = []
if file_type.endswith("pdf"):
update_progress(0.3, "使用MinerU解析器")
# 创建临时文件保存PDF内容
temp_dir = tempfile.gettempdir()
temp_pdf_path = os.path.join(temp_dir, f"{doc_id}.pdf")
with open(temp_pdf_path, "wb") as f:
f.write(file_content)
# 使用MinerU处理
reader = FileBasedDataReader("")
pdf_bytes = reader.read(temp_pdf_path)
ds = PymuDocDataset(pdf_bytes)
update_progress(0.3, "分析PDF类型")
is_ocr = ds.classify() == SupportedPdfParseMethod.OCR
mode_msg = "OCR模式" if is_ocr else "文本模式"
update_progress(0.4, f"使用{mode_msg}处理PDF处理中具体进度可查看容器日志")
infer_result = ds.apply(doc_analyze, ocr=is_ocr)
# 设置临时输出目录
temp_image_dir = os.path.join(temp_dir, f"images_{doc_id}")
os.makedirs(temp_image_dir, exist_ok=True)
image_writer = FileBasedDataWriter(temp_image_dir)
update_progress(0.6, f"处理{mode_msg}结果")
pipe_result = infer_result.pipe_ocr_mode(image_writer) if is_ocr else infer_result.pipe_txt_mode(image_writer)
update_progress(0.8, "提取内容")
content_list = pipe_result.get_content_list(os.path.basename(temp_image_dir))
# 获取内容列表JSON格式
middle_content = pipe_result.get_middle_json()
middle_json_content = json.loads(middle_content)
elif file_type.endswith("word") or file_type.endswith("ppt") or file_type.endswith("txt") or file_type.endswith("md") or file_type.endswith("html"):
update_progress(0.3, "使用MinerU解析器")
# 创建临时文件保存文件内容
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, f"{doc_id}{file_extension}")
with open(temp_file_path, "wb") as f:
f.write(file_content)
logger.info(f"[Parser-INFO] 临时文件路径: {temp_file_path}")
# 使用MinerU处理
ds = read_local_office(temp_file_path)[0]
infer_result = ds.apply(doc_analyze, ocr=True)
# 设置临时输出目录
temp_image_dir = os.path.join(temp_dir, f"images_{doc_id}")
os.makedirs(temp_image_dir, exist_ok=True)
image_writer = FileBasedDataWriter(temp_image_dir)
update_progress(0.6, "处理文件结果")
pipe_result = infer_result.pipe_txt_mode(image_writer)
update_progress(0.8, "提取内容")
content_list = pipe_result.get_content_list(os.path.basename(temp_image_dir))
# 获取内容列表JSON格式
middle_content = pipe_result.get_middle_json()
middle_json_content = json.loads(middle_content)
# 对excel文件单独进行处理
elif file_type.endswith("excel"):
update_progress(0.3, "使用MinerU解析器")
# 创建临时文件保存文件内容
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, f"{doc_id}{file_extension}")
with open(temp_file_path, "wb") as f:
f.write(file_content)
logger.info(f"[Parser-INFO] 临时文件路径: {temp_file_path}")
update_progress(0.8, "提取内容")
# 处理内容列表
content_list = parse_excel_file(temp_file_path)
elif file_type.endswith("visual"):
update_progress(0.3, "使用MinerU解析器")
# 创建临时文件保存文件内容
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, f"{doc_id}{file_extension}")
with open(temp_file_path, "wb") as f:
f.write(file_content)
logger.info(f"[Parser-INFO] 临时文件路径: {temp_file_path}")
# 使用MinerU处理
ds = read_local_images(temp_file_path)[0]
infer_result = ds.apply(doc_analyze, ocr=True)
update_progress(0.3, "分析PDF类型")
is_ocr = ds.classify() == SupportedPdfParseMethod.OCR
mode_msg = "OCR模式" if is_ocr else "文本模式"
update_progress(0.4, f"使用{mode_msg}处理PDF处理中具体进度可查看日志")
infer_result = ds.apply(doc_analyze, ocr=is_ocr)
# 设置临时输出目录
temp_image_dir = os.path.join(temp_dir, f"images_{doc_id}")
os.makedirs(temp_image_dir, exist_ok=True)
image_writer = FileBasedDataWriter(temp_image_dir)
update_progress(0.6, f"处理{mode_msg}结果")
pipe_result = infer_result.pipe_ocr_mode(image_writer) if is_ocr else infer_result.pipe_txt_mode(image_writer)
update_progress(0.8, "提取内容")
content_list = pipe_result.get_content_list(os.path.basename(temp_image_dir))
# 获取内容列表JSON格式
middle_content = pipe_result.get_middle_json()
middle_json_content = json.loads(middle_content)
else:
update_progress(0.3, f"暂不支持的文件类型: {file_type}")
raise NotImplementedError(f"文件类型 '{file_type}' 的解析器尚未实现")
# 解析 middle_json_content 并提取块信息
block_info_list = []
if middle_json_content:
try:
if isinstance(middle_json_content, dict):
middle_data = middle_json_content # 直接赋值
else:
middle_data = None
logger.warning(f"[Parser-WARNING] middle_json_content 不是预期的字典格式,实际类型: {type(middle_json_content)}")
# 提取信息
for page_idx, page_data in enumerate(middle_data.get("pdf_info", [])):
for block in page_data.get("preproc_blocks", []):
block_bbox = get_bbox_from_block(block)
# 仅提取包含文本且有 bbox 的块
if block_bbox != [0, 0, 0, 0]:
block_info_list.append({"page_idx": page_idx, "bbox": block_bbox})
else:
logger.warning("[Parser-WARNING] 块的 bbox 格式无效,跳过。")
logger.info(f"[Parser-INFO] 从 middle_data 提取了 {len(block_info_list)} 个块的信息。")
except json.JSONDecodeError:
logger.error("[Parser-ERROR] 解析 middle_json_content 失败。")
raise Exception("[Parser-ERROR] 解析 middle_json_content 失败。")
except Exception as e:
logger.error(f"[Parser-ERROR] 处理 middle_json_content 时出错: {e}")
raise Exception(f"[Parser-ERROR] 处理 middle_json_content 时出错: {e}")
# 3. 处理解析结果 (上传到MinIO, 存储到ES)
update_progress(0.95, "保存解析结果")
es_client = get_es_client()
# 注意MinIO的桶应该是知识库ID (kb_id),而不是文件的 parent_id
output_bucket = kb_id
if not minio_client.bucket_exists(output_bucket):
minio_client.make_bucket(output_bucket)
logger.info(f"[Parser-INFO] 创建MinIO桶: {output_bucket}")
index_name = f"ragflow_{tenant_id}"
if not es_client.indices.exists(index=index_name):
# 创建索引
es_client.indices.create(
index=index_name,
body={
"settings": {"number_of_replicas": 0},
"mappings": {
"properties": {"doc_id": {"type": "keyword"}, "kb_id": {"type": "keyword"}, "content_with_weight": {"type": "text"}, "q_1024_vec": {"type": "dense_vector", "dims": 1024}}
},
},
)
logger.info(f"[Parser-INFO] 创建Elasticsearch索引: {index_name}")
chunk_count = 0
chunk_ids_list = []
for chunk_idx, chunk_data in enumerate(content_list):
page_idx = 0 # 默认页面索引
bbox = [0, 0, 0, 0] # 默认 bbox
# 尝试使用 chunk_idx 直接从 block_info_list 获取对应的块信息
if chunk_idx < len(block_info_list):
block_info = block_info_list[chunk_idx]
page_idx = block_info.get("page_idx", 0)
bbox = block_info.get("bbox", [0, 0, 0, 0])
# 验证 bbox 是否有效,如果无效则重置为默认值 (可选,取决于是否需要严格验证)
if not (isinstance(bbox, list) and len(bbox) == 4 and all(isinstance(n, (int, float)) for n in bbox)):
logger.info(f"[Parser-WARNING] Chunk {chunk_idx} 对应的 bbox 格式无效: {bbox},将使用默认值。")
bbox = [0, 0, 0, 0]
else:
# 如果 block_info_list 的长度小于 content_list打印警告
# 仅在第一次索引越界时打印一次警告,避免刷屏
if chunk_idx == len(block_info_list):
logger.warning(f"[Parser-WARNING] block_info_list 的长度 ({len(block_info_list)}) 小于 content_list 的长度 ({len(content_list)})。后续块将使用默认 page_idx 和 bbox。")
if chunk_data["type"] == "text" or chunk_data["type"] == "table" or chunk_data["type"] == "equation":
if chunk_data["type"] == "text":
content = chunk_data["text"]
if not content or not content.strip():
continue
# 过滤 markdown 特殊符号
content = re.sub(r"[!#\\$/]", "", content)
elif chunk_data["type"] == "equation":
content = chunk_data["text"]
if not content or not content.strip():
continue
elif chunk_data["type"] == "table":
caption_list = chunk_data.get("table_caption", []) # 获取列表,默认为空列表
table_body = chunk_data.get("table_body", "") # 获取表格主体,默认为空字符串
# 如果表格主体为空,说明无实际内容,跳过该表格块
if not table_body.strip():
continue
# 检查 caption_list 是否为列表,并且包含字符串元素
if isinstance(caption_list, list) and all(isinstance(item, str) for item in caption_list):
# 使用空格将列表中的所有字符串拼接起来
caption_str = " ".join(caption_list)
elif isinstance(caption_list, str):
# 如果 caption 本身就是字符串,直接使用
caption_str = caption_list
else:
# 其他情况如空列表、None 或非字符串列表),使用空字符串
caption_str = ""
# 将处理后的标题字符串和表格主体拼接
content = caption_str + table_body
q_1024_vec = [] # 初始化为空列表
# 获取embedding向量
try:
# embedding_resp = requests.post(
# "http://localhost:8000/v1/embeddings",
# json={
# "model": "bge-m3", # 你的embedding模型名
# "input": content
# },
# timeout=10
# )
headers = {"Content-Type": "application/json"}
if embedding_api_key:
headers["Authorization"] = f"Bearer {embedding_api_key}"
if is_ollama:
embedding_resp = requests.post(
embedding_url, # 使用动态构建的 URL
headers=headers, # 添加 headers (包含可能的 API Key)
json={
"model": embedding_model_name, # 使用动态获取或默认的模型名
"prompt": content,
},
timeout=15, # 稍微增加超时时间
)
else:
embedding_resp = requests.post(
embedding_url, # 使用动态构建的 URL
headers=headers, # 添加 headers (包含可能的 API Key)
json={
"model": embedding_model_name, # 使用动态获取或默认的模型名
"input": content,
},
timeout=15, # 稍微增加超时时间
)
embedding_resp.raise_for_status()
embedding_data = embedding_resp.json()
# 对ollama嵌入模型的接口返回值进行特殊处理
if is_ollama:
q_1024_vec = embedding_data.get("embedding")
else:
q_1024_vec = embedding_data["data"][0]["embedding"]
# logger.info(f"[Parser-INFO] 获取embedding成功长度: {len(q_1024_vec)}")
# 检查向量维度是否为1024
if len(q_1024_vec) != 1024:
error_msg = f"[Parser-ERROR] Embedding向量维度不是1024实际维度: {len(q_1024_vec)}, 建议使用bge-m3模型"
logger.error(error_msg)
update_progress(-5, error_msg)
raise ValueError(error_msg)
except Exception as e:
logger.error(f"[Parser-ERROR] 获取embedding失败: {e}")
raise Exception(f"[Parser-ERROR] 获取embedding失败: {e}")
chunk_id = generate_uuid()
try:
# 准备ES文档
current_time_es = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
current_timestamp_es = datetime.now().timestamp()
# 转换坐标格式
x1, y1, x2, y2 = bbox
bbox_reordered = [x1, x2, y1, y2]
es_doc = {
"doc_id": doc_id,
"kb_id": kb_id,
"docnm_kwd": doc_info["name"],
"title_tks": tokenize_text(doc_info["name"]),
"title_sm_tks": tokenize_text(doc_info["name"]),
"content_with_weight": content,
"content_ltks": tokenize_text(content),
"content_sm_ltks": tokenize_text(content),
"page_num_int": [page_idx + 1],
"position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]]
"top_int": [1],
"create_time": current_time_es,
"create_timestamp_flt": current_timestamp_es,
"img_id": "",
"q_1024_vec": q_1024_vec,
}
# 存储到Elasticsearch
es_client.index(index=index_name, id=chunk_id, document=es_doc) # 使用 document 参数
chunk_count += 1
chunk_ids_list.append(chunk_id)
except Exception as e:
logger.error(f"[Parser-ERROR] 处理文本块 {chunk_idx} (page: {page_idx}, bbox: {bbox}) 失败: {e}")
raise Exception(f"[Parser-ERROR] 处理文本块 {chunk_idx} (page: {page_idx}, bbox: {bbox}) 失败: {e}")
elif chunk_data["type"] == "image":
img_path_relative = chunk_data.get("img_path")
if not img_path_relative or not temp_image_dir:
continue
img_path_abs = os.path.join(temp_image_dir, os.path.basename(img_path_relative))
if not os.path.exists(img_path_abs):
logger.warning(f"[Parser-WARNING] 图片文件不存在: {img_path_abs}")
continue
img_id = generate_uuid()
img_ext = os.path.splitext(img_path_abs)[1]
img_key = f"images/{img_id}{img_ext}" # MinIO中的对象名
content_type = f"image/{img_ext[1:].lower()}"
if content_type == "image/jpg":
content_type = "image/jpeg"
try:
# 上传图片到MinIO (桶为kb_id)
minio_client.fput_object(bucket_name=output_bucket, object_name=img_key, file_path=img_path_abs, content_type=content_type)
# 设置图片的公共访问权限
policy = {"Version": "2012-10-17", "Statement": [{"Effect": "Allow", "Principal": {"AWS": "*"}, "Action": ["s3:GetObject"], "Resource": [f"arn:aws:s3:::{kb_id}/images/*"]}]}
minio_client.set_bucket_policy(kb_id, json.dumps(policy))
logger.info(f"成功上传图片: {img_key}")
minio_endpoint = MINIO_CONFIG["endpoint"]
use_ssl = MINIO_CONFIG.get("secure", False)
protocol = "https" if use_ssl else "http"
img_url = f"{protocol}://{minio_endpoint}/{output_bucket}/{img_key}"
# 记录图片信息包括URL和位置信息
image_info = {
"url": img_url,
"position": chunk_count, # 使用当前处理的文本块数作为位置参考
}
image_info_list.append(image_info)
logger.info(f"图片访问链接: {img_url}")
except Exception as e:
logger.error(f"[Parser-ERROR] 上传图片 {img_path_abs} 失败: {e}")
raise Exception(f"[Parser-ERROR] 上传图片 {img_path_abs} 失败: {e}")
# 打印匹配总结信息
logger.info(f"[Parser-INFO] 共处理 {chunk_count} 个文本块。")
# 4. 更新文本块的图像信息
if image_info_list and chunk_ids_list:
conn = None
cursor = None
try:
conn = get_db_connection()
cursor = conn.cursor()
# 为每个文本块找到最近的图片
for i, chunk_id in enumerate(chunk_ids_list):
# 找到与当前文本块最近的图片
nearest_image = None
for img_info in image_info_list:
# 计算文本块与图片的"距离"
distance = abs(i - img_info["position"]) # 使用位置差作为距离度量
# 如果文本块与图片的距离间隔小于5个块,则认为块与图片是相关的
if distance < 5:
nearest_image = img_info
# 如果找到了最近的图片则更新文本块的img_id
if nearest_image:
# 存储相对路径部分
parsed_url = urlparse(nearest_image["url"])
relative_path = parsed_url.path.lstrip("/") # 去掉开头的斜杠
# 更新ES中的文档
direct_update = {"doc": {"img_id": relative_path}}
es_client.update(index=index_name, id=chunk_id, body=direct_update, refresh=True)
index_name = f"ragflow_{tenant_id}"
logger.info(f"[Parser-INFO] 更新文本块 {chunk_id} 的图片关联: {relative_path}")
except Exception as e:
logger.error(f"[Parser-ERROR] 更新文本块图片关联失败: {e}")
raise Exception(f"[Parser-ERROR] 更新文本块图片关联失败: {e}")
finally:
if cursor:
cursor.close()
if conn:
conn.close()
# 5. 更新最终状态
process_duration = time.time() - start_time
_update_document_progress(doc_id, progress=1.0, message="解析完成", status="1", run="3", chunk_count=chunk_count, process_duration=process_duration)
_update_kb_chunk_count(kb_id, chunk_count) # 更新知识库总块数
_create_task_record(doc_id, chunk_ids_list) # 创建task记录
update_progress(1.0, "解析完成")
logger.info(f"[Parser-INFO] 解析完成文档ID: {doc_id}, 耗时: {process_duration:.2f}s, 块数: {chunk_count}")
return {"success": True, "chunk_count": chunk_count}
except Exception as e:
process_duration = time.time() - start_time
# error_message = f"解析失败: {str(e)}"
logger.error(f"[Parser-ERROR] 文档 {doc_id} 解析失败: {e}")
error_message = f"解析失败: {e}"
# 更新文档状态为失败
_update_document_progress(doc_id, status="1", run="0", message=error_message, process_duration=process_duration) # status=1表示完成run=0表示失败
return {"success": False, "error": error_message}
finally:
# 清理临时文件
try:
if temp_pdf_path and os.path.exists(temp_pdf_path):
os.remove(temp_pdf_path)
if temp_image_dir and os.path.exists(temp_image_dir):
shutil.rmtree(temp_image_dir, ignore_errors=True)
except Exception as clean_e:
logger.error(f"[Parser-WARNING] 清理临时文件失败: {clean_e}")