feat: 添加文件上传功能并完善文件管理服务 (#21)

- 在管理界面添加文件上传功能,支持多文件上传
- 实现文件上传到服务器的核心逻辑,包括文件类型检查、文件名处理、文件存储等
- 完善文件管理服务,包括文件、文档及其关联关系的数据库操作
- 添加文件类型枚举和工具函数,支持多种文件格式
- 更新前端页面,添加文件上传对话框和上传按钮
This commit is contained in:
zstar 2025-04-12 00:42:19 +08:00 committed by GitHub
parent 5d900c3883
commit 9689a2efd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 657 additions and 7 deletions

View File

@ -0,0 +1,55 @@
import os
import mysql.connector
from dotenv import load_dotenv
# 加载环境变量
load_dotenv("../../docker/.env")
# 数据库连接配置
DB_CONFIG = {
"host": "localhost",
"port": int(os.getenv("MYSQL_PORT", "5455")),
"user": "root",
"password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"),
"database": "rag_flow"
}
def get_db_connection():
"""创建数据库连接"""
return mysql.connector.connect(**DB_CONFIG)
def get_all_tables():
"""获取数据库中所有表的名称"""
try:
# 连接数据库
conn = get_db_connection()
cursor = conn.cursor()
# 查询所有表名
cursor.execute("SHOW TABLES")
tables = cursor.fetchall()
print(f"数据库 {DB_CONFIG['database']} 中的表:")
if tables:
for i, table in enumerate(tables, 1):
print(f"{i}. {table[0]}")
else:
print("数据库中没有表")
# 检查是否存在特定表
important_tables = ['document', 'file', 'file2document']
print("\n检查重要表是否存在:")
for table in important_tables:
cursor.execute(f"SHOW TABLES LIKE '{table}'")
exists = cursor.fetchone() is not None
status = "✓ 存在" if exists else "✗ 不存在"
print(f"{table}: {status}")
cursor.close()
conn.close()
except mysql.connector.Error as e:
print(f"数据库连接或查询出错: {e}")
if __name__ == "__main__":
get_all_tables()

View File

@ -1,15 +1,39 @@
import os
from flask import jsonify, request, send_file, current_app
from io import BytesIO
from .. import files_bp
from flask import request, jsonify
from werkzeug.utils import secure_filename
from services.files.service import (
get_files_list,
get_file_info,
download_file_from_minio,
delete_file,
batch_delete_files,
get_minio_client
get_minio_client,
upload_files_to_server
)
UPLOAD_FOLDER = '/data/uploads'
ALLOWED_EXTENSIONS = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif', 'doc', 'docx', 'xls', 'xlsx'}
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@files_bp.route('/upload', methods=['POST'])
def upload_file():
if 'files' not in request.files:
return jsonify({'code': 400, 'message': '未选择文件'}), 400
files = request.files.getlist('files')
upload_result = upload_files_to_server(files)
return jsonify(upload_result)
@files_bp.route('', methods=['GET', 'OPTIONS'])
def get_files():
"""获取文件列表的API端点"""
@ -120,7 +144,7 @@ def download_file(file_id):
"message": "文件下载失败",
"details": str(e)
}), 500
@files_bp.route('/<string:file_id>', methods=['DELETE', 'OPTIONS'])
def delete_file_route(file_id):
"""删除文件的API端点"""

View File

@ -0,0 +1,25 @@
from peewee import Model
from typing import Type, TypeVar, Dict, Any
T = TypeVar('T', bound=Model)
class BaseService:
model: Type[T]
@classmethod
def get_by_id(cls, id: str) -> T:
return cls.model.get_by_id(id)
@classmethod
def insert(cls, data: Dict[str, Any]) -> T:
return cls.model.create(**data)
@classmethod
def delete_by_id(cls, id: str) -> int:
return cls.model.delete().where(cls.model.id == id).execute()
@classmethod
def query(cls, **kwargs) -> list[T]:
return list(cls.model.select().where(*[
getattr(cls.model, k) == v for k, v in kwargs.items()
]))

View File

