feat: 增加支持对csv格式文件的上传和解析
This commit is contained in:
parent
dfb7867561
commit
0b1126b1c8
|
@ -1,11 +1,11 @@
|
||||||
from flask import jsonify, request, send_file, current_app
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from .. import files_bp
|
|
||||||
|
|
||||||
|
from flask import current_app, jsonify, request, send_file
|
||||||
from services.files.service import get_files_list, get_file_info, download_file_from_minio, delete_file, batch_delete_files, handle_chunk_upload, merge_chunks, upload_files_to_server
|
from services.files.service import batch_delete_files, delete_file, download_file_from_minio, get_file_info, get_files_list, handle_chunk_upload, merge_chunks, upload_files_to_server
|
||||||
from services.files.utils import FileType
|
from services.files.utils import FileType
|
||||||
|
|
||||||
|
from .. import files_bp
|
||||||
|
|
||||||
UPLOAD_FOLDER = "/data/uploads"
|
UPLOAD_FOLDER = "/data/uploads"
|
||||||
ALLOWED_EXTENSIONS = {"txt", "pdf", "png", "jpg", "jpeg", "gif", "doc", "docx", "xls", "xlsx"}
|
ALLOWED_EXTENSIONS = {"txt", "pdf", "png", "jpg", "jpeg", "gif", "doc", "docx", "xls", "xlsx"}
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
from peewee import *
|
from peewee import * # noqa: F403
|
||||||
|
|
||||||
from .base_service import BaseService
|
from .base_service import BaseService
|
||||||
from .models import Document
|
from .models import Document
|
||||||
from .utils import get_uuid, StatusEnum
|
from .utils import StatusEnum, get_uuid
|
||||||
|
|
||||||
|
|
||||||
class DocumentService(BaseService):
|
class DocumentService(BaseService):
|
||||||
model = Document
|
model = Document
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_document(cls, kb_id: str, name: str, location: str, size: int, file_type: str, created_by: str = None, parser_id: str = None, parser_config: dict = None) -> Document:
|
def create_document(cls, kb_id: str, name: str, location: str, size: int, file_type: str, created_by: str = None, parser_id: str = None, parser_config: dict = None) -> Document:
|
||||||
"""
|
"""
|
||||||
创建文档记录
|
创建文档记录
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kb_id: 知识库ID
|
kb_id: 知识库ID
|
||||||
name: 文件名
|
name: 文件名
|
||||||
|
@ -20,34 +22,34 @@ class DocumentService(BaseService):
|
||||||
created_by: 创建者ID
|
created_by: 创建者ID
|
||||||
parser_id: 解析器ID
|
parser_id: 解析器ID
|
||||||
parser_config: 解析器配置
|
parser_config: 解析器配置
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Document: 创建的文档对象
|
Document: 创建的文档对象
|
||||||
"""
|
"""
|
||||||
doc_id = get_uuid()
|
doc_id = get_uuid()
|
||||||
|
|
||||||
# 构建基本文档数据
|
# 构建基本文档数据
|
||||||
doc_data = {
|
doc_data = {
|
||||||
'id': doc_id,
|
"id": doc_id,
|
||||||
'kb_id': kb_id,
|
"kb_id": kb_id,
|
||||||
'name': name,
|
"name": name,
|
||||||
'location': location,
|
"location": location,
|
||||||
'size': size,
|
"size": size,
|
||||||
'type': file_type,
|
"type": file_type,
|
||||||
'created_by': created_by or 'system',
|
"created_by": created_by or "system",
|
||||||
'parser_id': parser_id or '',
|
"parser_id": parser_id or "",
|
||||||
'parser_config': parser_config or {"pages": [[1, 1000000]]},
|
"parser_config": parser_config or {"pages": [[1, 1000000]]},
|
||||||
'source_type': 'local',
|
"source_type": "local",
|
||||||
'token_num': 0,
|
"token_num": 0,
|
||||||
'chunk_num': 0,
|
"chunk_num": 0,
|
||||||
'progress': 0,
|
"progress": 0,
|
||||||
'progress_msg': '',
|
"progress_msg": "",
|
||||||
'run': '0', # 未开始解析
|
"run": "0", # 未开始解析
|
||||||
'status': StatusEnum.VALID.value
|
"status": StatusEnum.VALID.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
return cls.insert(doc_data)
|
return cls.insert(doc_data)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_by_kb_id(cls, kb_id: str) -> list[Document]:
|
def get_by_kb_id(cls, kb_id: str) -> list[Document]:
|
||||||
return cls.query(kb_id=kb_id)
|
return cls.query(kb_id=kb_id)
|
||||||
|
|
|
@ -1,26 +1,20 @@
|
||||||
from peewee import * # noqa: F403
|
from peewee import * # noqa: F403
|
||||||
|
|
||||||
from .base_service import BaseService
|
from .base_service import BaseService
|
||||||
from .models import File
|
from .models import File
|
||||||
from .utils import FileType, get_uuid
|
from .utils import FileType, get_uuid
|
||||||
|
|
||||||
|
|
||||||
class FileService(BaseService):
|
class FileService(BaseService):
|
||||||
model = File
|
model = File
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_file(cls, parent_id: str, name: str, location: str, size: int, file_type: str) -> File:
|
def create_file(cls, parent_id: str, name: str, location: str, size: int, file_type: str) -> File:
|
||||||
return cls.insert({
|
return cls.insert({"parent_id": parent_id, "name": name, "location": location, "size": size, "type": file_type, "source_type": "knowledgebase"})
|
||||||
'parent_id': parent_id,
|
|
||||||
'name': name,
|
|
||||||
'location': location,
|
|
||||||
'size': size,
|
|
||||||
'type': file_type,
|
|
||||||
'source_type': 'knowledgebase'
|
|
||||||
})
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_parser(cls, file_type, filename, tenant_id):
|
def get_parser(cls, file_type, filename, tenant_id):
|
||||||
"""获取适合文件类型的解析器ID"""
|
"""获取适合文件类型的解析器ID"""
|
||||||
# 这里可能需要根据实际情况调整
|
|
||||||
if file_type == FileType.PDF.value:
|
if file_type == FileType.PDF.value:
|
||||||
return "pdf_parser"
|
return "pdf_parser"
|
||||||
elif file_type == FileType.WORD.value:
|
elif file_type == FileType.WORD.value:
|
||||||
|
@ -40,7 +34,7 @@ class FileService(BaseService):
|
||||||
def get_by_parent_id(cls, parent_id: str) -> list[File]:
|
def get_by_parent_id(cls, parent_id: str) -> list[File]:
|
||||||
return cls.query(parent_id=parent_id)
|
return cls.query(parent_id=parent_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_bucket_name(cls):
|
def generate_bucket_name(cls):
|
||||||
"""生成随机存储桶名称"""
|
"""生成随机存储桶名称"""
|
||||||
return f"kb-{get_uuid()}"
|
return f"kb-{get_uuid()}"
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from dotenv import load_dotenv
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from database import get_db_connection, get_minio_client, get_redis_connection
|
|
||||||
from .utils import FileType, FileSource, get_uuid
|
|
||||||
|
|
||||||
|
from database import get_db_connection, get_minio_client, get_redis_connection
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from .utils import FileSource, FileType, get_uuid
|
||||||
|
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
load_dotenv("../../docker/.env")
|
load_dotenv("../../docker/.env")
|
||||||
|
@ -18,7 +19,7 @@ CHUNK_EXPIRY_SECONDS = 3600 * 24 # 分块24小时过期
|
||||||
|
|
||||||
temp_dir = tempfile.gettempdir()
|
temp_dir = tempfile.gettempdir()
|
||||||
UPLOAD_FOLDER = os.path.join(temp_dir, "uploads")
|
UPLOAD_FOLDER = os.path.join(temp_dir, "uploads")
|
||||||
ALLOWED_EXTENSIONS = {"pdf", "doc", "docx", "ppt", "pptx", "xls", "xlsx", "jpg", "jpeg", "png", "bmp", "txt", "md", "html"}
|
ALLOWED_EXTENSIONS = {"pdf", "doc", "docx", "ppt", "pptx", "xls", "xlsx", "jpg", "jpeg", "png", "bmp", "txt", "md", "html", "csv"}
|
||||||
|
|
||||||
|
|
||||||
def allowed_file(filename):
|
def allowed_file(filename):
|
||||||
|
@ -36,7 +37,7 @@ def filename_type(filename):
|
||||||
return FileType.PDF.value
|
return FileType.PDF.value
|
||||||
elif ext in [".doc", ".docx"]:
|
elif ext in [".doc", ".docx"]:
|
||||||
return FileType.WORD.value
|
return FileType.WORD.value
|
||||||
elif ext in [".xls", ".xlsx"]:
|
elif ext in [".xls", ".xlsx", ".csv"]:
|
||||||
return FileType.EXCEL.value
|
return FileType.EXCEL.value
|
||||||
elif ext in [".ppt", ".pptx"]:
|
elif ext in [".ppt", ".pptx"]:
|
||||||
return FileType.PPT.value
|
return FileType.PPT.value
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import uuid
|
import uuid
|
||||||
from strenum import StrEnum
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
from strenum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
# 参考:api.db
|
|
||||||
class FileType(StrEnum):
|
class FileType(StrEnum):
|
||||||
FOLDER = "folder"
|
FOLDER = "folder"
|
||||||
PDF = "pdf"
|
PDF = "pdf"
|
||||||
|
@ -15,15 +15,18 @@ class FileType(StrEnum):
|
||||||
HTML = "html"
|
HTML = "html"
|
||||||
OTHER = "other"
|
OTHER = "other"
|
||||||
|
|
||||||
|
|
||||||
class FileSource(StrEnum):
|
class FileSource(StrEnum):
|
||||||
LOCAL = ""
|
LOCAL = ""
|
||||||
KNOWLEDGEBASE = "knowledgebase"
|
KNOWLEDGEBASE = "knowledgebase"
|
||||||
S3 = "s3"
|
S3 = "s3"
|
||||||
|
|
||||||
|
|
||||||
class StatusEnum(Enum):
|
class StatusEnum(Enum):
|
||||||
VALID = "1"
|
VALID = "1"
|
||||||
INVALID = "0"
|
INVALID = "0"
|
||||||
|
|
||||||
|
|
||||||
# 参考:api.utils
|
# 参考:api.utils
|
||||||
def get_uuid():
|
def get_uuid():
|
||||||
return uuid.uuid1().hex
|
return uuid.uuid1().hex
|
||||||
|
|
|
@ -20,7 +20,7 @@ from magic_pdf.data.read_api import read_local_images, read_local_office
|
||||||
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
||||||
|
|
||||||
from . import logger
|
from . import logger
|
||||||
from .excel_parser import parse_excel
|
from .excel_parser import parse_excel_file
|
||||||
from .rag_tokenizer import RagTokenizer
|
from .rag_tokenizer import RagTokenizer
|
||||||
from .utils import _create_task_record, _update_document_progress, _update_kb_chunk_count, generate_uuid, get_bbox_from_block
|
from .utils import _create_task_record, _update_document_progress, _update_kb_chunk_count, generate_uuid, get_bbox_from_block
|
||||||
|
|
||||||
|
@ -196,7 +196,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
|
||||||
|
|
||||||
update_progress(0.8, "提取内容")
|
update_progress(0.8, "提取内容")
|
||||||
# 处理内容列表
|
# 处理内容列表
|
||||||
content_list = parse_excel(temp_file_path)
|
content_list = parse_excel_file(temp_file_path)
|
||||||
|
|
||||||
elif file_type.endswith("visual"):
|
elif file_type.endswith("visual"):
|
||||||
update_progress(0.3, "使用MinerU解析器")
|
update_progress(0.3, "使用MinerU解析器")
|
||||||
|
|
|
@ -1,24 +1,69 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
def parse_excel(file_path):
|
def parse_excel_file(file_path):
|
||||||
all_sheets = pd.read_excel(file_path, sheet_name=None)
|
"""
|
||||||
|
通用表格解析函数,支持 Excel (.xlsx/.xls) 和 CSV 文件
|
||||||
|
返回统一格式的数据块列表
|
||||||
|
"""
|
||||||
blocks = []
|
blocks = []
|
||||||
|
|
||||||
for sheet_name, df in all_sheets.items():
|
# 根据文件扩展名选择读取方式
|
||||||
df = df.ffill()
|
file_ext = os.path.splitext(file_path)[1].lower()
|
||||||
headers = df.columns.tolist()
|
|
||||||
|
|
||||||
for _, row in df.iterrows():
|
try:
|
||||||
html_table = "<html><body><table><tr>{}</tr><tr>{}</tr></table></body></html>".format("".join(f"<td>{col}</td>" for col in headers), "".join(f"<td>{row[col]}</td>" for col in headers))
|
if file_ext in (".xlsx", ".xls"):
|
||||||
block = {"type": "table", "img_path": "", "table_caption": [f"Sheet: {sheet_name}"], "table_footnote": [], "table_body": f"{html_table}", "page_idx": 0}
|
# 处理Excel文件(多sheet)
|
||||||
blocks.append(block)
|
all_sheets = pd.read_excel(file_path, sheet_name=None)
|
||||||
|
for sheet_name, df in all_sheets.items():
|
||||||
|
blocks.extend(_process_dataframe(df, sheet_name))
|
||||||
|
|
||||||
|
elif file_ext == ".csv":
|
||||||
|
# 处理CSV文件(单sheet)
|
||||||
|
df = pd.read_csv(file_path)
|
||||||
|
blocks.extend(_process_dataframe(df, "CSV"))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported file format: {file_ext}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to parse file {file_path}: {str(e)}")
|
||||||
|
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def _process_dataframe(df, sheet_name):
|
||||||
|
"""处理单个DataFrame,生成统一格式的数据块"""
|
||||||
|
df = df.ffill()
|
||||||
|
headers = df.columns.tolist()
|
||||||
|
blocks = []
|
||||||
|
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
html_table = "<html><body><table><tr>{}</tr><tr>{}</tr></table></body></html>".format("".join(f"<td>{col}</td>" for col in headers), "".join(f"<td>{row[col]}</td>" for col in headers))
|
||||||
|
|
||||||
|
block = {"type": "table", "img_path": "", "table_caption": "", "table_footnote": [], "table_body": html_table, "page_idx": 0}
|
||||||
|
blocks.append(block)
|
||||||
|
|
||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
file_path = "test_excel.xls"
|
# 测试示例
|
||||||
parse_excel_result = parse_excel(file_path)
|
excel_path = "test.xlsx"
|
||||||
print(parse_excel_result)
|
csv_path = "test.csv"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 测试Excel解析
|
||||||
|
excel_blocks = parse_excel_file(excel_path)
|
||||||
|
print(f"Excel解析结果(共{len(excel_blocks)}条):")
|
||||||
|
print(excel_blocks[:1]) # 打印第一条示例
|
||||||
|
|
||||||
|
# 测试CSV解析
|
||||||
|
csv_blocks = parse_excel_file(csv_path)
|
||||||
|
print(f"\nCSV解析结果(共{len(csv_blocks)}条):")
|
||||||
|
print(csv_blocks[:1])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误: {str(e)}")
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import uuid
|
|
||||||
import base64
|
import base64
|
||||||
from flask import jsonify
|
import uuid
|
||||||
from Cryptodome.PublicKey import RSA
|
|
||||||
from Cryptodome.Cipher import PKCS1_v1_5
|
from Cryptodome.Cipher import PKCS1_v1_5
|
||||||
|
from Cryptodome.PublicKey import RSA
|
||||||
|
from flask import jsonify
|
||||||
from werkzeug.security import generate_password_hash
|
from werkzeug.security import generate_password_hash
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,6 +11,7 @@ from werkzeug.security import generate_password_hash
|
||||||
def generate_uuid():
|
def generate_uuid():
|
||||||
return str(uuid.uuid4()).replace("-", "")
|
return str(uuid.uuid4()).replace("-", "")
|
||||||
|
|
||||||
|
|
||||||
# RSA 加密密码
|
# RSA 加密密码
|
||||||
def rsa_psw(password: str) -> str:
|
def rsa_psw(password: str) -> str:
|
||||||
pub_key = """-----BEGIN PUBLIC KEY-----
|
pub_key = """-----BEGIN PUBLIC KEY-----
|
||||||
|
@ -21,26 +23,22 @@ def rsa_psw(password: str) -> str:
|
||||||
encrypted_data = cipher.encrypt(base64.b64encode(password.encode()))
|
encrypted_data = cipher.encrypt(base64.b64encode(password.encode()))
|
||||||
return base64.b64encode(encrypted_data).decode()
|
return base64.b64encode(encrypted_data).decode()
|
||||||
|
|
||||||
|
|
||||||
# 加密密码
|
# 加密密码
|
||||||
def encrypt_password(raw_password: str) -> str:
|
def encrypt_password(raw_password: str) -> str:
|
||||||
base64_password = base64.b64encode(raw_password.encode()).decode()
|
base64_password = base64.b64encode(raw_password.encode()).decode()
|
||||||
return generate_password_hash(base64_password)
|
return generate_password_hash(base64_password)
|
||||||
|
|
||||||
|
|
||||||
# 标准响应格式
|
# 标准响应格式
|
||||||
def success_response(data=None, message="操作成功", code=0):
|
def success_response(data=None, message="操作成功", code=0):
|
||||||
return jsonify({
|
return jsonify({"code": code, "message": message, "data": data})
|
||||||
"code": code,
|
|
||||||
"message": message,
|
|
||||||
"data": data
|
|
||||||
})
|
|
||||||
|
|
||||||
# 错误响应格式
|
# 错误响应格式
|
||||||
def error_response(message="操作失败", code=500, details=None):
|
def error_response(message="操作失败", code=500, details=None):
|
||||||
"""标准错误响应格式"""
|
"""标准错误响应格式"""
|
||||||
response = {
|
response = {"code": code, "message": message}
|
||||||
"code": code,
|
|
||||||
"message": message
|
|
||||||
}
|
|
||||||
if details:
|
if details:
|
||||||
response["details"] = details
|
response["details"] = details
|
||||||
return jsonify(response), code if code >= 400 else 500
|
return jsonify(response), code if code >= 400 else 500
|
||||||
|
|
Loading…
Reference in New Issue