feat: 增加支持对csv格式文件的上传和解析

This commit is contained in:
zstar 2025-06-12 22:50:19 +08:00
parent dfb7867561
commit 0b1126b1c8
8 changed files with 122 additions and 79 deletions

View File

@ -1,11 +1,11 @@
from flask import jsonify, request, send_file, current_app
from io import BytesIO from io import BytesIO
from .. import files_bp
from flask import current_app, jsonify, request, send_file
from services.files.service import get_files_list, get_file_info, download_file_from_minio, delete_file, batch_delete_files, handle_chunk_upload, merge_chunks, upload_files_to_server from services.files.service import batch_delete_files, delete_file, download_file_from_minio, get_file_info, get_files_list, handle_chunk_upload, merge_chunks, upload_files_to_server
from services.files.utils import FileType from services.files.utils import FileType
from .. import files_bp
UPLOAD_FOLDER = "/data/uploads" UPLOAD_FOLDER = "/data/uploads"
ALLOWED_EXTENSIONS = {"txt", "pdf", "png", "jpg", "jpeg", "gif", "doc", "docx", "xls", "xlsx"} ALLOWED_EXTENSIONS = {"txt", "pdf", "png", "jpg", "jpeg", "gif", "doc", "docx", "xls", "xlsx"}

View File

@ -1,16 +1,18 @@
from peewee import * from peewee import * # noqa: F403
from .base_service import BaseService from .base_service import BaseService
from .models import Document from .models import Document
from .utils import get_uuid, StatusEnum from .utils import StatusEnum, get_uuid
class DocumentService(BaseService): class DocumentService(BaseService):
model = Document model = Document
@classmethod @classmethod
def create_document(cls, kb_id: str, name: str, location: str, size: int, file_type: str, created_by: str = None, parser_id: str = None, parser_config: dict = None) -> Document: def create_document(cls, kb_id: str, name: str, location: str, size: int, file_type: str, created_by: str = None, parser_id: str = None, parser_config: dict = None) -> Document:
""" """
创建文档记录 创建文档记录
Args: Args:
kb_id: 知识库ID kb_id: 知识库ID
name: 文件名 name: 文件名
@ -20,34 +22,34 @@ class DocumentService(BaseService):
created_by: 创建者ID created_by: 创建者ID
parser_id: 解析器ID parser_id: 解析器ID
parser_config: 解析器配置 parser_config: 解析器配置
Returns: Returns:
Document: 创建的文档对象 Document: 创建的文档对象
""" """
doc_id = get_uuid() doc_id = get_uuid()
# 构建基本文档数据 # 构建基本文档数据
doc_data = { doc_data = {
'id': doc_id, "id": doc_id,
'kb_id': kb_id, "kb_id": kb_id,
'name': name, "name": name,
'location': location, "location": location,
'size': size, "size": size,
'type': file_type, "type": file_type,
'created_by': created_by or 'system', "created_by": created_by or "system",
'parser_id': parser_id or '', "parser_id": parser_id or "",
'parser_config': parser_config or {"pages": [[1, 1000000]]}, "parser_config": parser_config or {"pages": [[1, 1000000]]},
'source_type': 'local', "source_type": "local",
'token_num': 0, "token_num": 0,
'chunk_num': 0, "chunk_num": 0,
'progress': 0, "progress": 0,
'progress_msg': '', "progress_msg": "",
'run': '0', # 未开始解析 "run": "0", # 未开始解析
'status': StatusEnum.VALID.value "status": StatusEnum.VALID.value,
} }
return cls.insert(doc_data) return cls.insert(doc_data)
@classmethod @classmethod
def get_by_kb_id(cls, kb_id: str) -> list[Document]: def get_by_kb_id(cls, kb_id: str) -> list[Document]:
return cls.query(kb_id=kb_id) return cls.query(kb_id=kb_id)

View File

@ -1,26 +1,20 @@
from peewee import * # noqa: F403 from peewee import * # noqa: F403
from .base_service import BaseService from .base_service import BaseService
from .models import File from .models import File
from .utils import FileType, get_uuid from .utils import FileType, get_uuid
class FileService(BaseService): class FileService(BaseService):
model = File model = File
@classmethod @classmethod
def create_file(cls, parent_id: str, name: str, location: str, size: int, file_type: str) -> File: def create_file(cls, parent_id: str, name: str, location: str, size: int, file_type: str) -> File:
return cls.insert({ return cls.insert({"parent_id": parent_id, "name": name, "location": location, "size": size, "type": file_type, "source_type": "knowledgebase"})
'parent_id': parent_id,
'name': name,
'location': location,
'size': size,
'type': file_type,
'source_type': 'knowledgebase'
})
@classmethod @classmethod
def get_parser(cls, file_type, filename, tenant_id): def get_parser(cls, file_type, filename, tenant_id):
"""获取适合文件类型的解析器ID""" """获取适合文件类型的解析器ID"""
# 这里可能需要根据实际情况调整
if file_type == FileType.PDF.value: if file_type == FileType.PDF.value:
return "pdf_parser" return "pdf_parser"
elif file_type == FileType.WORD.value: elif file_type == FileType.WORD.value:
@ -40,7 +34,7 @@ class FileService(BaseService):
def get_by_parent_id(cls, parent_id: str) -> list[File]: def get_by_parent_id(cls, parent_id: str) -> list[File]:
return cls.query(parent_id=parent_id) return cls.query(parent_id=parent_id)
@classmethod @classmethod
def generate_bucket_name(cls): def generate_bucket_name(cls):
"""生成随机存储桶名称""" """生成随机存储桶名称"""
return f"kb-{get_uuid()}" return f"kb-{get_uuid()}"

