RAGflow/management/server/services/knowledgebases/document_parser.py

547 lines
27 KiB
Python
Raw Normal View History

# Copyright 2025 zstar1003. All Rights Reserved.
# Project source code: https://github.com/zstar1003/ragflow-plus
import json
import os
import re
import shutil
import tempfile
import time
from datetime import datetime
from urllib.parse import urlparse
import requests
from database import MINIO_CONFIG, get_db_connection, get_es_client, get_minio_client
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import FileBasedDataReader, FileBasedDataWriter
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.data.read_api import read_local_images, read_local_office
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from . import logger
from .excel_parser import parse_excel_file
from .rag_tokenizer import RagTokenizer
from .utils import _create_task_record, _update_document_progress, _update_kb_chunk_count, generate_uuid, get_bbox_from_block
tknzr = RagTokenizer()
def tokenize_text(text):
"""使用分词器对文本进行分词"""
return tknzr.tokenize(text)
def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
"""
执行文档解析的核心逻辑
Args:
doc_id (str): 文档ID.
doc_info (dict): 包含文档信息的字典 (name, location, type, kb_id, parser_config, created_by).
file_info (dict): 包含文件信息的字典 (parent_id/bucket_name).
kb_info (dict): 包含知识库信息的字典 (created_by).
Returns:
dict: 包含解析结果的字典 (success, chunk_count).
"""
temp_pdf_path = None
temp_image_dir = None
start_time = time.time()
middle_json_content = None # 初始化 middle_json_content
image_info_list = [] # 图片信息列表
# 默认值处理
embedding_model_name = embedding_config.get("llm_name") if embedding_config and embedding_config.get("llm_name") else "bge-m3" # 默认模型
# 对模型名称进行处理
if embedding_model_name and "___" in embedding_model_name:
embedding_model_name = embedding_model_name.split("___")[0]
# 替换特定模型名称(对硅基流动平台进行特异性处理)
if embedding_model_name == "netease-youdao/bce-embedding-base_v1":
embedding_model_name = "BAAI/bge-m3"
embedding_api_base = embedding_config.get("api_base") if embedding_config and embedding_config.get("api_base") else "http://localhost:11434" # 默认基础 URL
# 如果 API 基础地址为空字符串,设置为硅基流动的 API 地址
if embedding_api_base == "":
embedding_api_base = "https://api.siliconflow.cn/v1/embeddings"
logger.info(f"[Parser-INFO] API 基础地址为空,已设置为硅基流动的 API 地址: {embedding_api_base}")
embedding_api_key = embedding_config.get("api_key") if embedding_config else None # 可能为 None 或空字符串
# 构建完整的 Embedding API URL
embedding_url = None # 默认为 None
if embedding_api_base:
# 确保 embedding_api_base 包含协议头 (http:// 或 https://)
if not embedding_api_base.startswith(("http://", "https://")):
embedding_api_base = "http://" + embedding_api_base
# 移除末尾斜杠以方便判断
normalized_base_url = embedding_api_base.rstrip("/")
# 如果请求url端口号为11434则认为是ollama模型采用ollama特定的api
is_ollama = "11434" in normalized_base_url
if is_ollama:
# Ollama 的特殊接口路径
embedding_url = normalized_base_url + "/api/embeddings"
elif normalized_base_url.endswith("/v1"):
embedding_url = normalized_base_url + "/embeddings"
elif normalized_base_url.endswith("/embeddings"):
embedding_url = normalized_base_url
else:
embedding_url = normalized_base_url + "/v1/embeddings"
logger.info(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}")
try:
kb_id = doc_info["kb_id"]
file_location = doc_info["location"]
# 从文件路径中提取原始后缀名
_, file_extension = os.path.splitext(file_location)
file_type = doc_info["type"].lower()
bucket_name = file_info["parent_id"] # 文件存储的桶是 parent_id
tenant_id = kb_info["created_by"] # 知识库创建者作为 tenant_id
# 进度更新回调 (直接调用内部更新函数)
def update_progress(prog=None, msg=None):
_update_document_progress(doc_id, progress=prog, message=msg)
logger.info(f"[Parser-PROGRESS] Doc: {doc_id}, Progress: {prog}, Message: {msg}")
# 1. 从 MinIO 获取文件内容
minio_client = get_minio_client()
if not minio_client.bucket_exists(bucket_name):
raise Exception(f"存储桶不存在: {bucket_name}")
update_progress(0.1, f"正在从存储中获取文件: {file_location}")
response = minio_client.get_object(bucket_name, file_location)
file_content = response.read()
response.close()
update_progress(0.2, "文件获取成功,准备解析")
# 2. 根据文件类型选择解析器
content_list = []
if file_type.endswith("pdf"):
update_progress(0.3, "使用MinerU解析器")
# 创建临时文件保存PDF内容
temp_dir = tempfile.gettempdir()
temp_pdf_path = os.path.join(temp_dir, f"{doc_id}.pdf")
with open(temp_pdf_path, "wb") as f:
f.write(file_content)
# 使用MinerU处理
reader = FileBasedDataReader("")
pdf_bytes = reader.read(temp_pdf_path)
ds = PymuDocDataset(pdf_bytes)
update_progress(0.3, "分析PDF类型")
is_ocr = ds.classify() == SupportedPdfParseMethod.OCR
mode_msg = "OCR模式" if is_ocr else "文本模式"
update_progress(0.4, f"使用{mode_msg}处理PDF处理中具体进度可查看容器日志")
infer_result = ds.apply(doc_analyze, ocr=is_ocr)
# 设置临时输出目录
temp_image_dir = os.path.join(temp_dir, f"images_{doc_id}")
os.makedirs(temp_image_dir, exist_ok=True)
image_writer = FileBasedDataWriter(temp_image_dir)
update_progress(0.6, f"处理{mode_msg}结果")
pipe_result = infer_result.pipe_ocr_mode(image_writer) if is_ocr else infer_result.pipe_txt_mode(image_writer)
update_progress(0.8, "提取内容")
content_list = pipe_result.get_content_list(os.path.basename(temp_image_dir))
# 获取内容列表JSON格式
middle_content = pipe_result.get_middle_json()
middle_json_content = json.loads(middle_content)
elif file_type.endswith("word") or file_type.endswith("ppt") or file_type.endswith("txt") or file_type.endswith("md") or file_type.endswith("html"):
update_progress(0.3, "使用MinerU解析器")
# 创建临时文件保存文件内容
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, f"{doc_id}{file_extension}")
with open(temp_file_path, "wb") as f:
f.write(file_content)
logger.info(f"[Parser-INFO] 临时文件路径: {temp_file_path}")
# 使用MinerU处理
ds = read_local_office(temp_file_path)[0]
infer_result = ds.apply(doc_analyze, ocr=True)
# 设置临时输出目录
temp_image_dir = os.path.join(temp_dir, f"images_{doc_id}")
os.makedirs(temp_image_dir, exist_ok=True)
image_writer = FileBasedDataWriter(temp_image_dir)
update_progress(0.6, "处理文件结果")
pipe_result = infer_result.pipe_txt_mode(image_writer)
update_progress(0.8, "提取内容")
content_list = pipe_result.get_content_list(os.path.basename(temp_image_dir))
# 获取内容列表JSON格式
middle_content = pipe_result.get_middle_json()
middle_json_content = json.loads(middle_content)
# 对excel文件单独进行处理
elif file_type.endswith("excel"):
update_progress(0.3, "使用MinerU解析器")
# 创建临时文件保存文件内容
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, f"{doc_id}{file_extension}")
with open(temp_file_path, "wb") as f:
f.write(file_content)
logger.info(f"[Parser-INFO] 临时文件路径: {temp_file_path}")
update_progress(0.8, "提取内容")
# 处理内容列表
content_list = parse_excel_file(temp_file_path)
elif file_type.endswith("visual"):
update_progress(0.3, "使用MinerU解析器")
# 创建临时文件保存文件内容
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, f"{doc_id}{file_extension}")
with open(temp_file_path, "wb") as f:
f.write(file_content)
logger.info(f"[Parser-INFO] 临时文件路径: {temp_file_path}")
# 使用MinerU处理
ds = read_local_images(temp_file_path)[0]
infer_result = ds.apply(doc_analyze, ocr=True)
update_progress(0.3, "分析PDF类型")
is_ocr = ds.classify() == SupportedPdfParseMethod.OCR
mode_msg = "OCR模式" if is_ocr else "文本模式"
update_progress(0.4, f"使用{mode_msg}处理PDF处理中具体进度可查看日志")
infer_result = ds.apply(doc_analyze, ocr=is_ocr)
# 设置临时输出目录
temp_image_dir = os.path.join(temp_dir, f"images_{doc_id}")
os.makedirs(temp_image_dir, exist_ok=True)
image_writer = FileBasedDataWriter(temp_image_dir)
update_progress(0.6, f"处理{mode_msg}结果")
pipe_result = infer_result.pipe_ocr_mode(image_writer) if is_ocr else infer_result.pipe_txt_mode(image_writer)
update_progress(0.8, "提取内容")
content_list = pipe_result.get_content_list(os.path.basename(temp_image_dir))
# 获取内容列表JSON格式
middle_content = pipe_result.get_middle_json()
middle_json_content = json.loads(middle_content)
else:
update_progress(0.3, f"暂不支持的文件类型: {file_type}")
raise NotImplementedError(f"文件类型 '{file_type}' 的解析器尚未实现")
# 解析 middle_json_content 并提取块信息
block_info_list = []
if middle_json_content:
try:
if isinstance(middle_json_content, dict):
middle_data = middle_json_content # 直接赋值
else:
middle_data = None
logger.warning(f"[Parser-WARNING] middle_json_content 不是预期的字典格式,实际类型: {type(middle_json_content)}")
# 提取信息
for page_idx, page_data in enumerate(middle_data.get("pdf_info", [])):
for block in page_data.get("preproc_blocks", []):
block_bbox = get_bbox_from_block(block)
# 仅提取包含文本且有 bbox 的块
if block_bbox != [0, 0, 0, 0]:
block_info_list.append({"page_idx": page_idx, "bbox": block_bbox})
else:
logger.warning("[Parser-WARNING] 块的 bbox 格式无效,跳过。")
logger.info(f"[Parser-INFO] 从 middle_data 提取了 {len(block_info_list)} 个块的信息。")
except json.JSONDecodeError:
logger.error("[Parser-ERROR] 解析 middle_json_content 失败。")
raise Exception("[Parser-ERROR] 解析 middle_json_content 失败。")
except Exception as e:
logger.error(f"[Parser-ERROR] 处理 middle_json_content 时出错: {e}")
raise Exception(f"[Parser-ERROR] 处理 middle_json_content 时出错: {e}")
# 3. 处理解析结果 (上传到MinIO, 存储到ES)
update_progress(0.95, "保存解析结果")
es_client = get_es_client()
# 注意MinIO的桶应该是知识库ID (kb_id),而不是文件的 parent_id
output_bucket = kb_id
if not minio_client.bucket_exists(output_bucket):
minio_client.make_bucket(output_bucket)
logger.info(f"[Parser-INFO] 创建MinIO桶: {output_bucket}")
index_name = f"ragflow_{tenant_id}"
if not es_client.indices.exists(index=index_name):
# 创建索引
es_client.indices.create(
index=index_name,
body={
"settings": {"number_of_replicas": 0},
"mappings": {
"properties": {"doc_id": {"type": "keyword"}, "kb_id": {"type": "keyword"}, "content_with_weight": {"type": "text"}, "q_1024_vec": {"type": "dense_vector", "dims": 1024}}
},
},
)
logger.info(f"[Parser-INFO] 创建Elasticsearch索引: {index_name}")
chunk_count = 0
chunk_ids_list = []
for chunk_idx, chunk_data in enumerate(content_list):
page_idx = 0 # 默认页面索引
bbox = [0, 0, 0, 0] # 默认 bbox
# 尝试使用 chunk_idx 直接从 block_info_list 获取对应的块信息
if chunk_idx < len(block_info_list):
block_info = block_info_list[chunk_idx]
page_idx = block_info.get("page_idx", 0)
bbox = block_info.get("bbox", [0, 0, 0, 0])
# 验证 bbox 是否有效,如果无效则重置为默认值 (可选,取决于是否需要严格验证)
if not (isinstance(bbox, list) and len(bbox) == 4 and all(isinstance(n, (int, float)) for n in bbox)):
logger.info(f"[Parser-WARNING] Chunk {chunk_idx} 对应的 bbox 格式无效: {bbox},将使用默认值。")
bbox = [0, 0, 0, 0]
else:
# 如果 block_info_list 的长度小于 content_list打印警告
# 仅在第一次索引越界时打印一次警告,避免刷屏
if chunk_idx == len(block_info_list):
logger.warning(f"[Parser-WARNING] block_info_list 的长度 ({len(block_info_list)}) 小于 content_list 的长度 ({len(content_list)})。后续块将使用默认 page_idx 和 bbox。")
if chunk_data["type"] == "text" or chunk_data["type"] == "table" or chunk_data["type"] == "equation":
if chunk_data["type"] == "text":
content = chunk_data["text"]
if not content or not content.strip():
continue
# 过滤 markdown 特殊符号
content = re.sub(r"[!#\\$/]", "", content)
elif chunk_data["type"] == "equation":
content = chunk_data["text"]
if not content or not content.strip():
continue
elif chunk_data["type"] == "table":
caption_list = chunk_data.get("table_caption", []) # 获取列表,默认为空列表
table_body = chunk_data.get("table_body", "") # 获取表格主体,默认为空字符串
# 如果表格主体为空,说明无实际内容,跳过该表格块
if not table_body.strip():
continue
# 检查 caption_list 是否为列表,并且包含字符串元素
if isinstance(caption_list, list) and all(isinstance(item, str) for item in caption_list):
# 使用空格将列表中的所有字符串拼接起来
caption_str = " ".join(caption_list)
elif isinstance(caption_list, str):
# 如果 caption 本身就是字符串,直接使用
caption_str = caption_list
else:
# 其他情况如空列表、None 或非字符串列表),使用空字符串
caption_str = ""
# 将处理后的标题字符串和表格主体拼接
content = caption_str + table_body
q_1024_vec = [] # 初始化为空列表
# 获取embedding向量
try:
# embedding_resp = requests.post(
# "http://localhost:8000/v1/embeddings",
# json={
# "model": "bge-m3", # 你的embedding模型名
# "input": content
# },
# timeout=10
# )
headers = {"Content-Type": "application/json"}
if embedding_api_key:
headers["Authorization"] = f"Bearer {embedding_api_key}"
if is_ollama:
embedding_resp = requests.post(
embedding_url, # 使用动态构建的 URL
headers=headers, # 添加 headers (包含可能的 API Key)
json={
"model": embedding_model_name, # 使用动态获取或默认的模型名
"prompt": content,
},
timeout=15, # 稍微增加超时时间
)
else:
embedding_resp = requests.post(
embedding_url, # 使用动态构建的 URL
headers=headers, # 添加 headers (包含可能的 API Key)
json={
"model": embedding_model_name, # 使用动态获取或默认的模型名
"input": content,
},
timeout=15, # 稍微增加超时时间
)
embedding_resp.raise_for_status()
embedding_data = embedding_resp.json()
# 对ollama嵌入模型的接口返回值进行特殊处理
if is_ollama:
q_1024_vec = embedding_data.get("embedding")
else:
q_1024_vec = embedding_data["data"][0]["embedding"]
# logger.info(f"[Parser-INFO] 获取embedding成功长度: {len(q_1024_vec)}")
# 检查向量维度是否为1024
if len(q_1024_vec) != 1024:
error_msg = f"[Parser-ERROR] Embedding向量维度不是1024实际维度: {len(q_1024_vec)}, 建议使用bge-m3模型"
logger.error(error_msg)
update_progress(-5, error_msg)
raise ValueError(error_msg)
except Exception as e:
logger.error(f"[Parser-ERROR] 获取embedding失败: {e}")
raise Exception(f"[Parser-ERROR] 获取embedding失败: {e}")
chunk_id = generate_uuid()
try:
# 准备ES文档
current_time_es = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
current_timestamp_es = datetime.now().timestamp()
# 转换坐标格式
x1, y1, x2, y2 = bbox
bbox_reordered = [x1, x2, y1, y2]
es_doc = {
"doc_id": doc_id,
"kb_id": kb_id,
"docnm_kwd": doc_info["name"],
"title_tks": tokenize_text(doc_info["name"]),
"title_sm_tks": tokenize_text(doc_info["name"]),
"content_with_weight": content,
"content_ltks": tokenize_text(content),
"content_sm_ltks": tokenize_text(content),
"page_num_int": [page_idx + 1],
"position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]]
"top_int": [1],
"create_time": current_time_es,
"create_timestamp_flt": current_timestamp_es,
"img_id": "",
"q_1024_vec": q_1024_vec,
}
# 存储到Elasticsearch
es_client.index(index=index_name, id=chunk_id, document=es_doc) # 使用 document 参数
chunk_count += 1
chunk_ids_list.append(chunk_id)
except Exception as e:
logger.error(f"[Parser-ERROR] 处理文本块 {chunk_idx} (page: {page_idx}, bbox: {bbox}) 失败: {e}")
raise Exception(f"[Parser-ERROR] 处理文本块 {chunk_idx} (page: {page_idx}, bbox: {bbox}) 失败: {e}")
elif chunk_data["type"] == "image":
img_path_relative = chunk_data.get("img_path")
if not img_path_relative or not temp_image_dir:
continue
img_path_abs = os.path.join(temp_image_dir, os.path.basename(img_path_relative))
if not os.path.exists(img_path_abs):
logger.warning(f"[Parser-WARNING] 图片文件不存在: {img_path_abs}")
continue
img_id = generate_uuid()
img_ext = os.path.splitext(img_path_abs)[1]
img_key = f"images/{img_id}{img_ext}" # MinIO中的对象名
content_type = f"image/{img_ext[1:].lower()}"
if content_type == "image/jpg":
content_type = "image/jpeg"
try:
# 上传图片到MinIO (桶为kb_id)
minio_client.fput_object(bucket_name=output_bucket, object_name=img_key, file_path=img_path_abs, content_type=content_type)
# 设置图片的公共访问权限
policy = {"Version": "2012-10-17", "Statement": [{"Effect": "Allow", "Principal": {"AWS": "*"}, "Action": ["s3:GetObject"], "Resource": [f"arn:aws:s3:::{kb_id}/images/*"]}]}
minio_client.set_bucket_policy(kb_id, json.dumps(policy))
logger.info(f"成功上传图片: {img_key}")
minio_endpoint = MINIO_CONFIG["endpoint"]
use_ssl = MINIO_CONFIG.get("secure", False)
protocol = "https" if use_ssl else "http"
img_url = f"{protocol}://{minio_endpoint}/{output_bucket}/{img_key}"
# 记录图片信息包括URL和位置信息
image_info = {
"url": img_url,
"position": chunk_count, # 使用当前处理的文本块数作为位置参考
}
image_info_list.append(image_info)
logger.info(f"图片访问链接: {img_url}")
except Exception as e:
logger.error(f"[Parser-ERROR] 上传图片 {img_path_abs} 失败: {e}")
raise Exception(f"[Parser-ERROR] 上传图片 {img_path_abs} 失败: {e}")
# 打印匹配总结信息
logger.info(f"[Parser-INFO] 共处理 {chunk_count} 个文本块。")
# 4. 更新文本块的图像信息
if image_info_list and chunk_ids_list:
try:
# 为每个文本块找到最近的图片
for i, chunk_id in enumerate(chunk_ids_list):
# 找到与当前文本块最近的图片
nearest_image = None
for img_info in image_info_list:
# 计算文本块与图片的"距离"
distance = abs(i - img_info["position"]) # 使用位置差作为距离度量
# 如果文本块与图片的距离间隔小于5个块,则认为块与图片是相关的
if distance < 5:
nearest_image = img_info
# 如果找到了最近的图片则更新文本块的img_id
if nearest_image:
# 存储相对路径部分
parsed_url = urlparse(nearest_image["url"])
relative_path = parsed_url.path.lstrip("/") # 去掉开头的斜杠
# 更新ES中的文档
direct_update = {"doc": {"img_id": relative_path}}
es_client.update(index=index_name, id=chunk_id, body=direct_update, refresh=True)
index_name = f"ragflow_{tenant_id}"
logger.info(f"[Parser-INFO] 更新文本块 {chunk_id} 的图片关联: {relative_path}")
except Exception as e:
logger.error(f"[Parser-ERROR] 更新文本块图片关联失败: {e}")
raise Exception(f"[Parser-ERROR] 更新文本块图片关联失败: {e}")
# 5. 更新最终状态
process_duration = time.time() - start_time
_update_document_progress(doc_id, progress=1.0, message="解析完成", status="1", run="3", chunk_count=chunk_count, process_duration=process_duration)
_update_kb_chunk_count(kb_id, chunk_count) # 更新知识库总块数
_create_task_record(doc_id, chunk_ids_list) # 创建task记录
update_progress(1.0, "解析完成")
logger.info(f"[Parser-INFO] 解析完成文档ID: {doc_id}, 耗时: {process_duration:.2f}s, 块数: {chunk_count}")
return {"success": True, "chunk_count": chunk_count}
except Exception as e:
process_duration = time.time() - start_time
# error_message = f"解析失败: {str(e)}"
logger.error(f"[Parser-ERROR] 文档 {doc_id} 解析失败: {e}")
error_message = f"解析失败: {e}"
# 更新文档状态为失败
_update_document_progress(doc_id, status="1", run="0", message=error_message, process_duration=process_duration) # status=1表示完成run=0表示失败
return {"success": False, "error": error_message}
finally:
# 清理临时文件
try:
if temp_pdf_path and os.path.exists(temp_pdf_path):
os.remove(temp_pdf_path)
if temp_image_dir and os.path.exists(temp_image_dir):
shutil.rmtree(temp_image_dir, ignore_errors=True)
except Exception as clean_e:
logger.error(f"[Parser-WARNING] 清理临时文件失败: {clean_e}")