@ -0,0 +1,53 @@
from peewee import *
from .base_service import BaseService
from .models import Document
from .utils import get_uuid, StatusEnum
class DocumentService(BaseService):
model = Document
@classmethod
def create_document(cls, kb_id: str, name: str, location: str, size: int, file_type: str, created_by: str = None, parser_id: str = None, parser_config: dict = None) -> Document:
"""
创建文档记录
Args:
kb_id: 知识库ID
name: 文件名
location: 存储位置
size: 文件大小
file_type: 文件类型
created_by: 创建者ID
parser_id: 解析器ID
parser_config: 解析器配置
Returns:
Document: 创建的文档对象
"""
doc_id = get_uuid()
# 构建基本文档数据
doc_data = {
'id': doc_id,
'kb_id': kb_id,
'name': name,
'location': location,
'size': size,
'type': file_type,
'created_by': created_by or 'system',
'parser_id': parser_id or '',
'parser_config': parser_config or {"pages": [[1, 1000000]]},
'source_type': 'local',
'token_num': 0,
'chunk_num': 0,
'progress': 0,
'progress_msg': '',
'run': '0', # 未开始解析
'status': StatusEnum.VALID.value
}
return cls.insert(doc_data)
@classmethod
def get_by_kb_id(cls, kb_id: str) -> list[Document]:
return cls.query(kb_id=kb_id)

View File

@ -0,0 +1,21 @@
from peewee import *
from .base_service import BaseService
from .models import File2Document
class File2DocumentService(BaseService):
model = File2Document
@classmethod
def create_mapping(cls, file_id: str, document_id: str) -> File2Document:
return cls.insert({
'file_id': file_id,
'document_id': document_id
})
@classmethod
def get_by_document_id(cls, document_id: str) -> list[File2Document]:
return cls.query(document_id=document_id)
@classmethod
def get_by_file_id(cls, file_id: str) -> list[File2Document]:
return cls.query(file_id=file_id)

View File

@ -0,0 +1,47 @@
from peewee import *
from .base_service import BaseService
from .models import File
from .utils import FileType, get_uuid
class FileService(BaseService):
model = File
@classmethod
def create_file(cls, parent_id: str, name: str, location: str, size: int, file_type: str) -> File:
return cls.insert({
'parent_id': parent_id,
'name': name,
'location': location,
'size': size,
'type': file_type,
'source_type': 'knowledgebase'
})
@classmethod
@classmethod
def get_parser(cls, file_type, filename, tenant_id):
"""获取适合文件类型的解析器ID"""
# 这里可能需要根据实际情况调整
if file_type == FileType.PDF.value:
return "pdf_parser"
elif file_type == FileType.WORD.value:
return "word_parser"
elif file_type == FileType.EXCEL.value:
return "excel_parser"
elif file_type == FileType.PPT.value:
return "ppt_parser"
elif file_type == FileType.VISUAL.value:
return "image_parser"
elif file_type == FileType.TEXT.value: # 添加对文本文件的支持
return "text_parser"
else:
return "default_parser"
@classmethod
def get_by_parent_id(cls, parent_id: str) -> list[File]:
return cls.query(parent_id=parent_id)
@classmethod
def generate_bucket_name(cls):
"""生成随机存储桶名称"""
return f"kb-{get_uuid()}"

View File

@ -0,0 +1,78 @@
from peewee import *
import os
from datetime import datetime
# 数据库连接配置
DB_CONFIG = {
"host": "localhost",
"port": int(os.getenv("MYSQL_PORT", "5455")),
"user": "root",
"password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"),
"database": "rag_flow"
}
# 使用MySQL数据库
db = MySQLDatabase(
DB_CONFIG["database"],
host=DB_CONFIG["host"],
port=DB_CONFIG["port"],
user=DB_CONFIG["user"],
password=DB_CONFIG["password"]
)
class BaseModel(Model):
# 添加共同的时间戳字段
create_time = BigIntegerField(null=True)
create_date = CharField(null=True)
update_time = BigIntegerField(null=True)
update_date = CharField(null=True)
class Meta:
database = db
class Document(BaseModel):
id = CharField(primary_key=True)
thumbnail = TextField(null=True)
kb_id = CharField(index=True)
parser_id = CharField(null=True, index=True)
parser_config = TextField(null=True) # JSONField在SQLite中用TextField替代
source_type = CharField(default="local", index=True)
type = CharField(index=True)
created_by = CharField(null=True, index=True)
name = CharField(null=True, index=True)
location = CharField(null=True)
size = IntegerField(default=0)
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
progress = FloatField(default=0)
progress_msg = TextField(null=True, default="")
process_begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
meta_fields = TextField(null=True) # JSONField
run = CharField(default="0")
status = CharField(default="1")
class Meta:
db_table = "document"
class File(BaseModel):
id = CharField(primary_key=True)
parent_id = CharField(index=True)
tenant_id = CharField(null=True, index=True)
created_by = CharField(null=True, index=True)
name = CharField(index=True)
location = CharField(null=True)
size = IntegerField(default=0)
type = CharField(index=True)
source_type = CharField(default="", index=True)
class Meta:
db_table = "file"
class File2Document(BaseModel):
id = CharField(primary_key=True)
file_id = CharField(index=True)
document_id = CharField(index=True)
class Meta:
db_table = "file2document"

