parent
100f7f9405
commit
e6c18119da
|
@ -14,11 +14,7 @@ on:
|
|||
|
||||
jobs:
|
||||
deploy:
|
||||
if: |
|
||||
(github.event_name == 'push' && github.ref == 'refs/heads/main') ||
|
||||
(github.event_name == 'pull_request' && github.base_ref == 'refs/heads/main')
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
@ -31,11 +27,3 @@ jobs:
|
|||
else
|
||||
echo "${{ toJSON(github.event.pull_request.changed_files) }}"
|
||||
fi
|
||||
|
||||
skip-notification:
|
||||
if: ${{ always() && !contains(github.event.pull_request.changed_files, 'docs/') }}
|
||||
runs-on: ubuntu-latest
|
||||
needs: deploy
|
||||
steps:
|
||||
- name: Skip irrelevant changes
|
||||
run: echo "Skipped: No docs/ files modified in this PR."
|
|
@ -118,11 +118,9 @@ pnpm dev
|
|||
6. 提交PR等待审核
|
||||
|
||||
## 📄 交流群
|
||||
如果有使用问题或建议,可加入交流群进行讨论,目前1群已满,2群可扫码加入。
|
||||
如果有使用问题或建议,可加入交流群进行讨论。
|
||||
|
||||
<div align="center">
|
||||
<img src="docs/images/group.jpg" width="200" alt="2群二维码">
|
||||
</div>
|
||||
由于群聊超过200人,无法通过扫码加入,如需加群,加我微信zstar1003,备注"加群"即可。
|
||||
|
||||
## 🚀 鸣谢
|
||||
|
||||
|
|
|
@ -213,7 +213,7 @@
|
|||
<div class="container">
|
||||
<div class="footer-content">
|
||||
<div class="links">
|
||||
<a href="../LICENSE">许可证</a>
|
||||
<a href="https://github.com/zstar1003/ragflow-plus/blob/main/LICENSE">许可证</a>
|
||||
<a href="https://github.com/zstar1003/ragflow-plus">GitHub</a>
|
||||
</div>
|
||||
<div class="copyright">
|
||||
|
|
|
@ -1,180 +1,121 @@
|
|||
import os
|
||||
from flask import jsonify, request, send_file, current_app
|
||||
from io import BytesIO
|
||||
from .. import files_bp
|
||||
from flask import request, jsonify
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
|
||||
from services.files.service import (
|
||||
get_files_list,
|
||||
get_file_info,
|
||||
download_file_from_minio,
|
||||
delete_file,
|
||||
batch_delete_files,
|
||||
get_minio_client,
|
||||
upload_files_to_server
|
||||
)
|
||||
from services.files.service import get_files_list, get_file_info, download_file_from_minio, delete_file, batch_delete_files, get_minio_client, upload_files_to_server
|
||||
from services.files.utils import FileType
|
||||
|
||||
UPLOAD_FOLDER = '/data/uploads'
|
||||
ALLOWED_EXTENSIONS = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif', 'doc', 'docx', 'xls', 'xlsx'}
|
||||
UPLOAD_FOLDER = "/data/uploads"
|
||||
ALLOWED_EXTENSIONS = {"txt", "pdf", "png", "jpg", "jpeg", "gif", "doc", "docx", "xls", "xlsx"}
|
||||
|
||||
|
||||
def allowed_file(filename):
|
||||
return '.' in filename and \
|
||||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
|
||||
@files_bp.route('/upload', methods=['POST'])
|
||||
|
||||
@files_bp.route("/upload", methods=["POST"])
|
||||
def upload_file():
|
||||
if 'files' not in request.files:
|
||||
return jsonify({'code': 400, 'message': '未选择文件', 'data': None}), 400
|
||||
|
||||
files = request.files.getlist('files')
|
||||
upload_result = upload_files_to_server(files)
|
||||
|
||||
# 返回标准格式
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'message': '上传成功',
|
||||
'data': upload_result['data']
|
||||
})
|
||||
if "files" not in request.files:
|
||||
return jsonify({"code": 400, "message": "未选择文件", "data": None}), 400
|
||||
|
||||
@files_bp.route('', methods=['GET', 'OPTIONS'])
|
||||
files = request.files.getlist("files")
|
||||
upload_result = upload_files_to_server(files)
|
||||
|
||||
# 返回标准格式
|
||||
return jsonify({"code": 0, "message": "上传成功", "data": upload_result["data"]})
|
||||
|
||||
|
||||
@files_bp.route("", methods=["GET", "OPTIONS"])
|
||||
def get_files():
|
||||
"""获取文件列表的API端点"""
|
||||
if request.method == 'OPTIONS':
|
||||
return '', 200
|
||||
|
||||
try:
|
||||
current_page = int(request.args.get('currentPage', 1))
|
||||
page_size = int(request.args.get('size', 10))
|
||||
name_filter = request.args.get('name', '')
|
||||
|
||||
result, total = get_files_list(current_page, page_size, name_filter)
|
||||
|
||||
return jsonify({
|
||||
"code": 0,
|
||||
"data": {
|
||||
"list": result,
|
||||
"total": total
|
||||
},
|
||||
"message": "获取文件列表成功"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
"code": 500,
|
||||
"message": f"获取文件列表失败: {str(e)}"
|
||||
}), 500
|
||||
if request.method == "OPTIONS":
|
||||
return "", 200
|
||||
|
||||
@files_bp.route('/<string:file_id>/download', methods=['GET', 'OPTIONS'])
|
||||
try:
|
||||
current_page = int(request.args.get("currentPage", 1))
|
||||
page_size = int(request.args.get("size", 10))
|
||||
name_filter = request.args.get("name", "")
|
||||
|
||||
result, total = get_files_list(current_page, page_size, name_filter)
|
||||
|
||||
return jsonify({"code": 0, "data": {"list": result, "total": total}, "message": "获取文件列表成功"})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"code": 500, "message": f"获取文件列表失败: {str(e)}"}), 500
|
||||
|
||||
|
||||
@files_bp.route("/<string:file_id>/download", methods=["GET", "OPTIONS"])
|
||||
def download_file(file_id):
|
||||
try:
|
||||
current_app.logger.info(f"开始处理文件下载请求: {file_id}")
|
||||
|
||||
|
||||
# 获取文件信息
|
||||
file = get_file_info(file_id)
|
||||
|
||||
|
||||
if not file:
|
||||
current_app.logger.error(f"文件不存在: {file_id}")
|
||||
return jsonify({
|
||||
"code": 404,
|
||||
"message": f"文件 {file_id} 不存在",
|
||||
"details": "文件记录不存在或已被删除"
|
||||
}), 404
|
||||
|
||||
if file['type'] == FileType.FOLDER.value:
|
||||
return jsonify({"code": 404, "message": f"文件 {file_id} 不存在", "details": "文件记录不存在或已被删除"}), 404
|
||||
|
||||
if file["type"] == FileType.FOLDER.value:
|
||||
current_app.logger.error(f"不能下载文件夹: {file_id}")
|
||||
return jsonify({
|
||||
"code": 400,
|
||||
"message": "不能下载文件夹",
|
||||
"details": "请选择一个文件进行下载"
|
||||
}), 400
|
||||
|
||||
return jsonify({"code": 400, "message": "不能下载文件夹", "details": "请选择一个文件进行下载"}), 400
|
||||
|
||||
current_app.logger.info(f"文件信息获取成功: {file_id}, 存储位置: {file['parent_id']}/{file['location']}")
|
||||
|
||||
|
||||
try:
|
||||
# 从MinIO下载文件
|
||||
file_data, filename = download_file_from_minio(file_id)
|
||||
|
||||
|
||||
# 创建内存文件对象
|
||||
file_stream = BytesIO(file_data)
|
||||
|
||||
|
||||
# 返回文件
|
||||
return send_file(
|
||||
file_stream,
|
||||
download_name=filename,
|
||||
as_attachment=True,
|
||||
mimetype='application/octet-stream'
|
||||
)
|
||||
|
||||
return send_file(file_stream, download_name=filename, as_attachment=True, mimetype="application/octet-stream")
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"下载文件失败: {str(e)}")
|
||||
return jsonify({
|
||||
"code": 500,
|
||||
"message": "下载文件失败",
|
||||
"details": str(e)
|
||||
}), 500
|
||||
|
||||
return jsonify({"code": 500, "message": "下载文件失败", "details": str(e)}), 500
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"处理下载请求时出错: {str(e)}")
|
||||
return jsonify({
|
||||
"code": 500,
|
||||
"message": "处理下载请求时出错",
|
||||
"details": str(e)
|
||||
}), 500
|
||||
return jsonify({"code": 500, "message": "处理下载请求时出错", "details": str(e)}), 500
|
||||
|
||||
@files_bp.route('/<string:file_id>', methods=['DELETE', 'OPTIONS'])
|
||||
|
||||
@files_bp.route("/<string:file_id>", methods=["DELETE", "OPTIONS"])
|
||||
def delete_file_route(file_id):
|
||||
"""删除文件的API端点"""
|
||||
if request.method == 'OPTIONS':
|
||||
return '', 200
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
return "", 200
|
||||
|
||||
try:
|
||||
success = delete_file(file_id)
|
||||
|
||||
if success:
|
||||
return jsonify({
|
||||
"code": 0,
|
||||
"message": "文件删除成功"
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
"code": 404,
|
||||
"message": f"文件 {file_id} 不存在"
|
||||
}), 404
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
"code": 500,
|
||||
"message": f"删除文件失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
@files_bp.route('/batch', methods=['DELETE', 'OPTIONS'])
|
||||
if success:
|
||||
return jsonify({"code": 0, "message": "文件删除成功"})
|
||||
else:
|
||||
return jsonify({"code": 404, "message": f"文件 {file_id} 不存在"}), 404
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"code": 500, "message": f"删除文件失败: {str(e)}"}), 500
|
||||
|
||||
|
||||
@files_bp.route("/batch", methods=["DELETE", "OPTIONS"])
|
||||
def batch_delete_files_route():
|
||||
"""批量删除文件的API端点"""
|
||||
if request.method == 'OPTIONS':
|
||||
return '', 200
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
return "", 200
|
||||
|
||||
try:
|
||||
data = request.json
|
||||
file_ids = data.get('ids', [])
|
||||
|
||||
file_ids = data.get("ids", [])
|
||||
|
||||
if not file_ids:
|
||||
return jsonify({
|
||||
"code": 400,
|
||||
"message": "未提供要删除的文件ID"
|
||||
}), 400
|
||||
|
||||
return jsonify({"code": 400, "message": "未提供要删除的文件ID"}), 400
|
||||
|
||||
success_count = batch_delete_files(file_ids)
|
||||
|
||||
return jsonify({
|
||||
"code": 0,
|
||||
"message": f"成功删除 {success_count}/{len(file_ids)} 个文件"
|
||||
})
|
||||
|
||||
|
||||
return jsonify({"code": 0, "message": f"成功删除 {success_count}/{len(file_ids)} 个文件"})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
"code": 500,
|
||||
"message": f"批量删除文件失败: {str(e)}"
|
||||
}), 500
|
||||
return jsonify({"code": 500, "message": f"批量删除文件失败: {str(e)}"}), 500
|
||||
|
|
|
@ -3,7 +3,7 @@ import tempfile
|
|||
import shutil
|
||||
import json
|
||||
import mysql.connector
|
||||
import time
|
||||
import time
|
||||
import traceback
|
||||
import re
|
||||
import requests
|
||||
|
@ -25,22 +25,23 @@ def tokenize_text(text):
|
|||
# 简单实现,实际应用中可能需要更复杂的分词逻辑
|
||||
return text.split()
|
||||
|
||||
|
||||
def merge_chunks(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
"""合并文本块,替代naive_merge功能"""
|
||||
if not sections:
|
||||
return []
|
||||
|
||||
|
||||
chunks = [""]
|
||||
token_counts = [0]
|
||||
|
||||
|
||||
for section in sections:
|
||||
# 计算当前部分的token数量
|
||||
text = section[0] if isinstance(section, tuple) else section
|
||||
position = section[1] if isinstance(section, tuple) and len(section) > 1 else ""
|
||||
|
||||
|
||||
# 简单估算token数量
|
||||
token_count = len(text.split())
|
||||
|
||||
|
||||
# 如果当前chunk已经超过限制,创建新chunk
|
||||
if token_counts[-1] > chunk_token_num:
|
||||
chunks.append(text)
|
||||
|
@ -49,13 +50,15 @@ def merge_chunks(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
|||
# 否则添加到当前chunk
|
||||
chunks[-1] += text
|
||||
token_counts[-1] += token_count
|
||||
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _get_db_connection():
|
||||
"""创建数据库连接"""
|
||||
return mysql.connector.connect(**DB_CONFIG)
|
||||
|
||||
|
||||
def _update_document_progress(doc_id, progress=None, message=None, status=None, run=None, chunk_count=None, process_duration=None):
|
||||
"""更新数据库中文档的进度和状态"""
|
||||
conn = None
|
||||
|
@ -79,13 +82,12 @@ def _update_document_progress(doc_id, progress=None, message=None, status=None,
|
|||
updates.append("run = %s")
|
||||
params.append(run)
|
||||
if chunk_count is not None:
|
||||
updates.append("chunk_num = %s")
|
||||
params.append(chunk_count)
|
||||
updates.append("chunk_num = %s")
|
||||
params.append(chunk_count)
|
||||
if process_duration is not None:
|
||||
updates.append("process_duation = %s")
|
||||
params.append(process_duration)
|
||||
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
|
@ -101,6 +103,7 @@ def _update_document_progress(doc_id, progress=None, message=None, status=None,
|
|||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _update_kb_chunk_count(kb_id, count_delta):
|
||||
"""更新知识库的块数量"""
|
||||
conn = None
|
||||
|
@ -125,6 +128,7 @@ def _update_kb_chunk_count(kb_id, count_delta):
|
|||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _create_task_record(doc_id, chunk_ids_list):
|
||||
"""创建task记录"""
|
||||
conn = None
|
||||
|
@ -137,8 +141,8 @@ def _create_task_record(doc_id, chunk_ids_list):
|
|||
current_timestamp = int(current_datetime.timestamp() * 1000)
|
||||
current_time_str = current_datetime.strftime("%Y-%m-%d %H:%M:%S")
|
||||
current_date_only = current_datetime.strftime("%Y-%m-%d")
|
||||
digest = f"{doc_id}_{0}_{1}" # 假设 from_page=0, to_page=1
|
||||
chunk_ids_str = ' '.join(chunk_ids_list)
|
||||
digest = f"{doc_id}_{0}_{1}" # 假设 from_page=0, to_page=1
|
||||
chunk_ids_str = " ".join(chunk_ids_list)
|
||||
|
||||
task_insert = """
|
||||
INSERT INTO task (
|
||||
|
@ -148,9 +152,22 @@ def _create_task_record(doc_id, chunk_ids_list):
|
|||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
task_params = [
|
||||
task_id, current_timestamp, current_date_only, current_timestamp, current_date_only,
|
||||
doc_id, 0, 1, None, 0.0, # begin_at, process_duration
|
||||
1.0, "MinerU解析完成", 1, digest, chunk_ids_str, "" # progress, msg, retry, digest, chunks, type
|
||||
task_id,
|
||||
current_timestamp,
|
||||
current_date_only,
|
||||
current_timestamp,
|
||||
current_date_only,
|
||||
doc_id,
|
||||
0,
|
||||
1,
|
||||
None,
|
||||
0.0, # begin_at, process_duration
|
||||
1.0,
|
||||
"MinerU解析完成",
|
||||
1,
|
||||
digest,
|
||||
chunk_ids_str,
|
||||
"", # progress, msg, retry, digest, chunks, type
|
||||
]
|
||||
cursor.execute(task_insert, task_params)
|
||||
conn.commit()
|
||||
|
@ -184,6 +201,7 @@ def get_bbox_from_block(block):
|
|||
# 如果 block 不是字典或没有 bbox 键,或 bbox 格式无效,返回默认值
|
||||
return [0, 0, 0, 0]
|
||||
|
||||
|
||||
def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
||||
"""
|
||||
执行文档解析的核心逻辑
|
||||
|
@ -199,48 +217,47 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
temp_pdf_path = None
|
||||
temp_image_dir = None
|
||||
start_time = time.time()
|
||||
middle_json_content = None # 初始化 middle_json_content
|
||||
middle_json_content = None # 初始化 middle_json_content
|
||||
image_info_list = [] # 图片信息列表
|
||||
|
||||
|
||||
# 默认值处理
|
||||
embedding_model_name = embedding_config.get("llm_name") if embedding_config and embedding_config.get("llm_name") else "bge-m3" # 默认模型
|
||||
embedding_model_name = embedding_config.get("llm_name") if embedding_config and embedding_config.get("llm_name") else "bge-m3" # 默认模型
|
||||
# 对模型名称进行处理
|
||||
if embedding_model_name and '___' in embedding_model_name:
|
||||
embedding_model_name = embedding_model_name.split('___')[0]
|
||||
embedding_api_base = embedding_config.get("api_base") if embedding_config and embedding_config.get("api_base") else "http://localhost:8000" # 默认基础 URL
|
||||
embedding_api_key = embedding_config.get("api_key") if embedding_config else None # 可能为 None 或空字符串
|
||||
|
||||
if embedding_model_name and "___" in embedding_model_name:
|
||||
embedding_model_name = embedding_model_name.split("___")[0]
|
||||
embedding_api_base = embedding_config.get("api_base") if embedding_config and embedding_config.get("api_base") else "http://localhost:8000" # 默认基础 URL
|
||||
embedding_api_key = embedding_config.get("api_key") if embedding_config else None # 可能为 None 或空字符串
|
||||
|
||||
# 构建完整的 Embedding API URL
|
||||
embedding_url = None # 默认为 None
|
||||
embedding_url = None # 默认为 None
|
||||
if embedding_api_base:
|
||||
# 确保 embedding_api_base 包含协议头 (http:// 或 https://)
|
||||
if not embedding_api_base.startswith(('http://', 'https://')):
|
||||
embedding_api_base = 'http://' + embedding_api_base
|
||||
if not embedding_api_base.startswith(("http://", "https://")):
|
||||
embedding_api_base = "http://" + embedding_api_base
|
||||
|
||||
# --- URL 拼接优化 (处理 /v1) ---
|
||||
endpoint_segment = "embeddings"
|
||||
full_endpoint_path = "v1/embeddings"
|
||||
# 移除末尾斜杠以方便判断
|
||||
normalized_base_url = embedding_api_base.rstrip('/')
|
||||
normalized_base_url = embedding_api_base.rstrip("/")
|
||||
|
||||
if normalized_base_url.endswith('/v1'):
|
||||
if normalized_base_url.endswith("/v1"):
|
||||
# 如果 base_url 已经是 http://host/v1 形式
|
||||
embedding_url = normalized_base_url + '/' + endpoint_segment
|
||||
embedding_url = normalized_base_url + "/" + endpoint_segment
|
||||
else:
|
||||
# 如果 base_url 是 http://host 或 http://host/api 等其他形式
|
||||
embedding_url = normalized_base_url + '/' + full_endpoint_path
|
||||
embedding_url = normalized_base_url + "/" + full_endpoint_path
|
||||
|
||||
print(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}")
|
||||
|
||||
|
||||
try:
|
||||
kb_id = doc_info['kb_id']
|
||||
file_location = doc_info['location']
|
||||
kb_id = doc_info["kb_id"]
|
||||
file_location = doc_info["location"]
|
||||
# 从文件路径中提取原始后缀名
|
||||
_, file_extension = os.path.splitext(file_location)
|
||||
file_type = doc_info['type'].lower()
|
||||
parser_config = json.loads(doc_info['parser_config']) if isinstance(doc_info['parser_config'], str) else doc_info['parser_config']
|
||||
bucket_name = file_info['parent_id'] # 文件存储的桶是 parent_id
|
||||
tenant_id = doc_info['created_by'] # 文档创建者作为 tenant_id
|
||||
file_type = doc_info["type"].lower()
|
||||
bucket_name = file_info["parent_id"] # 文件存储的桶是 parent_id
|
||||
tenant_id = doc_info["created_by"] # 文档创建者作为 tenant_id
|
||||
|
||||
# 进度更新回调 (直接调用内部更新函数)
|
||||
def update_progress(prog=None, msg=None):
|
||||
|
@ -260,13 +277,13 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
|
||||
# 2. 根据文件类型选择解析器
|
||||
content_list = []
|
||||
if file_type.endswith('pdf'):
|
||||
if file_type.endswith("pdf"):
|
||||
update_progress(0.3, "使用MinerU解析器")
|
||||
|
||||
# 创建临时文件保存PDF内容
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_pdf_path = os.path.join(temp_dir, f"{doc_id}.pdf")
|
||||
with open(temp_pdf_path, 'wb') as f:
|
||||
with open(temp_pdf_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
# 使用MinerU处理
|
||||
|
@ -295,12 +312,12 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
middle_content = pipe_result.get_middle_json()
|
||||
middle_json_content = json.loads(middle_content)
|
||||
|
||||
elif file_type.endswith('word') or file_type.endswith('ppt'):
|
||||
elif file_type.endswith("word") or file_type.endswith("ppt"):
|
||||
update_progress(0.3, "使用MinerU解析器")
|
||||
# 创建临时文件保存文件内容
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_file_path = os.path.join(temp_dir, f"{doc_id}{file_extension}")
|
||||
with open(temp_file_path, 'wb') as f:
|
||||
with open(temp_file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
print(f"[Parser-INFO] 临时文件路径: {temp_file_path}")
|
||||
|
@ -313,7 +330,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
os.makedirs(temp_image_dir, exist_ok=True)
|
||||
image_writer = FileBasedDataWriter(temp_image_dir)
|
||||
|
||||
update_progress(0.6, f"处理文件结果")
|
||||
update_progress(0.6, "处理文件结果")
|
||||
pipe_result = infer_result.pipe_txt_mode(image_writer)
|
||||
|
||||
update_progress(0.8, "提取内容")
|
||||
|
@ -332,22 +349,19 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
if middle_json_content:
|
||||
try:
|
||||
if isinstance(middle_json_content, dict):
|
||||
middle_data = middle_json_content # 直接赋值
|
||||
middle_data = middle_json_content # 直接赋值
|
||||
else:
|
||||
middle_data = None
|
||||
print(f"[Parser-WARNING] middle_json_content 不是预期的字典格式,实际类型: {type(middle_json_content)}。")
|
||||
# 提取信息
|
||||
# 提取信息
|
||||
for page_idx, page_data in enumerate(middle_data.get("pdf_info", [])):
|
||||
for block in page_data.get("preproc_blocks", []):
|
||||
block_bbox = get_bbox_from_block(block)
|
||||
# 仅提取包含文本且有 bbox 的块
|
||||
if block_bbox != [0, 0, 0, 0]:
|
||||
block_info_list.append({
|
||||
"page_idx": page_idx,
|
||||
"bbox": block_bbox
|
||||
})
|
||||
block_info_list.append({"page_idx": page_idx, "bbox": block_bbox})
|
||||
else:
|
||||
print(f"[Parser-WARNING] 块的 bbox 格式无效: {bbox},跳过。")
|
||||
print("[Parser-WARNING] 块的 bbox 格式无效,跳过。")
|
||||
|
||||
print(f"[Parser-INFO] 从 middle_data 提取了 {len(block_info_list)} 个块的信息。")
|
||||
|
||||
|
@ -355,7 +369,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
print("[Parser-ERROR] 解析 middle_json_content 失败。")
|
||||
except Exception as e:
|
||||
print(f"[Parser-ERROR] 处理 middle_json_content 时出错: {e}")
|
||||
|
||||
|
||||
# 3. 处理解析结果 (上传到MinIO, 存储到ES)
|
||||
update_progress(0.95, "保存解析结果")
|
||||
es_client = get_es_client()
|
||||
|
@ -367,34 +381,26 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
|
||||
index_name = f"ragflow_{tenant_id}"
|
||||
if not es_client.indices.exists(index=index_name):
|
||||
# 创建索引
|
||||
# 创建索引
|
||||
es_client.indices.create(
|
||||
index=index_name,
|
||||
body={
|
||||
"settings": {"number_of_replicas": 0},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"},
|
||||
"kb_id": {"type": "keyword"},
|
||||
"content_with_weight": {"type": "text"},
|
||||
"q_1024_vec": {
|
||||
"type": "dense_vector",
|
||||
"dims": 1024
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"properties": {"doc_id": {"type": "keyword"}, "kb_id": {"type": "keyword"}, "content_with_weight": {"type": "text"}, "q_1024_vec": {"type": "dense_vector", "dims": 1024}}
|
||||
},
|
||||
},
|
||||
)
|
||||
print(f"[Parser-INFO] 创建Elasticsearch索引: {index_name}")
|
||||
|
||||
chunk_count = 0
|
||||
chunk_ids_list = []
|
||||
middle_block_idx = 0 # 用于按顺序匹配 block_info_list
|
||||
processed_text_chunks = 0 # 记录处理的文本块数量
|
||||
middle_block_idx = 0 # 用于按顺序匹配 block_info_list
|
||||
processed_text_chunks = 0 # 记录处理的文本块数量
|
||||
|
||||
for chunk_idx, chunk_data in enumerate(content_list):
|
||||
page_idx = 0 # 默认页面索引
|
||||
bbox = [0, 0, 0, 0] # 默认 bbox
|
||||
bbox = [0, 0, 0, 0] # 默认 bbox
|
||||
|
||||
# 尝试使用 chunk_idx 直接从 block_info_list 获取对应的块信息
|
||||
if chunk_idx < len(block_info_list):
|
||||
|
@ -410,7 +416,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
# 仅在第一次索引越界时打印一次警告,避免刷屏
|
||||
if chunk_idx == len(block_info_list):
|
||||
print(f"[Parser-WARNING] block_info_list 的长度 ({len(block_info_list)}) 小于 content_list 的长度 ({len(content_list)})。后续块将使用默认 page_idx 和 bbox。")
|
||||
|
||||
|
||||
if chunk_data["type"] == "text" or chunk_data["type"] == "table":
|
||||
if chunk_data["type"] == "text":
|
||||
content = chunk_data["text"]
|
||||
|
@ -419,8 +425,8 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
# 过滤 markdown 特殊符号
|
||||
content = re.sub(r"[!#\\$/]", "", content)
|
||||
elif chunk_data["type"] == "table":
|
||||
caption_list = chunk_data.get("table_caption", []) # 获取列表,默认为空列表
|
||||
table_body = chunk_data.get("table_body", "") # 获取表格主体,默认为空字符串
|
||||
caption_list = chunk_data.get("table_caption", []) # 获取列表,默认为空列表
|
||||
table_body = chunk_data.get("table_body", "") # 获取表格主体,默认为空字符串
|
||||
# 检查 caption_list 是否为列表,并且包含字符串元素
|
||||
if isinstance(caption_list, list) and all(isinstance(item, str) for item in caption_list):
|
||||
# 使用空格将列表中的所有字符串拼接起来
|
||||
|
@ -433,9 +439,8 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
caption_str = ""
|
||||
# 将处理后的标题字符串和表格主体拼接
|
||||
content = caption_str + table_body
|
||||
|
||||
|
||||
q_1024_vec = [] # 初始化为空列表
|
||||
|
||||
q_1024_vec = [] # 初始化为空列表
|
||||
# 获取embedding向量
|
||||
try:
|
||||
# embedding_resp = requests.post(
|
||||
|
@ -451,15 +456,15 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
headers["Authorization"] = f"Bearer {embedding_api_key}"
|
||||
|
||||
embedding_resp = requests.post(
|
||||
embedding_url, # 使用动态构建的 URL
|
||||
headers=headers, # 添加 headers (包含可能的 API Key)
|
||||
embedding_url, # 使用动态构建的 URL
|
||||
headers=headers, # 添加 headers (包含可能的 API Key)
|
||||
json={
|
||||
"model": embedding_model_name, # 使用动态获取或默认的模型名
|
||||
"input": content
|
||||
"input": content,
|
||||
},
|
||||
timeout=15 # 稍微增加超时时间
|
||||
timeout=15, # 稍微增加超时时间
|
||||
)
|
||||
|
||||
|
||||
embedding_resp.raise_for_status()
|
||||
embedding_data = embedding_resp.json()
|
||||
q_1024_vec = embedding_data["data"][0]["embedding"]
|
||||
|
@ -467,20 +472,20 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
except Exception as e:
|
||||
print(f"[Parser-ERROR] 获取embedding失败: {e}")
|
||||
q_1024_vec = []
|
||||
|
||||
|
||||
chunk_id = generate_uuid()
|
||||
|
||||
|
||||
try:
|
||||
# 上传文本块到 MinIO
|
||||
minio_client.put_object(
|
||||
bucket_name=output_bucket,
|
||||
object_name=chunk_id,
|
||||
data=BytesIO(content.encode('utf-8')),
|
||||
length=len(content.encode('utf-8')) # 使用字节长度
|
||||
data=BytesIO(content.encode("utf-8")),
|
||||
length=len(content.encode("utf-8")), # 使用字节长度
|
||||
)
|
||||
|
||||
|
||||
# 准备ES文档
|
||||
content_tokens = tokenize_text(content) # 分词
|
||||
content_tokens = tokenize_text(content) # 分词
|
||||
current_time_es = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
current_timestamp_es = datetime.now().timestamp()
|
||||
|
||||
|
@ -491,34 +496,34 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
es_doc = {
|
||||
"doc_id": doc_id,
|
||||
"kb_id": kb_id,
|
||||
"docnm_kwd": doc_info['name'],
|
||||
"title_tks": doc_info['name'],
|
||||
"title_sm_tks": doc_info['name'],
|
||||
"docnm_kwd": doc_info["name"],
|
||||
"title_tks": doc_info["name"],
|
||||
"title_sm_tks": doc_info["name"],
|
||||
"content_with_weight": content,
|
||||
"content_ltks": " ".join(content_tokens), # 字符串类型
|
||||
"content_ltks": " ".join(content_tokens), # 字符串类型
|
||||
"content_sm_ltks": " ".join(content_tokens), # 字符串类型
|
||||
"page_num_int": [page_idx + 1],
|
||||
"position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]]
|
||||
"position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]]
|
||||
"top_int": [1],
|
||||
"create_time": current_time_es,
|
||||
"create_timestamp_flt": current_timestamp_es,
|
||||
"img_id": "",
|
||||
"q_1024_vec": q_1024_vec
|
||||
"q_1024_vec": q_1024_vec,
|
||||
}
|
||||
|
||||
# 存储到Elasticsearch
|
||||
es_client.index(index=index_name, id=chunk_id, document=es_doc) # 使用 document 参数
|
||||
es_client.index(index=index_name, id=chunk_id, document=es_doc) # 使用 document 参数
|
||||
|
||||
chunk_count += 1
|
||||
processed_text_chunks += 1
|
||||
chunk_ids_list.append(chunk_id)
|
||||
# print(f"成功处理文本块 {chunk_count}/{len(content_list)}") # 可以取消注释用于详细调试
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Parser-ERROR] 处理文本块 {chunk_idx} (page: {page_idx}, bbox: {bbox}) 失败: {e}")
|
||||
traceback.print_exc() # 打印更详细的错误
|
||||
traceback.print_exc() # 打印更详细的错误
|
||||
|
||||
elif chunk_data["type"] == "image":
|
||||
img_path_relative = chunk_data.get('img_path')
|
||||
img_path_relative = chunk_data.get("img_path")
|
||||
if not img_path_relative or not temp_image_dir:
|
||||
continue
|
||||
|
||||
|
@ -529,31 +534,17 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
|
||||
img_id = generate_uuid()
|
||||
img_ext = os.path.splitext(img_path_abs)[1]
|
||||
img_key = f"images/{img_id}{img_ext}" # MinIO中的对象名
|
||||
img_key = f"images/{img_id}{img_ext}" # MinIO中的对象名
|
||||
content_type = f"image/{img_ext[1:].lower()}"
|
||||
if content_type == "image/jpg": content_type = "image/jpeg"
|
||||
if content_type == "image/jpg":
|
||||
content_type = "image/jpeg"
|
||||
|
||||
try:
|
||||
# 上传图片到MinIO (桶为kb_id)
|
||||
minio_client.fput_object(
|
||||
bucket_name=output_bucket,
|
||||
object_name=img_key,
|
||||
file_path=img_path_abs,
|
||||
content_type=content_type
|
||||
)
|
||||
minio_client.fput_object(bucket_name=output_bucket, object_name=img_key, file_path=img_path_abs, content_type=content_type)
|
||||
|
||||
# 设置图片的公共访问权限
|
||||
policy = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"AWS": "*"},
|
||||
"Action": ["s3:GetObject"],
|
||||
"Resource": [f"arn:aws:s3:::{kb_id}/images/*"]
|
||||
}
|
||||
]
|
||||
}
|
||||
policy = {"Version": "2012-10-17", "Statement": [{"Effect": "Allow", "Principal": {"AWS": "*"}, "Action": ["s3:GetObject"], "Resource": [f"arn:aws:s3:::{kb_id}/images/*"]}]}
|
||||
minio_client.set_bucket_policy(kb_id, json.dumps(policy))
|
||||
|
||||
print(f"成功上传图片: {img_key}")
|
||||
|
@ -565,7 +556,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
# 记录图片信息,包括URL和位置信息
|
||||
image_info = {
|
||||
"url": img_url,
|
||||
"position": processed_text_chunks # 使用当前处理的文本块数作为位置参考
|
||||
"position": processed_text_chunks, # 使用当前处理的文本块数作为位置参考
|
||||
}
|
||||
image_info_list.append(image_info)
|
||||
|
||||
|
@ -573,46 +564,42 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
|
||||
except Exception as e:
|
||||
print(f"[Parser-ERROR] 上传图片 {img_path_abs} 失败: {e}")
|
||||
|
||||
|
||||
# 打印匹配总结信息
|
||||
print(f"[Parser-INFO] 共处理 {processed_text_chunks} 个文本块。")
|
||||
if middle_block_idx < len(block_info_list):
|
||||
print(f"[Parser-WARNING] middle_data 中还有 {len(block_info_list) - middle_block_idx} 个提取的块信息未被使用。")
|
||||
|
||||
# 4. 更新文本块的图像信息
|
||||
print(f"[Parser-WARNING] middle_data 中还有 {len(block_info_list) - middle_block_idx} 个提取的块信息未被使用。")
|
||||
|
||||
# 4. 更新文本块的图像信息
|
||||
if image_info_list and chunk_ids_list:
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = _get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
|
||||
# 为每个文本块找到最近的图片
|
||||
for i, chunk_id in enumerate(chunk_ids_list):
|
||||
# 找到与当前文本块最近的图片
|
||||
nearest_image = None
|
||||
|
||||
|
||||
for img_info in image_info_list:
|
||||
# 计算文本块与图片的"距离"
|
||||
distance = abs(i - img_info["position"]) # 使用位置差作为距离度量
|
||||
# 如果文本块与图片的距离间隔小于10个块,则认为块与图片是相关的
|
||||
if distance < 10:
|
||||
nearest_image = img_info
|
||||
|
||||
|
||||
# 如果找到了最近的图片,则更新文本块的img_id
|
||||
if nearest_image:
|
||||
# 更新ES中的文档
|
||||
direct_update = {
|
||||
"doc": {
|
||||
"img_id": nearest_image["url"]
|
||||
}
|
||||
}
|
||||
direct_update = {"doc": {"img_id": nearest_image["url"]}}
|
||||
es_client.update(index=index_name, id=chunk_id, body=direct_update, refresh=True)
|
||||
|
||||
|
||||
index_name = f"ragflow_{tenant_id}"
|
||||
|
||||
|
||||
print(f"[Parser-INFO] 更新文本块 {chunk_id} 的图片关联: {nearest_image['url']}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Parser-ERROR] 更新文本块图片关联失败: {e}")
|
||||
finally:
|
||||
|
@ -623,23 +610,23 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
|
||||
# 5. 更新最终状态
|
||||
process_duration = time.time() - start_time
|
||||
_update_document_progress(doc_id, progress=1.0, message="解析完成", status='1', run='3', chunk_count=chunk_count, process_duration=process_duration)
|
||||
_update_kb_chunk_count(kb_id, chunk_count) # 更新知识库总块数
|
||||
_create_task_record(doc_id, chunk_ids_list) # 创建task记录
|
||||
_update_document_progress(doc_id, progress=1.0, message="解析完成", status="1", run="3", chunk_count=chunk_count, process_duration=process_duration)
|
||||
_update_kb_chunk_count(kb_id, chunk_count) # 更新知识库总块数
|
||||
_create_task_record(doc_id, chunk_ids_list) # 创建task记录
|
||||
|
||||
update_progress(1.0, "解析完成")
|
||||
print(f"[Parser-INFO] 解析完成,文档ID: {doc_id}, 耗时: {process_duration:.2f}s, 块数: {chunk_count}")
|
||||
|
||||
return {"success": True, "chunk_count": chunk_count}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
process_duration = time.time() - start_time
|
||||
# error_message = f"解析失败: {str(e)}"
|
||||
print(f"[Parser-ERROR] 文档 {doc_id} 解析失败: {e}")
|
||||
error_message = f"解析失败: {e}"
|
||||
traceback.print_exc() # 打印详细错误堆栈
|
||||
traceback.print_exc() # 打印详细错误堆栈
|
||||
# 更新文档状态为失败
|
||||
_update_document_progress(doc_id, status='1', run='0', message=error_message, process_duration=process_duration) # status=1表示完成,run=0表示失败
|
||||
_update_document_progress(doc_id, status="1", run="0", message=error_message, process_duration=process_duration) # status=1表示完成,run=0表示失败
|
||||
# 不抛出异常,让调用者知道任务已结束(但失败)
|
||||
return {"success": False, "error": error_message}
|
||||
|
||||
|
@ -651,4 +638,4 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
if temp_image_dir and os.path.exists(temp_image_dir):
|
||||
shutil.rmtree(temp_image_dir, ignore_errors=True)
|
||||
except Exception as clean_e:
|
||||
print(f"[Parser-WARNING] 清理临时文件失败: {clean_e}")
|
||||
print(f"[Parser-WARNING] 清理临时文件失败: {clean_e}")
|
||||
|
|
Loading…
Reference in New Issue