fix:修复图像文本块关联异常问题 (#78)

添加processed_text_chunks += 1
This commit is contained in:
zstar 2025-05-11 21:28:39 +08:00 committed by GitHub
parent 100f7f9405
commit e6c18119da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 197 additions and 283 deletions

View File

@ -14,11 +14,7 @@ on:
jobs:
deploy:
if: |
(github.event_name == 'push' && github.ref == 'refs/heads/main') ||
(github.event_name == 'pull_request' && github.base_ref == 'refs/heads/main')
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
@ -31,11 +27,3 @@ jobs:
else
echo "${{ toJSON(github.event.pull_request.changed_files) }}"
fi
skip-notification:
if: ${{ always() && !contains(github.event.pull_request.changed_files, 'docs/') }}
runs-on: ubuntu-latest
needs: deploy
steps:
- name: Skip irrelevant changes
run: echo "Skipped: No docs/ files modified in this PR."

View File

@ -118,11 +118,9 @@ pnpm dev
6. 提交PR等待审核
## 📄 交流群
如果有使用问题或建议,可加入交流群进行讨论目前1群已满2群可扫码加入
如果有使用问题或建议,可加入交流群进行讨论。
<div align="center">
<img src="docs/images/group.jpg" width="200" alt="2群二维码">
</div>
由于群聊超过200人无法通过扫码加入如需加群加我微信zstar1003备注"加群"即可。
## 🚀 鸣谢

View File

@ -213,7 +213,7 @@
<div class="container">
<div class="footer-content">
<div class="links">
<a href="../LICENSE">许可证</a>
<a href="https://github.com/zstar1003/ragflow-plus/blob/main/LICENSE">许可证</a>
<a href="https://github.com/zstar1003/ragflow-plus">GitHub</a>
</div>
<div class="copyright">

View File

@ -1,180 +1,121 @@
import os
from flask import jsonify, request, send_file, current_app
from io import BytesIO
from .. import files_bp
from flask import request, jsonify
from werkzeug.utils import secure_filename
from services.files.service import (
get_files_list,
get_file_info,
download_file_from_minio,
delete_file,
batch_delete_files,
get_minio_client,
upload_files_to_server
)
from services.files.service import get_files_list, get_file_info, download_file_from_minio, delete_file, batch_delete_files, get_minio_client, upload_files_to_server
from services.files.utils import FileType
UPLOAD_FOLDER = '/data/uploads'
ALLOWED_EXTENSIONS = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif', 'doc', 'docx', 'xls', 'xlsx'}
UPLOAD_FOLDER = "/data/uploads"
ALLOWED_EXTENSIONS = {"txt", "pdf", "png", "jpg", "jpeg", "gif", "doc", "docx", "xls", "xlsx"}
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
@files_bp.route('/upload', methods=['POST'])
@files_bp.route("/upload", methods=["POST"])
def upload_file():
if 'files' not in request.files:
return jsonify({'code': 400, 'message': '未选择文件', 'data': None}), 400
files = request.files.getlist('files')
upload_result = upload_files_to_server(files)
# 返回标准格式
return jsonify({
'code': 0,
'message': '上传成功',
'data': upload_result['data']
})
if "files" not in request.files:
return jsonify({"code": 400, "message": "未选择文件", "data": None}), 400
@files_bp.route('', methods=['GET', 'OPTIONS'])
files = request.files.getlist("files")
upload_result = upload_files_to_server(files)
# 返回标准格式
return jsonify({"code": 0, "message": "上传成功", "data": upload_result["data"]})
@files_bp.route("", methods=["GET", "OPTIONS"])
def get_files():
"""获取文件列表的API端点"""
if request.method == 'OPTIONS':
return '', 200
try:
current_page = int(request.args.get('currentPage', 1))
page_size = int(request.args.get('size', 10))
name_filter = request.args.get('name', '')
result, total = get_files_list(current_page, page_size, name_filter)
return jsonify({
"code": 0,
"data": {
"list": result,
"total": total
},
"message": "获取文件列表成功"
})
except Exception as e:
return jsonify({
"code": 500,
"message": f"获取文件列表失败: {str(e)}"
}), 500
if request.method == "OPTIONS":
return "", 200
@files_bp.route('/<string:file_id>/download', methods=['GET', 'OPTIONS'])
try:
current_page = int(request.args.get("currentPage", 1))
page_size = int(request.args.get("size", 10))
name_filter = request.args.get("name", "")
result, total = get_files_list(current_page, page_size, name_filter)
return jsonify({"code": 0, "data": {"list": result, "total": total}, "message": "获取文件列表成功"})
except Exception as e:
return jsonify({"code": 500, "message": f"获取文件列表失败: {str(e)}"}), 500
@files_bp.route("/<string:file_id>/download", methods=["GET", "OPTIONS"])
def download_file(file_id):
try:
current_app.logger.info(f"开始处理文件下载请求: {file_id}")
# 获取文件信息
file = get_file_info(file_id)
if not file:
current_app.logger.error(f"文件不存在: {file_id}")
return jsonify({
"code": 404,
"message": f"文件 {file_id} 不存在",
"details": "文件记录不存在或已被删除"
}), 404
if file['type'] == FileType.FOLDER.value:
return jsonify({"code": 404, "message": f"文件 {file_id} 不存在", "details": "文件记录不存在或已被删除"}), 404
if file["type"] == FileType.FOLDER.value:
current_app.logger.error(f"不能下载文件夹: {file_id}")
return jsonify({
"code": 400,
"message": "不能下载文件夹",
"details": "请选择一个文件进行下载"
}), 400
return jsonify({"code": 400, "message": "不能下载文件夹", "details": "请选择一个文件进行下载"}), 400
current_app.logger.info(f"文件信息获取成功: {file_id}, 存储位置: {file['parent_id']}/{file['location']}")
try:
# 从MinIO下载文件
file_data, filename = download_file_from_minio(file_id)
# 创建内存文件对象
file_stream = BytesIO(file_data)
# 返回文件
return send_file(
file_stream,
download_name=filename,
as_attachment=True,
mimetype='application/octet-stream'
)
return send_file(file_stream, download_name=filename, as_attachment=True, mimetype="application/octet-stream")
except Exception as e:
current_app.logger.error(f"下载文件失败: {str(e)}")
return jsonify({
"code": 500,
"message": "下载文件失败",
"details": str(e)
}), 500
return jsonify({"code": 500, "message": "下载文件失败", "details": str(e)}), 500
except Exception as e:
current_app.logger.error(f"处理下载请求时出错: {str(e)}")
return jsonify({
"code": 500,
"message": "处理下载请求时出错",
"details": str(e)
}), 500
return jsonify({"code": 500, "message": "处理下载请求时出错", "details": str(e)}), 500
@files_bp.route('/<string:file_id>', methods=['DELETE', 'OPTIONS'])
@files_bp.route("/<string:file_id>", methods=["DELETE", "OPTIONS"])
def delete_file_route(file_id):
"""删除文件的API端点"""
if request.method == 'OPTIONS':
return '', 200
if request.method == "OPTIONS":
return "", 200
try:
success = delete_file(file_id)
if success:
return jsonify({
"code": 0,
"message": "文件删除成功"
})
else:
return jsonify({
"code": 404,
"message": f"文件 {file_id} 不存在"
}), 404
except Exception as e:
return jsonify({
"code": 500,
"message": f"删除文件失败: {str(e)}"
}), 500
@files_bp.route('/batch', methods=['DELETE', 'OPTIONS'])
if success:
return jsonify({"code": 0, "message": "文件删除成功"})
else:
return jsonify({"code": 404, "message": f"文件 {file_id} 不存在"}), 404
except Exception as e:
return jsonify({"code": 500, "message": f"删除文件失败: {str(e)}"}), 500
@files_bp.route("/batch", methods=["DELETE", "OPTIONS"])
def batch_delete_files_route():
"""批量删除文件的API端点"""
if request.method == 'OPTIONS':
return '', 200
if request.method == "OPTIONS":
return "", 200
try:
data = request.json
file_ids = data.get('ids', [])
file_ids = data.get("ids", [])
if not file_ids:
return jsonify({
"code": 400,
"message": "未提供要删除的文件ID"
}), 400
return jsonify({"code": 400, "message": "未提供要删除的文件ID"}), 400
success_count = batch_delete_files(file_ids)
return jsonify({
"code": 0,
"message": f"成功删除 {success_count}/{len(file_ids)} 个文件"
})
return jsonify({"code": 0, "message": f"成功删除 {success_count}/{len(file_ids)} 个文件"})
except Exception as e:
return jsonify({
"code": 500,
"message": f"批量删除文件失败: {str(e)}"
}), 500
return jsonify({"code": 500, "message": f"批量删除文件失败: {str(e)}"}), 500

View File

@ -3,7 +3,7 @@ import tempfile
import shutil
import json
import mysql.connector
import time
import time
import traceback
import re
import requests
@ -25,22 +25,23 @@ def tokenize_text(text):
# 简单实现,实际应用中可能需要更复杂的分词逻辑
return text.split()
def merge_chunks(sections, chunk_token_num=128, delimiter="\n。;!?"):
"""合并文本块替代naive_merge功能"""
if not sections:
return []
chunks = [""]
token_counts = [0]
for section in sections:
# 计算当前部分的token数量
text = section[0] if isinstance(section, tuple) else section
position = section[1] if isinstance(section, tuple) and len(section) > 1 else ""
# 简单估算token数量
token_count = len(text.split())
# 如果当前chunk已经超过限制创建新chunk
if token_counts[-1] > chunk_token_num:
chunks.append(text)
@ -49,13 +50,15 @@ def merge_chunks(sections, chunk_token_num=128, delimiter="\n。"):
# 否则添加到当前chunk
chunks[-1] += text
token_counts[-1] += token_count
return chunks
def _get_db_connection():
"""创建数据库连接"""
return mysql.connector.connect(**DB_CONFIG)
def _update_document_progress(doc_id, progress=None, message=None, status=None, run=None, chunk_count=None, process_duration=None):
"""更新数据库中文档的进度和状态"""
conn = None
@ -79,13 +82,12 @@ def _update_document_progress(doc_id, progress=None, message=None, status=None,
updates.append("run = %s")
params.append(run)
if chunk_count is not None:
updates.append("chunk_num = %s")
params.append(chunk_count)
updates.append("chunk_num = %s")
params.append(chunk_count)
if process_duration is not None:
updates.append("process_duation = %s")
params.append(process_duration)
if not updates:
return
@ -101,6 +103,7 @@ def _update_document_progress(doc_id, progress=None, message=None, status=None,
if conn:
conn.close()
def _update_kb_chunk_count(kb_id, count_delta):
"""更新知识库的块数量"""
conn = None
@ -125,6 +128,7 @@ def _update_kb_chunk_count(kb_id, count_delta):
if conn:
conn.close()
def _create_task_record(doc_id, chunk_ids_list):
"""创建task记录"""
conn = None
@ -137,8 +141,8 @@ def _create_task_record(doc_id, chunk_ids_list):
current_timestamp = int(current_datetime.timestamp() * 1000)
current_time_str = current_datetime.strftime("%Y-%m-%d %H:%M:%S")
current_date_only = current_datetime.strftime("%Y-%m-%d")
digest = f"{doc_id}_{0}_{1}" # 假设 from_page=0, to_page=1
chunk_ids_str = ' '.join(chunk_ids_list)
digest = f"{doc_id}_{0}_{1}" # 假设 from_page=0, to_page=1
chunk_ids_str = " ".join(chunk_ids_list)
task_insert = """
INSERT INTO task (
@ -148,9 +152,22 @@ def _create_task_record(doc_id, chunk_ids_list):
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
task_params = [
task_id, current_timestamp, current_date_only, current_timestamp, current_date_only,
doc_id, 0, 1, None, 0.0, # begin_at, process_duration
1.0, "MinerU解析完成", 1, digest, chunk_ids_str, "" # progress, msg, retry, digest, chunks, type
task_id,
current_timestamp,
current_date_only,
current_timestamp,
current_date_only,
doc_id,
0,
1,
None,
0.0, # begin_at, process_duration
1.0,
"MinerU解析完成",
1,
digest,
chunk_ids_str,
"", # progress, msg, retry, digest, chunks, type
]
cursor.execute(task_insert, task_params)
conn.commit()
@ -184,6 +201,7 @@ def get_bbox_from_block(block):
# 如果 block 不是字典或没有 bbox 键,或 bbox 格式无效,返回默认值
return [0, 0, 0, 0]
def perform_parse(doc_id, doc_info, file_info, embedding_config):
"""
执行文档解析的核心逻辑
@ -199,48 +217,47 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
temp_pdf_path = None
temp_image_dir = None
start_time = time.time()
middle_json_content = None # 初始化 middle_json_content
middle_json_content = None # 初始化 middle_json_content
image_info_list = [] # 图片信息列表
# 默认值处理
embedding_model_name = embedding_config.get("llm_name") if embedding_config and embedding_config.get("llm_name") else "bge-m3" # 默认模型
embedding_model_name = embedding_config.get("llm_name") if embedding_config and embedding_config.get("llm_name") else "bge-m3" # 默认模型
# 对模型名称进行处理
if embedding_model_name and '___' in embedding_model_name:
embedding_model_name = embedding_model_name.split('___')[0]
embedding_api_base = embedding_config.get("api_base") if embedding_config and embedding_config.get("api_base") else "http://localhost:8000" # 默认基础 URL
embedding_api_key = embedding_config.get("api_key") if embedding_config else None # 可能为 None 或空字符串
if embedding_model_name and "___" in embedding_model_name:
embedding_model_name = embedding_model_name.split("___")[0]
embedding_api_base = embedding_config.get("api_base") if embedding_config and embedding_config.get("api_base") else "http://localhost:8000" # 默认基础 URL
embedding_api_key = embedding_config.get("api_key") if embedding_config else None # 可能为 None 或空字符串
# 构建完整的 Embedding API URL
embedding_url = None # 默认为 None
embedding_url = None # 默认为 None
if embedding_api_base:
# 确保 embedding_api_base 包含协议头 (http:// 或 https://)
if not embedding_api_base.startswith(('http://', 'https://')):
embedding_api_base = 'http://' + embedding_api_base
if not embedding_api_base.startswith(("http://", "https://")):
embedding_api_base = "http://" + embedding_api_base
# --- URL 拼接优化 (处理 /v1) ---
endpoint_segment = "embeddings"
full_endpoint_path = "v1/embeddings"
# 移除末尾斜杠以方便判断
normalized_base_url = embedding_api_base.rstrip('/')
normalized_base_url = embedding_api_base.rstrip("/")
if normalized_base_url.endswith('/v1'):
if normalized_base_url.endswith("/v1"):
# 如果 base_url 已经是 http://host/v1 形式
embedding_url = normalized_base_url + '/' + endpoint_segment
embedding_url = normalized_base_url + "/" + endpoint_segment
else:
# 如果 base_url 是 http://host 或 http://host/api 等其他形式
embedding_url = normalized_base_url + '/' + full_endpoint_path
embedding_url = normalized_base_url + "/" + full_endpoint_path
print(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}")
try:
kb_id = doc_info['kb_id']
file_location = doc_info['location']
kb_id = doc_info["kb_id"]
file_location = doc_info["location"]
# 从文件路径中提取原始后缀名
_, file_extension = os.path.splitext(file_location)
file_type = doc_info['type'].lower()
parser_config = json.loads(doc_info['parser_config']) if isinstance(doc_info['parser_config'], str) else doc_info['parser_config']
bucket_name = file_info['parent_id'] # 文件存储的桶是 parent_id
tenant_id = doc_info['created_by'] # 文档创建者作为 tenant_id
file_type = doc_info["type"].lower()
bucket_name = file_info["parent_id"] # 文件存储的桶是 parent_id
tenant_id = doc_info["created_by"] # 文档创建者作为 tenant_id
# 进度更新回调 (直接调用内部更新函数)
def update_progress(prog=None, msg=None):
@ -260,13 +277,13 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
# 2. 根据文件类型选择解析器
content_list = []
if file_type.endswith('pdf'):
if file_type.endswith("pdf"):
update_progress(0.3, "使用MinerU解析器")
# 创建临时文件保存PDF内容
temp_dir = tempfile.gettempdir()
temp_pdf_path = os.path.join(temp_dir, f"{doc_id}.pdf")
with open(temp_pdf_path, 'wb') as f:
with open(temp_pdf_path, "wb") as f:
f.write(file_content)
# 使用MinerU处理
@ -295,12 +312,12 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
middle_content = pipe_result.get_middle_json()
middle_json_content = json.loads(middle_content)
elif file_type.endswith('word') or file_type.endswith('ppt'):
elif file_type.endswith("word") or file_type.endswith("ppt"):
update_progress(0.3, "使用MinerU解析器")
# 创建临时文件保存文件内容
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, f"{doc_id}{file_extension}")
with open(temp_file_path, 'wb') as f:
with open(temp_file_path, "wb") as f:
f.write(file_content)
print(f"[Parser-INFO] 临时文件路径: {temp_file_path}")
@ -313,7 +330,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
os.makedirs(temp_image_dir, exist_ok=True)
image_writer = FileBasedDataWriter(temp_image_dir)
update_progress(0.6, f"处理文件结果")
update_progress(0.6, "处理文件结果")
pipe_result = infer_result.pipe_txt_mode(image_writer)
update_progress(0.8, "提取内容")
@ -332,22 +349,19 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
if middle_json_content:
try:
if isinstance(middle_json_content, dict):
middle_data = middle_json_content # 直接赋值
middle_data = middle_json_content # 直接赋值
else:
middle_data = None
print(f"[Parser-WARNING] middle_json_content 不是预期的字典格式,实际类型: {type(middle_json_content)}")
# 提取信息
# 提取信息
for page_idx, page_data in enumerate(middle_data.get("pdf_info", [])):
for block in page_data.get("preproc_blocks", []):
block_bbox = get_bbox_from_block(block)
# 仅提取包含文本且有 bbox 的块
if block_bbox != [0, 0, 0, 0]:
block_info_list.append({
"page_idx": page_idx,
"bbox": block_bbox
})
block_info_list.append({"page_idx": page_idx, "bbox": block_bbox})
else:
print(f"[Parser-WARNING] 块的 bbox 格式无效: {bbox},跳过。")
print("[Parser-WARNING] 块的 bbox 格式无效,跳过。")
print(f"[Parser-INFO] 从 middle_data 提取了 {len(block_info_list)} 个块的信息。")
@ -355,7 +369,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
print("[Parser-ERROR] 解析 middle_json_content 失败。")
except Exception as e:
print(f"[Parser-ERROR] 处理 middle_json_content 时出错: {e}")
# 3. 处理解析结果 (上传到MinIO, 存储到ES)
update_progress(0.95, "保存解析结果")
es_client = get_es_client()
@ -367,34 +381,26 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
index_name = f"ragflow_{tenant_id}"
if not es_client.indices.exists(index=index_name):
# 创建索引
# 创建索引
es_client.indices.create(
index=index_name,
body={
"settings": {"number_of_replicas": 0},
"mappings": {
"properties": {
"doc_id": {"type": "keyword"},
"kb_id": {"type": "keyword"},
"content_with_weight": {"type": "text"},
"q_1024_vec": {
"type": "dense_vector",
"dims": 1024
}
}
}
}
"properties": {"doc_id": {"type": "keyword"}, "kb_id": {"type": "keyword"}, "content_with_weight": {"type": "text"}, "q_1024_vec": {"type": "dense_vector", "dims": 1024}}
},
},
)
print(f"[Parser-INFO] 创建Elasticsearch索引: {index_name}")
chunk_count = 0
chunk_ids_list = []
middle_block_idx = 0 # 用于按顺序匹配 block_info_list
processed_text_chunks = 0 # 记录处理的文本块数量
middle_block_idx = 0 # 用于按顺序匹配 block_info_list
processed_text_chunks = 0 # 记录处理的文本块数量
for chunk_idx, chunk_data in enumerate(content_list):
page_idx = 0 # 默认页面索引
bbox = [0, 0, 0, 0] # 默认 bbox
bbox = [0, 0, 0, 0] # 默认 bbox
# 尝试使用 chunk_idx 直接从 block_info_list 获取对应的块信息
if chunk_idx < len(block_info_list):
@ -410,7 +416,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
# 仅在第一次索引越界时打印一次警告,避免刷屏
if chunk_idx == len(block_info_list):
print(f"[Parser-WARNING] block_info_list 的长度 ({len(block_info_list)}) 小于 content_list 的长度 ({len(content_list)})。后续块将使用默认 page_idx 和 bbox。")
if chunk_data["type"] == "text" or chunk_data["type"] == "table":
if chunk_data["type"] == "text":
content = chunk_data["text"]
@ -419,8 +425,8 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
# 过滤 markdown 特殊符号
content = re.sub(r"[!#\\$/]", "", content)
elif chunk_data["type"] == "table":
caption_list = chunk_data.get("table_caption", []) # 获取列表,默认为空列表
table_body = chunk_data.get("table_body", "") # 获取表格主体,默认为空字符串
caption_list = chunk_data.get("table_caption", []) # 获取列表,默认为空列表
table_body = chunk_data.get("table_body", "") # 获取表格主体,默认为空字符串
# 检查 caption_list 是否为列表,并且包含字符串元素
if isinstance(caption_list, list) and all(isinstance(item, str) for item in caption_list):
# 使用空格将列表中的所有字符串拼接起来
@ -433,9 +439,8 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
caption_str = ""
# 将处理后的标题字符串和表格主体拼接
content = caption_str + table_body
q_1024_vec = [] # 初始化为空列表
q_1024_vec = [] # 初始化为空列表
# 获取embedding向量
try:
# embedding_resp = requests.post(
@ -451,15 +456,15 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
headers["Authorization"] = f"Bearer {embedding_api_key}"
embedding_resp = requests.post(
embedding_url, # 使用动态构建的 URL
headers=headers, # 添加 headers (包含可能的 API Key)
embedding_url, # 使用动态构建的 URL
headers=headers, # 添加 headers (包含可能的 API Key)
json={
"model": embedding_model_name, # 使用动态获取或默认的模型名
"input": content
"input": content,
},
timeout=15 # 稍微增加超时时间
timeout=15, # 稍微增加超时时间
)
embedding_resp.raise_for_status()
embedding_data = embedding_resp.json()
q_1024_vec = embedding_data["data"][0]["embedding"]
@ -467,20 +472,20 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
except Exception as e:
print(f"[Parser-ERROR] 获取embedding失败: {e}")
q_1024_vec = []
chunk_id = generate_uuid()
try:
# 上传文本块到 MinIO
minio_client.put_object(
bucket_name=output_bucket,
object_name=chunk_id,
data=BytesIO(content.encode('utf-8')),
length=len(content.encode('utf-8')) # 使用字节长度
data=BytesIO(content.encode("utf-8")),
length=len(content.encode("utf-8")), # 使用字节长度
)
# 准备ES文档
content_tokens = tokenize_text(content) # 分词
content_tokens = tokenize_text(content) # 分词
current_time_es = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
current_timestamp_es = datetime.now().timestamp()
@ -491,34 +496,34 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
es_doc = {
"doc_id": doc_id,
"kb_id": kb_id,
"docnm_kwd": doc_info['name'],
"title_tks": doc_info['name'],
"title_sm_tks": doc_info['name'],
"docnm_kwd": doc_info["name"],
"title_tks": doc_info["name"],
"title_sm_tks": doc_info["name"],
"content_with_weight": content,
"content_ltks": " ".join(content_tokens), # 字符串类型
"content_ltks": " ".join(content_tokens), # 字符串类型
"content_sm_ltks": " ".join(content_tokens), # 字符串类型
"page_num_int": [page_idx + 1],
"position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]]
"position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]]
"top_int": [1],
"create_time": current_time_es,
"create_timestamp_flt": current_timestamp_es,
"img_id": "",
"q_1024_vec": q_1024_vec
"q_1024_vec": q_1024_vec,
}
# 存储到Elasticsearch
es_client.index(index=index_name, id=chunk_id, document=es_doc) # 使用 document 参数
es_client.index(index=index_name, id=chunk_id, document=es_doc) # 使用 document 参数
chunk_count += 1
processed_text_chunks += 1
chunk_ids_list.append(chunk_id)
# print(f"成功处理文本块 {chunk_count}/{len(content_list)}") # 可以取消注释用于详细调试
except Exception as e:
print(f"[Parser-ERROR] 处理文本块 {chunk_idx} (page: {page_idx}, bbox: {bbox}) 失败: {e}")
traceback.print_exc() # 打印更详细的错误
traceback.print_exc() # 打印更详细的错误
elif chunk_data["type"] == "image":
img_path_relative = chunk_data.get('img_path')
img_path_relative = chunk_data.get("img_path")
if not img_path_relative or not temp_image_dir:
continue
@ -529,31 +534,17 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
img_id = generate_uuid()
img_ext = os.path.splitext(img_path_abs)[1]
img_key = f"images/{img_id}{img_ext}" # MinIO中的对象名
img_key = f"images/{img_id}{img_ext}" # MinIO中的对象名
content_type = f"image/{img_ext[1:].lower()}"
if content_type == "image/jpg": content_type = "image/jpeg"
if content_type == "image/jpg":
content_type = "image/jpeg"
try:
# 上传图片到MinIO (桶为kb_id)
minio_client.fput_object(
bucket_name=output_bucket,
object_name=img_key,
file_path=img_path_abs,
content_type=content_type
)
minio_client.fput_object(bucket_name=output_bucket, object_name=img_key, file_path=img_path_abs, content_type=content_type)
# 设置图片的公共访问权限
policy = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {"AWS": "*"},
"Action": ["s3:GetObject"],
"Resource": [f"arn:aws:s3:::{kb_id}/images/*"]
}
]
}
policy = {"Version": "2012-10-17", "Statement": [{"Effect": "Allow", "Principal": {"AWS": "*"}, "Action": ["s3:GetObject"], "Resource": [f"arn:aws:s3:::{kb_id}/images/*"]}]}
minio_client.set_bucket_policy(kb_id, json.dumps(policy))
print(f"成功上传图片: {img_key}")
@ -565,7 +556,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
# 记录图片信息包括URL和位置信息
image_info = {
"url": img_url,
"position": processed_text_chunks # 使用当前处理的文本块数作为位置参考
"position": processed_text_chunks, # 使用当前处理的文本块数作为位置参考
}
image_info_list.append(image_info)
@ -573,46 +564,42 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
except Exception as e:
print(f"[Parser-ERROR] 上传图片 {img_path_abs} 失败: {e}")
# 打印匹配总结信息
print(f"[Parser-INFO] 共处理 {processed_text_chunks} 个文本块。")
if middle_block_idx < len(block_info_list):
print(f"[Parser-WARNING] middle_data 中还有 {len(block_info_list) - middle_block_idx} 个提取的块信息未被使用。")
# 4. 更新文本块的图像信息
print(f"[Parser-WARNING] middle_data 中还有 {len(block_info_list) - middle_block_idx} 个提取的块信息未被使用。")
# 4. 更新文本块的图像信息
if image_info_list and chunk_ids_list:
conn = None
cursor = None
try:
conn = _get_db_connection()
cursor = conn.cursor()
# 为每个文本块找到最近的图片
for i, chunk_id in enumerate(chunk_ids_list):
# 找到与当前文本块最近的图片
nearest_image = None
for img_info in image_info_list:
# 计算文本块与图片的"距离"
distance = abs(i - img_info["position"]) # 使用位置差作为距离度量
# 如果文本块与图片的距离间隔小于10个块,则认为块与图片是相关的
if distance < 10:
nearest_image = img_info
# 如果找到了最近的图片则更新文本块的img_id
if nearest_image:
# 更新ES中的文档
direct_update = {
"doc": {
"img_id": nearest_image["url"]
}
}
direct_update = {"doc": {"img_id": nearest_image["url"]}}
es_client.update(index=index_name, id=chunk_id, body=direct_update, refresh=True)
index_name = f"ragflow_{tenant_id}"
print(f"[Parser-INFO] 更新文本块 {chunk_id} 的图片关联: {nearest_image['url']}")
except Exception as e:
print(f"[Parser-ERROR] 更新文本块图片关联失败: {e}")
finally:
@ -623,23 +610,23 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
# 5. 更新最终状态
process_duration = time.time() - start_time
_update_document_progress(doc_id, progress=1.0, message="解析完成", status='1', run='3', chunk_count=chunk_count, process_duration=process_duration)
_update_kb_chunk_count(kb_id, chunk_count) # 更新知识库总块数
_create_task_record(doc_id, chunk_ids_list) # 创建task记录
_update_document_progress(doc_id, progress=1.0, message="解析完成", status="1", run="3", chunk_count=chunk_count, process_duration=process_duration)
_update_kb_chunk_count(kb_id, chunk_count) # 更新知识库总块数
_create_task_record(doc_id, chunk_ids_list) # 创建task记录
update_progress(1.0, "解析完成")
print(f"[Parser-INFO] 解析完成文档ID: {doc_id}, 耗时: {process_duration:.2f}s, 块数: {chunk_count}")
return {"success": True, "chunk_count": chunk_count}
except Exception as e:
process_duration = time.time() - start_time
# error_message = f"解析失败: {str(e)}"
print(f"[Parser-ERROR] 文档 {doc_id} 解析失败: {e}")
error_message = f"解析失败: {e}"
traceback.print_exc() # 打印详细错误堆栈
traceback.print_exc() # 打印详细错误堆栈
# 更新文档状态为失败
_update_document_progress(doc_id, status='1', run='0', message=error_message, process_duration=process_duration) # status=1表示完成run=0表示失败
_update_document_progress(doc_id, status="1", run="0", message=error_message, process_duration=process_duration) # status=1表示完成run=0表示失败
# 不抛出异常,让调用者知道任务已结束(但失败)
return {"success": False, "error": error_message}
@ -651,4 +638,4 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
if temp_image_dir and os.path.exists(temp_image_dir):
shutil.rmtree(temp_image_dir, ignore_errors=True)
except Exception as clean_e:
print(f"[Parser-WARNING] 清理临时文件失败: {clean_e}")
print(f"[Parser-WARNING] 清理临时文件失败: {clean_e}")