View File

@ -1,13 +1,14 @@
import os import os
import shutil
import re import re
import shutil
import tempfile import tempfile
from dotenv import load_dotenv
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from database import get_db_connection, get_minio_client, get_redis_connection
from .utils import FileType, FileSource, get_uuid
from database import get_db_connection, get_minio_client, get_redis_connection
from dotenv import load_dotenv
from .utils import FileSource, FileType, get_uuid
# 加载环境变量 # 加载环境变量
load_dotenv("../../docker/.env") load_dotenv("../../docker/.env")
@ -18,7 +19,7 @@ CHUNK_EXPIRY_SECONDS = 3600 * 24 # 分块24小时过期
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
UPLOAD_FOLDER = os.path.join(temp_dir, "uploads") UPLOAD_FOLDER = os.path.join(temp_dir, "uploads")
ALLOWED_EXTENSIONS = {"pdf", "doc", "docx", "ppt", "pptx", "xls", "xlsx", "jpg", "jpeg", "png", "bmp", "txt", "md", "html"} ALLOWED_EXTENSIONS = {"pdf", "doc", "docx", "ppt", "pptx", "xls", "xlsx", "jpg", "jpeg", "png", "bmp", "txt", "md", "html", "csv"}
def allowed_file(filename): def allowed_file(filename):
@ -36,7 +37,7 @@ def filename_type(filename):
return FileType.PDF.value return FileType.PDF.value
elif ext in [".doc", ".docx"]: elif ext in [".doc", ".docx"]:
return FileType.WORD.value return FileType.WORD.value
elif ext in [".xls", ".xlsx"]: elif ext in [".xls", ".xlsx", ".csv"]:
return FileType.EXCEL.value return FileType.EXCEL.value
elif ext in [".ppt", ".pptx"]: elif ext in [".ppt", ".pptx"]:
return FileType.PPT.value return FileType.PPT.value

View File

@ -1,9 +1,9 @@
import uuid import uuid
from strenum import StrEnum
from enum import Enum from enum import Enum
from strenum import StrEnum
# 参考api.db
class FileType(StrEnum): class FileType(StrEnum):
FOLDER = "folder" FOLDER = "folder"
PDF = "pdf" PDF = "pdf"
@ -15,15 +15,18 @@ class FileType(StrEnum):
HTML = "html" HTML = "html"
OTHER = "other" OTHER = "other"
class FileSource(StrEnum): class FileSource(StrEnum):
LOCAL = "" LOCAL = ""
KNOWLEDGEBASE = "knowledgebase" KNOWLEDGEBASE = "knowledgebase"
S3 = "s3" S3 = "s3"
class StatusEnum(Enum): class StatusEnum(Enum):
VALID = "1" VALID = "1"
INVALID = "0" INVALID = "0"
# 参考api.utils # 参考api.utils
def get_uuid(): def get_uuid():
return uuid.uuid1().hex return uuid.uuid1().hex

View File

@ -20,7 +20,7 @@ from magic_pdf.data.read_api import read_local_images, read_local_office
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from . import logger from . import logger
from .excel_parser import parse_excel from .excel_parser import parse_excel_file
from .rag_tokenizer import RagTokenizer from .rag_tokenizer import RagTokenizer
from .utils import _create_task_record, _update_document_progress, _update_kb_chunk_count, generate_uuid, get_bbox_from_block from .utils import _create_task_record, _update_document_progress, _update_kb_chunk_count, generate_uuid, get_bbox_from_block
@ -196,7 +196,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
update_progress(0.8, "提取内容") update_progress(0.8, "提取内容")
# 处理内容列表 # 处理内容列表
content_list = parse_excel(temp_file_path) content_list = parse_excel_file(temp_file_path)
elif file_type.endswith("visual"): elif file_type.endswith("visual"):
update_progress(0.3, "使用MinerU解析器") update_progress(0.3, "使用MinerU解析器")

View File

