From 9689a2efd79d8039efc32ee118ecbce35bd8e35e Mon Sep 17 00:00:00 2001 From: zstar <65890619+zstar1003@users.noreply.github.com> Date: Sat, 12 Apr 2025 00:42:19 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E5=8A=9F=E8=83=BD=E5=B9=B6=E5=AE=8C=E5=96=84?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E7=AE=A1=E7=90=86=E6=9C=8D=E5=8A=A1=20(#21)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在管理界面添加文件上传功能,支持多文件上传 - 实现文件上传到服务器的核心逻辑,包括文件类型检查、文件名处理、文件存储等 - 完善文件管理服务,包括文件、文档及其关联关系的数据库操作 - 添加文件类型枚举和工具函数,支持多种文件格式 - 更新前端页面,添加文件上传对话框和上传按钮 --- management/server/check_tables.py | 55 +++++ management/server/routes/files/routes.py | 28 ++- .../server/services/files/base_service.py | 25 +++ .../server/services/files/document_service.py | 53 +++++ .../services/files/file2document_service.py | 21 ++ .../server/services/files/file_service.py | 47 ++++ .../server/services/files/models/__init__.py | 78 +++++++ management/server/services/files/service.py | 208 +++++++++++++++++- management/server/services/files/utils.py | 28 +++ management/web/src/common/apis/files/index.ts | 23 ++ management/web/src/pages/file/index.vue | 97 +++++++- management/web/types/auto/components.d.ts | 1 + 12 files changed, 657 insertions(+), 7 deletions(-) create mode 100644 management/server/check_tables.py create mode 100644 management/server/services/files/base_service.py create mode 100644 management/server/services/files/document_service.py create mode 100644 management/server/services/files/file2document_service.py create mode 100644 management/server/services/files/file_service.py create mode 100644 management/server/services/files/models/__init__.py create mode 100644 management/server/services/files/utils.py diff --git a/management/server/check_tables.py b/management/server/check_tables.py new file mode 100644 index 0000000..a56ea61 --- /dev/null +++ b/management/server/check_tables.py @@ -0,0 +1,55 @@ +import os +import mysql.connector +from dotenv import load_dotenv + +# 加载环境变量 +load_dotenv("../../docker/.env") + +# 数据库连接配置 +DB_CONFIG = { + "host": "localhost", + "port": int(os.getenv("MYSQL_PORT", "5455")), + "user": "root", + "password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"), + "database": "rag_flow" +} + +def get_db_connection(): + """创建数据库连接""" + return mysql.connector.connect(**DB_CONFIG) + +def get_all_tables(): + """获取数据库中所有表的名称""" + try: + # 连接数据库 + conn = get_db_connection() + cursor = conn.cursor() + + # 查询所有表名 + cursor.execute("SHOW TABLES") + tables = cursor.fetchall() + + print(f"数据库 {DB_CONFIG['database']} 中的表:") + if tables: + for i, table in enumerate(tables, 1): + print(f"{i}. {table[0]}") + else: + print("数据库中没有表") + + # 检查是否存在特定表 + important_tables = ['document', 'file', 'file2document'] + print("\n检查重要表是否存在:") + for table in important_tables: + cursor.execute(f"SHOW TABLES LIKE '{table}'") + exists = cursor.fetchone() is not None + status = "✓ 存在" if exists else "✗ 不存在" + print(f"{table}: {status}") + + cursor.close() + conn.close() + + except mysql.connector.Error as e: + print(f"数据库连接或查询出错: {e}") + +if __name__ == "__main__": + get_all_tables() \ No newline at end of file diff --git a/management/server/routes/files/routes.py b/management/server/routes/files/routes.py index ba4bc54..5023870 100644 --- a/management/server/routes/files/routes.py +++ b/management/server/routes/files/routes.py @@ -1,15 +1,39 @@ +import os from flask import jsonify, request, send_file, current_app from io import BytesIO from .. import files_bp +from flask import request, jsonify +from werkzeug.utils import secure_filename + + from services.files.service import ( get_files_list, get_file_info, download_file_from_minio, delete_file, batch_delete_files, - get_minio_client + get_minio_client, + upload_files_to_server ) +UPLOAD_FOLDER = '/data/uploads' +ALLOWED_EXTENSIONS = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif', 'doc', 'docx', 'xls', 'xlsx'} + +def allowed_file(filename): + return '.' in filename and \ + filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + +@files_bp.route('/upload', methods=['POST']) +def upload_file(): + if 'files' not in request.files: + return jsonify({'code': 400, 'message': '未选择文件'}), 400 + + files = request.files.getlist('files') + upload_result = upload_files_to_server(files) + + return jsonify(upload_result) + + @files_bp.route('', methods=['GET', 'OPTIONS']) def get_files(): """获取文件列表的API端点""" @@ -120,7 +144,7 @@ def download_file(file_id): "message": "文件下载失败", "details": str(e) }), 500 - + @files_bp.route('/', methods=['DELETE', 'OPTIONS']) def delete_file_route(file_id): """删除文件的API端点""" diff --git a/management/server/services/files/base_service.py b/management/server/services/files/base_service.py new file mode 100644 index 0000000..1ad0942 --- /dev/null +++ b/management/server/services/files/base_service.py @@ -0,0 +1,25 @@ +from peewee import Model +from typing import Type, TypeVar, Dict, Any + +T = TypeVar('T', bound=Model) + +class BaseService: + model: Type[T] + + @classmethod + def get_by_id(cls, id: str) -> T: + return cls.model.get_by_id(id) + + @classmethod + def insert(cls, data: Dict[str, Any]) -> T: + return cls.model.create(**data) + + @classmethod + def delete_by_id(cls, id: str) -> int: + return cls.model.delete().where(cls.model.id == id).execute() + + @classmethod + def query(cls, **kwargs) -> list[T]: + return list(cls.model.select().where(*[ + getattr(cls.model, k) == v for k, v in kwargs.items() + ])) \ No newline at end of file diff --git a/management/server/services/files/document_service.py b/management/server/services/files/document_service.py new file mode 100644 index 0000000..721f018 --- /dev/null +++ b/management/server/services/files/document_service.py @@ -0,0 +1,53 @@ +from peewee import * +from .base_service import BaseService +from .models import Document +from .utils import get_uuid, StatusEnum + +class DocumentService(BaseService): + model = Document + + @classmethod + def create_document(cls, kb_id: str, name: str, location: str, size: int, file_type: str, created_by: str = None, parser_id: str = None, parser_config: dict = None) -> Document: + """ + 创建文档记录 + + Args: + kb_id: 知识库ID + name: 文件名 + location: 存储位置 + size: 文件大小 + file_type: 文件类型 + created_by: 创建者ID + parser_id: 解析器ID + parser_config: 解析器配置 + + Returns: + Document: 创建的文档对象 + """ + doc_id = get_uuid() + + # 构建基本文档数据 + doc_data = { + 'id': doc_id, + 'kb_id': kb_id, + 'name': name, + 'location': location, + 'size': size, + 'type': file_type, + 'created_by': created_by or 'system', + 'parser_id': parser_id or '', + 'parser_config': parser_config or {"pages": [[1, 1000000]]}, + 'source_type': 'local', + 'token_num': 0, + 'chunk_num': 0, + 'progress': 0, + 'progress_msg': '', + 'run': '0', # 未开始解析 + 'status': StatusEnum.VALID.value + } + + return cls.insert(doc_data) + + @classmethod + def get_by_kb_id(cls, kb_id: str) -> list[Document]: + return cls.query(kb_id=kb_id) \ No newline at end of file diff --git a/management/server/services/files/file2document_service.py b/management/server/services/files/file2document_service.py new file mode 100644 index 0000000..f78d277 --- /dev/null +++ b/management/server/services/files/file2document_service.py @@ -0,0 +1,21 @@ +from peewee import * +from .base_service import BaseService +from .models import File2Document + +class File2DocumentService(BaseService): + model = File2Document + + @classmethod + def create_mapping(cls, file_id: str, document_id: str) -> File2Document: + return cls.insert({ + 'file_id': file_id, + 'document_id': document_id + }) + + @classmethod + def get_by_document_id(cls, document_id: str) -> list[File2Document]: + return cls.query(document_id=document_id) + + @classmethod + def get_by_file_id(cls, file_id: str) -> list[File2Document]: + return cls.query(file_id=file_id) \ No newline at end of file diff --git a/management/server/services/files/file_service.py b/management/server/services/files/file_service.py new file mode 100644 index 0000000..3b27e47 --- /dev/null +++ b/management/server/services/files/file_service.py @@ -0,0 +1,47 @@ +from peewee import * +from .base_service import BaseService +from .models import File +from .utils import FileType, get_uuid + +class FileService(BaseService): + model = File + + @classmethod + def create_file(cls, parent_id: str, name: str, location: str, size: int, file_type: str) -> File: + return cls.insert({ + 'parent_id': parent_id, + 'name': name, + 'location': location, + 'size': size, + 'type': file_type, + 'source_type': 'knowledgebase' + }) + + @classmethod + @classmethod + def get_parser(cls, file_type, filename, tenant_id): + """获取适合文件类型的解析器ID""" + # 这里可能需要根据实际情况调整 + if file_type == FileType.PDF.value: + return "pdf_parser" + elif file_type == FileType.WORD.value: + return "word_parser" + elif file_type == FileType.EXCEL.value: + return "excel_parser" + elif file_type == FileType.PPT.value: + return "ppt_parser" + elif file_type == FileType.VISUAL.value: + return "image_parser" + elif file_type == FileType.TEXT.value: # 添加对文本文件的支持 + return "text_parser" + else: + return "default_parser" + + @classmethod + def get_by_parent_id(cls, parent_id: str) -> list[File]: + return cls.query(parent_id=parent_id) + + @classmethod + def generate_bucket_name(cls): + """生成随机存储桶名称""" + return f"kb-{get_uuid()}" \ No newline at end of file diff --git a/management/server/services/files/models/__init__.py b/management/server/services/files/models/__init__.py new file mode 100644 index 0000000..527a924 --- /dev/null +++ b/management/server/services/files/models/__init__.py @@ -0,0 +1,78 @@ +from peewee import * +import os +from datetime import datetime + +# 数据库连接配置 +DB_CONFIG = { + "host": "localhost", + "port": int(os.getenv("MYSQL_PORT", "5455")), + "user": "root", + "password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"), + "database": "rag_flow" +} + +# 使用MySQL数据库 +db = MySQLDatabase( + DB_CONFIG["database"], + host=DB_CONFIG["host"], + port=DB_CONFIG["port"], + user=DB_CONFIG["user"], + password=DB_CONFIG["password"] +) + +class BaseModel(Model): + # 添加共同的时间戳字段 + create_time = BigIntegerField(null=True) + create_date = CharField(null=True) + update_time = BigIntegerField(null=True) + update_date = CharField(null=True) + + class Meta: + database = db + +class Document(BaseModel): + id = CharField(primary_key=True) + thumbnail = TextField(null=True) + kb_id = CharField(index=True) + parser_id = CharField(null=True, index=True) + parser_config = TextField(null=True) # JSONField在SQLite中用TextField替代 + source_type = CharField(default="local", index=True) + type = CharField(index=True) + created_by = CharField(null=True, index=True) + name = CharField(null=True, index=True) + location = CharField(null=True) + size = IntegerField(default=0) + token_num = IntegerField(default=0) + chunk_num = IntegerField(default=0) + progress = FloatField(default=0) + progress_msg = TextField(null=True, default="") + process_begin_at = DateTimeField(null=True) + process_duation = FloatField(default=0) + meta_fields = TextField(null=True) # JSONField + run = CharField(default="0") + status = CharField(default="1") + + class Meta: + db_table = "document" + +class File(BaseModel): + id = CharField(primary_key=True) + parent_id = CharField(index=True) + tenant_id = CharField(null=True, index=True) + created_by = CharField(null=True, index=True) + name = CharField(index=True) + location = CharField(null=True) + size = IntegerField(default=0) + type = CharField(index=True) + source_type = CharField(default="", index=True) + + class Meta: + db_table = "file" + +class File2Document(BaseModel): + id = CharField(primary_key=True) + file_id = CharField(index=True) + document_id = CharField(index=True) + + class Meta: + db_table = "file2document" \ No newline at end of file diff --git a/management/server/services/files/service.py b/management/server/services/files/service.py index 17de3c8..062e8a0 100644 --- a/management/server/services/files/service.py +++ b/management/server/services/files/service.py @@ -1,15 +1,26 @@ import os import mysql.connector +import re from io import BytesIO from minio import Minio from dotenv import load_dotenv +from werkzeug.utils import secure_filename +from datetime import datetime +from .utils import FileType, FileSource, StatusEnum, get_uuid +from .document_service import DocumentService +from .file_service import FileService +from .file2document_service import File2DocumentService + # 加载环境变量 load_dotenv("../../docker/.env") +UPLOAD_FOLDER = '/data/uploads' +ALLOWED_EXTENSIONS = {'pdf', 'doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx', 'jpg', 'jpeg', 'png', 'txt', 'md'} + # 数据库连接配置 DB_CONFIG = { - "host": "localhost", # 如果在Docker外运行,使用localhost + "host": "localhost", "port": int(os.getenv("MYSQL_PORT", "5455")), "user": "root", "password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"), @@ -24,6 +35,31 @@ MINIO_CONFIG = { "secure": False } + +def allowed_file(filename): + """Check if the file extension is allowed""" + return '.' in filename and \ + filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + +def filename_type(filename): + """根据文件名确定文件类型""" + ext = os.path.splitext(filename)[1].lower() + + if ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']: + return FileType.VISUAL.value + elif ext in ['.pdf']: + return FileType.PDF.value + elif ext in ['.doc', '.docx']: + return FileType.WORD.value + elif ext in ['.xls', '.xlsx']: + return FileType.EXCEL.value + elif ext in ['.ppt', '.pptx']: + return FileType.PPT.value + elif ext in ['.txt', '.md']: # 添加对 txt 和 md 文件的支持 + return FileType.TEXT.value + + return FileType.OTHER.value + def get_minio_client(): """创建MinIO客户端""" return Minio( @@ -397,4 +433,172 @@ def batch_delete_files(file_ids): conn.close() except Exception as e: - raise e \ No newline at end of file + raise e + +def upload_files_to_server(files, kb_id=None, user_id=None): + """处理文件上传到服务器的核心逻辑""" + results = [] + + for file in files: + if file.filename == '': + continue + + if file and allowed_file(file.filename): + # 为每个文件生成独立的存储桶名称 + file_bucket_id = FileService.generate_bucket_name() + original_filename = file.filename + # 修复文件名处理逻辑,保留中文字符 + name, ext = os.path.splitext(original_filename) + + # 保留中文字符,但替换不安全字符 + # 只替换文件系统不安全的字符,保留中文和其他Unicode字符 + safe_name = re.sub(r'[\\/:*?"<>|]', '_', name) + + # 如果处理后文件名为空,则使用随机字符串 + if not safe_name or safe_name.strip() == '': + safe_name = f"file_{get_uuid()[:8]}" + + filename = safe_name + ext.lower() + filepath = os.path.join(UPLOAD_FOLDER, filename) + + try: + # 1. 保存文件到本地临时目录 + os.makedirs(UPLOAD_FOLDER, exist_ok=True) + file.save(filepath) + print(f"文件已保存到临时目录: {filepath}") + print(f"原始文件名: {original_filename}, 处理后文件名: {filename}, 扩展名: {ext[1:]}") # 修改打印信息 + + # 2. 获取文件类型 - 使用修复后的文件名 + filetype = filename_type(filename) + if filetype == FileType.OTHER.value: + raise RuntimeError("不支持的文件类型") + + # 3. 生成唯一存储位置 + minio_client = get_minio_client() + location = filename + + # 确保bucket存在(使用文件独立的bucket) + if not minio_client.bucket_exists(file_bucket_id): + minio_client.make_bucket(file_bucket_id) + print(f"创建MinIO存储桶: {file_bucket_id}") + + # 4. 上传到MinIO(使用文件独立的bucket) + with open(filepath, 'rb') as file_data: + minio_client.put_object( + bucket_name=file_bucket_id, + object_name=location, + data=file_data, + length=os.path.getsize(filepath) + ) + print(f"文件已上传到MinIO: {file_bucket_id}/{location}") + + # 5. 创建缩略图(如果是图片/PDF等) + thumbnail_location = '' + if filetype in [FileType.VISUAL.value, FileType.PDF.value]: + try: + thumbnail_location = f'thumbnail_{get_uuid()}.png' + except Exception as e: + print(f"生成缩略图失败: {str(e)}") + + # 6. 创建数据库记录 + doc_id = get_uuid() + current_time = int(datetime.now().timestamp()) + current_date = datetime.now().strftime('%Y-%m-%d') + + doc = { + "id": doc_id, + "kb_id": file_bucket_id, # 使用文件独立的bucket_id + "parser_id": FileService.get_parser(filetype, filename, ""), + "parser_config": {"pages": [[1, 1000000]]}, + "source_type": "local", + "created_by": user_id or 'system', + "type": filetype, + "name": filename, + "location": location, + "size": os.path.getsize(filepath), + "thumbnail": thumbnail_location, + "token_num": 0, + "chunk_num": 0, + "progress": 0, + "progress_msg": "", + "run": "0", + "status": StatusEnum.VALID.value, + "create_time": current_time, + "create_date": current_date, + "update_time": current_time, + "update_date": current_date + } + + # 7. 保存文档记录 (添加事务处理) + conn = get_db_connection() + try: + cursor = conn.cursor() + DocumentService.insert(doc) + print(f"文档记录已保存到MySQL: {doc_id}") + + # 8. 创建文件记录和关联 + file_record = { + "id": get_uuid(), + "parent_id": file_bucket_id, # 使用文件独立的bucket_id + "tenant_id": user_id or 'system', + "created_by": user_id or 'system', + "name": filename, + "type": filetype, + "size": doc["size"], + "location": location, + "source_type": FileSource.KNOWLEDGEBASE.value, + "create_time": current_time, + "create_date": current_date, + "update_time": current_time, + "update_date": current_date + } + FileService.insert(file_record) + print(f"文件记录已保存到MySQL: {file_record['id']}") + + # 9. 创建文件-文档关联 + File2DocumentService.insert({ + "id": get_uuid(), + "file_id": file_record["id"], + "document_id": doc_id, + "create_time": current_time, + "create_date": current_date, + "update_time": current_time, + "update_date": current_date + }) + print(f"关联记录已保存到MySQL: {file_record['id']} -> {doc_id}") + + conn.commit() + + results.append({ + 'id': doc_id, + 'name': filename, + 'size': doc["size"], + 'type': filetype, + 'status': 'success' + }) + + except Exception as e: + conn.rollback() + print(f"数据库操作失败: {str(e)}") + raise + finally: + cursor.close() + conn.close() + + except Exception as e: + results.append({ + 'name': filename, + 'error': str(e), + 'status': 'failed' + }) + print(f"文件上传过程中出错: {filename}, 错误: {str(e)}") + finally: + # 删除临时文件 + if os.path.exists(filepath): + os.remove(filepath) + + return { + 'code': 0, + 'data': results, + 'message': f'成功上传 {len([r for r in results if r["status"] == "success"])}/{len(files)} 个文件' + } \ No newline at end of file diff --git a/management/server/services/files/utils.py b/management/server/services/files/utils.py new file mode 100644 index 0000000..aea4d1c --- /dev/null +++ b/management/server/services/files/utils.py @@ -0,0 +1,28 @@ +import uuid +from strenum import StrEnum +from enum import Enum + + +# 参考:api.db +class FileType(StrEnum): + FOLDER = "folder" + PDF = "pdf" + WORD = "word" + EXCEL = "excel" + PPT = "ppt" + VISUAL = "visual" + TEXT = "txt" + OTHER = "other" + +class FileSource(StrEnum): + LOCAL = "" + KNOWLEDGEBASE = "knowledgebase" + S3 = "s3" + +class StatusEnum(Enum): + VALID = "1" + INVALID = "0" + +# 参考:api.utils +def get_uuid(): + return uuid.uuid1().hex \ No newline at end of file diff --git a/management/web/src/common/apis/files/index.ts b/management/web/src/common/apis/files/index.ts index 0f9f8e7..8910213 100644 --- a/management/web/src/common/apis/files/index.ts +++ b/management/web/src/common/apis/files/index.ts @@ -91,3 +91,26 @@ export function batchDeleteFilesApi(fileIds: string[]) { data: { ids: fileIds } }) } + +/** + * 上传文件 + */ +export function uploadFileApi(formData: FormData) { + return request<{ + code: number + data: Array<{ + name: string + size: number + type: string + status: string + }> + message: string + }>({ + url: "/api/v1/files/upload", + method: "post", + data: formData, + headers: { + "Content-Type": "multipart/form-data" + } + }) +} diff --git a/management/web/src/pages/file/index.vue b/management/web/src/pages/file/index.vue index f9def3b..3c88ca2 100644 --- a/management/web/src/pages/file/index.vue +++ b/management/web/src/pages/file/index.vue @@ -1,9 +1,10 @@