View File

@ -1,15 +1,26 @@
import os
import mysql.connector
import re
from io import BytesIO
from minio import Minio
from dotenv import load_dotenv
from werkzeug.utils import secure_filename
from datetime import datetime
from .utils import FileType, FileSource, StatusEnum, get_uuid
from .document_service import DocumentService
from .file_service import FileService
from .file2document_service import File2DocumentService
# 加载环境变量
load_dotenv("../../docker/.env")
UPLOAD_FOLDER = '/data/uploads'
ALLOWED_EXTENSIONS = {'pdf', 'doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx', 'jpg', 'jpeg', 'png', 'txt', 'md'}
# 数据库连接配置
DB_CONFIG = {
"host": "localhost", # 如果在Docker外运行使用localhost
"host": "localhost",
"port": int(os.getenv("MYSQL_PORT", "5455")),
"user": "root",
"password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"),
@ -24,6 +35,31 @@ MINIO_CONFIG = {
"secure": False
}
def allowed_file(filename):
"""Check if the file extension is allowed"""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def filename_type(filename):
"""根据文件名确定文件类型"""
ext = os.path.splitext(filename)[1].lower()
if ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']:
return FileType.VISUAL.value
elif ext in ['.pdf']:
return FileType.PDF.value
elif ext in ['.doc', '.docx']:
return FileType.WORD.value
elif ext in ['.xls', '.xlsx']:
return FileType.EXCEL.value
elif ext in ['.ppt', '.pptx']:
return FileType.PPT.value
elif ext in ['.txt', '.md']: # 添加对 txt 和 md 文件的支持
return FileType.TEXT.value
return FileType.OTHER.value
def get_minio_client():
"""创建MinIO客户端"""
return Minio(
@ -397,4 +433,172 @@ def batch_delete_files(file_ids):
conn.close()
except Exception as e:
raise e
raise e
def upload_files_to_server(files, kb_id=None, user_id=None):
"""处理文件上传到服务器的核心逻辑"""
results = []
for file in files:
if file.filename == '':
continue
if file and allowed_file(file.filename):
# 为每个文件生成独立的存储桶名称
file_bucket_id = FileService.generate_bucket_name()
original_filename = file.filename
# 修复文件名处理逻辑,保留中文字符
name, ext = os.path.splitext(original_filename)
# 保留中文字符,但替换不安全字符
# 只替换文件系统不安全的字符保留中文和其他Unicode字符
safe_name = re.sub(r'[\\/:*?"<>|]', '_', name)
# 如果处理后文件名为空,则使用随机字符串
if not safe_name or safe_name.strip() == '':
safe_name = f"file_{get_uuid()[:8]}"
filename = safe_name + ext.lower()
filepath = os.path.join(UPLOAD_FOLDER, filename)
try:
# 1. 保存文件到本地临时目录
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
file.save(filepath)
print(f"文件已保存到临时目录: {filepath}")
print(f"原始文件名: {original_filename}, 处理后文件名: {filename}, 扩展名: {ext[1:]}") # 修改打印信息
# 2. 获取文件类型 - 使用修复后的文件名
filetype = filename_type(filename)
if filetype == FileType.OTHER.value:
raise RuntimeError("不支持的文件类型")
# 3. 生成唯一存储位置
minio_client = get_minio_client()
location = filename
# 确保bucket存在使用文件独立的bucket
if not minio_client.bucket_exists(file_bucket_id):
minio_client.make_bucket(file_bucket_id)
print(f"创建MinIO存储桶: {file_bucket_id}")
# 4. 上传到MinIO使用文件独立的bucket
with open(filepath, 'rb') as file_data:
minio_client.put_object(
bucket_name=file_bucket_id,
object_name=location,
data=file_data,
length=os.path.getsize(filepath)
)
print(f"文件已上传到MinIO: {file_bucket_id}/{location}")
# 5. 创建缩略图(如果是图片/PDF等)
thumbnail_location = ''
if filetype in [FileType.VISUAL.value, FileType.PDF.value]:
try:
thumbnail_location = f'thumbnail_{get_uuid()}.png'
except Exception as e:
print(f"生成缩略图失败: {str(e)}")
# 6. 创建数据库记录
doc_id = get_uuid()
current_time = int(datetime.now().timestamp())
current_date = datetime.now().strftime('%Y-%m-%d')
doc = {
"id": doc_id,
"kb_id": file_bucket_id, # 使用文件独立的bucket_id
"parser_id": FileService.get_parser(filetype, filename, ""),
"parser_config": {"pages": [[1, 1000000]]},
"source_type": "local",
"created_by": user_id or 'system',
"type": filetype,
"name": filename,
"location": location,
"size": os.path.getsize(filepath),
"thumbnail": thumbnail_location,
"token_num": 0,
"chunk_num": 0,
"progress": 0,
"progress_msg": "",
"run": "0",
"status": StatusEnum.VALID.value,
"create_time": current_time,
"create_date": current_date,
"update_time": current_time,
"update_date": current_date
}
# 7. 保存文档记录 (添加事务处理)
conn = get_db_connection()
try:
cursor = conn.cursor()
DocumentService.insert(doc)
print(f"文档记录已保存到MySQL: {doc_id}")
# 8. 创建文件记录和关联
file_record = {
"id": get_uuid(),
"parent_id": file_bucket_id, # 使用文件独立的bucket_id
"tenant_id": user_id or 'system',
"created_by": user_id or 'system',
"name": filename,
"type": filetype,
"size": doc["size"],
"location": location,
"source_type": FileSource.KNOWLEDGEBASE.value,
"create_time": current_time,
"create_date": current_date,
"update_time": current_time,
"update_date": current_date
}
FileService.insert(file_record)
print(f"文件记录已保存到MySQL: {file_record['id']}")
# 9. 创建文件-文档关联
File2DocumentService.insert({
"id": get_uuid(),
"file_id": file_record["id"],
"document_id": doc_id,
"create_time": current_time,
"create_date": current_date,
"update_time": current_time,
"update_date": current_date
})
print(f"关联记录已保存到MySQL: {file_record['id']} -> {doc_id}")
conn.commit()
results.append({
'id': doc_id,
'name': filename,
'size': doc["size"],
'type': filetype,
'status': 'success'
})
except Exception as e:
conn.rollback()
print(f"数据库操作失败: {str(e)}")
raise
finally:
cursor.close()
conn.close()
except Exception as e:
results.append({
'name': filename,
'error': str(e),
'status': 'failed'
})
print(f"文件上传过程中出错: {filename}, 错误: {str(e)}")
finally:
# 删除临时文件
if os.path.exists(filepath):
os.remove(filepath)
return {
'code': 0,
'data': results,
'message': f'成功上传 {len([r for r in results if r["status"] == "success"])}/{len(files)} 个文件'
}

View File

@ -0,0 +1,28 @@
import uuid
from strenum import StrEnum
from enum import Enum
# 参考api.db
class FileType(StrEnum):
FOLDER = "folder"
PDF = "pdf"
WORD = "word"
EXCEL = "excel"
PPT = "ppt"
VISUAL = "visual"
TEXT = "txt"
OTHER = "other"
class FileSource(StrEnum):
LOCAL = ""
KNOWLEDGEBASE = "knowledgebase"
S3 = "s3"
class StatusEnum(Enum):
VALID = "1"
INVALID = "0"
# 参考api.utils
def get_uuid():
return uuid.uuid1().hex

View File

@ -91,3 +91,26 @@ export function batchDeleteFilesApi(fileIds: string[]) {
data: { ids: fileIds }
})
}
/**
*
*/
export function uploadFileApi(formData: FormData) {
return request<{
code: number
data: Array<{
name: string
size: number
type: string
status: string
}>
message: string
}>({
url: "/api/v1/files/upload",
method: "post",
data: formData,
headers: {
"Content-Type": "multipart/form-data"
}
})
}