@ -1,24 +1,69 @@
import os
import pandas as pd import pandas as pd
def parse_excel(file_path): def parse_excel_file(file_path):
all_sheets = pd.read_excel(file_path, sheet_name=None) """
通用表格解析函数支持 Excel (.xlsx/.xls) CSV 文件
返回统一格式的数据块列表
"""
blocks = [] blocks = []
for sheet_name, df in all_sheets.items(): # 根据文件扩展名选择读取方式
df = df.ffill() file_ext = os.path.splitext(file_path)[1].lower()
headers = df.columns.tolist()
for _, row in df.iterrows(): try:
html_table = "<html><body><table><tr>{}</tr><tr>{}</tr></table></body></html>".format("".join(f"<td>{col}</td>" for col in headers), "".join(f"<td>{row[col]}</td>" for col in headers)) if file_ext in (".xlsx", ".xls"):
block = {"type": "table", "img_path": "", "table_caption": [f"Sheet: {sheet_name}"], "table_footnote": [], "table_body": f"{html_table}", "page_idx": 0} # 处理Excel文件多sheet
blocks.append(block) all_sheets = pd.read_excel(file_path, sheet_name=None)
for sheet_name, df in all_sheets.items():
blocks.extend(_process_dataframe(df, sheet_name))
elif file_ext == ".csv":
# 处理CSV文件单sheet
df = pd.read_csv(file_path)
blocks.extend(_process_dataframe(df, "CSV"))
else:
raise ValueError(f"Unsupported file format: {file_ext}")
except Exception as e:
raise ValueError(f"Failed to parse file {file_path}: {str(e)}")
return blocks
def _process_dataframe(df, sheet_name):
"""处理单个DataFrame生成统一格式的数据块"""
df = df.ffill()
headers = df.columns.tolist()
blocks = []
for _, row in df.iterrows():
html_table = "<html><body><table><tr>{}</tr><tr>{}</tr></table></body></html>".format("".join(f"<td>{col}</td>" for col in headers), "".join(f"<td>{row[col]}</td>" for col in headers))
block = {"type": "table", "img_path": "", "table_caption": "", "table_footnote": [], "table_body": html_table, "page_idx": 0}
blocks.append(block)
return blocks return blocks
if __name__ == "__main__": if __name__ == "__main__":
file_path = "test_excel.xls" # 测试示例
parse_excel_result = parse_excel(file_path) excel_path = "test.xlsx"
print(parse_excel_result) csv_path = "test.csv"
try:
# 测试Excel解析
excel_blocks = parse_excel_file(excel_path)
print(f"Excel解析结果{len(excel_blocks)}条):")
print(excel_blocks[:1]) # 打印第一条示例
# 测试CSV解析
csv_blocks = parse_excel_file(csv_path)
print(f"\nCSV解析结果{len(csv_blocks)}条):")
print(csv_blocks[:1])
except Exception as e:
print(f"错误: {str(e)}")

View File

@ -1,8 +1,9 @@
import uuid
import base64 import base64
from flask import jsonify import uuid
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 from Cryptodome.Cipher import PKCS1_v1_5
from Cryptodome.PublicKey import RSA
from flask import jsonify
from werkzeug.security import generate_password_hash from werkzeug.security import generate_password_hash
@ -10,6 +11,7 @@ from werkzeug.security import generate_password_hash
def generate_uuid(): def generate_uuid():
return str(uuid.uuid4()).replace("-", "") return str(uuid.uuid4()).replace("-", "")
# RSA 加密密码 # RSA 加密密码
def rsa_psw(password: str) -> str: def rsa_psw(password: str) -> str:
pub_key = """-----BEGIN PUBLIC KEY----- pub_key = """-----BEGIN PUBLIC KEY-----
@ -21,26 +23,22 @@ def rsa_psw(password: str) -> str:
encrypted_data = cipher.encrypt(base64.b64encode(password.encode())) encrypted_data = cipher.encrypt(base64.b64encode(password.encode()))
return base64.b64encode(encrypted_data).decode() return base64.b64encode(encrypted_data).decode()
# 加密密码 # 加密密码
def encrypt_password(raw_password: str) -> str: def encrypt_password(raw_password: str) -> str:
base64_password = base64.b64encode(raw_password.encode()).decode() base64_password = base64.b64encode(raw_password.encode()).decode()
return generate_password_hash(base64_password) return generate_password_hash(base64_password)
# 标准响应格式 # 标准响应格式
def success_response(data=None, message="操作成功", code=0): def success_response(data=None, message="操作成功", code=0):
return jsonify({ return jsonify({"code": code, "message": message, "data": data})
"code": code,
"message": message,
"data": data
})
# 错误响应格式 # 错误响应格式
def error_response(message="操作失败", code=500, details=None): def error_response(message="操作失败", code=500, details=None):
"""标准错误响应格式""" """标准错误响应格式"""
response = { response = {"code": code, "message": message}
"code": code,
"message": message
}
if details: if details:
response["details"] = details response["details"] = details
return jsonify(response), code if code >= 400 else 500 return jsonify(response), code if code >= 400 else 500