View File

@ -1,9 +1,10 @@
<script lang="ts" setup>
import type { FormInstance } from "element-plus"
import { batchDeleteFilesApi, deleteFileApi, getFileListApi } from "@@/apis/files"
import type { FormInstance, UploadUserFile } from "element-plus"
import { batchDeleteFilesApi, deleteFileApi, getFileListApi, uploadFileApi } from "@@/apis/files"
import { usePagination } from "@@/composables/usePagination"
import { Delete, Download, Refresh, Search } from "@element-plus/icons-vue"
import { Delete, Download, Refresh, Search, Upload } from "@element-plus/icons-vue"
import { ElMessage, ElMessageBox } from "element-plus"
import { ref } from "vue"
import "element-plus/dist/index.css"
import "element-plus/theme-chalk/el-message-box.css"
import "element-plus/theme-chalk/el-message.css"
@ -15,6 +16,9 @@ defineOptions({
const loading = ref<boolean>(false)
const { paginationData, handleCurrentChange, handleSizeChange } = usePagination()
const uploadDialogVisible = ref(false)
const uploadFileList = ref<UploadUserFile[]>([])
const uploadLoading = ref(false)
//
interface FileData {
@ -68,6 +72,37 @@ function resetSearch() {
handleSearch()
}
//
function handleUpload() {
uploadDialogVisible.value = true
}
async function submitUpload() {
uploadLoading.value = true
try {
const formData = new FormData()
uploadFileList.value.forEach((file) => {
if (file.raw) {
formData.append("files", file.raw)
}
})
await uploadFileApi(formData)
ElMessage.success("文件上传成功")
getTableData()
uploadDialogVisible.value = false
uploadFileList.value = []
} catch (error: unknown) {
let errorMessage = "上传失败"
if (error instanceof Error) {
errorMessage += `: ${error.message}`
}
ElMessage.error(errorMessage)
} finally {
uploadLoading.value = false
}
}
//
async function handleDownload(row: FileData) {
const loadingInstance = ElLoading.service({
@ -274,6 +309,13 @@ onActivated(() => {
<el-card v-loading="loading" shadow="never">
<div class="toolbar-wrapper">
<div>
<el-button
type="primary"
:icon="Upload"
@click="handleUpload"
>
上传文件
</el-button>
<el-button
type="danger"
:icon="Delete"
@ -284,6 +326,38 @@ onActivated(() => {
</el-button>
</div>
</div>
<!-- 上传对话框 -->
<el-dialog
v-model="uploadDialogVisible"
title="上传文件"
width="30%"
>
<el-upload
v-model:file-list="uploadFileList"
multiple
:auto-upload="false"
drag
>
<el-icon class="el-icon--upload">
<Upload />
</el-icon>
<div class="el-upload__text">
拖拽文件到此处或<em>点击上传</em>
</div>
</el-upload>
<template #footer>
<el-button @click="uploadDialogVisible = false">
取消
</el-button>
<el-button
type="primary"
:loading="uploadLoading"
@click="submitUpload"
>
确认上传
</el-button>
</template>
</el-dialog>
<div class="table-wrapper">
<el-table :data="tableData" @selection-change="handleSelectionChange">
<el-table-column type="selection" width="50" align="center" />
@ -434,4 +508,21 @@ onActivated(() => {
.delete-confirm-dialog .el-message-box__status {
display: none !important;
}
.toolbar-wrapper {
display: flex;
justify-content: space-between;
margin-bottom: 20px;
.el-button {
margin-right: 10px;
}
}
.upload-dialog {
.el-upload-dragger {
width: 100%;
padding: 20px;
}
}
</style>

View File

@ -48,6 +48,7 @@ declare module 'vue' {
ElTabs: typeof import('element-plus/es')['ElTabs']
ElTag: typeof import('element-plus/es')['ElTag']
ElTooltip: typeof import('element-plus/es')['ElTooltip']
ElUpload: typeof import('element-plus/es')['ElUpload']
RouterLink: typeof import('vue-router')['RouterLink']
RouterView: typeof import('vue-router')['RouterView']
}