From 07054fa7c38a56dd32f6f04233509b027e425346 Mon Sep 17 00:00:00 2001 From: zstar <65890619+zstar1003@users.noreply.github.com> Date: Sat, 12 Apr 2025 16:40:35 +0800 Subject: [PATCH] =?UTF-8?q?refactor(database):=20=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E9=85=8D=E7=BD=AE=E5=B9=B6=E7=A7=BB?= =?UTF-8?q?=E9=99=A4=E9=87=8D=E5=A4=8D=E4=BB=A3=E7=A0=81=20(#22)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将数据库配置从各个服务文件中移除,统一到 `database.py` 中,减少代码重复。 --- management/server/database.py | 71 +++++++++++- .../server/services/files/file_service.py | 1 - .../server/services/files/models/__init__.py | 10 +- management/server/services/files/service.py | 107 ++++++++++++++---- management/server/services/teams/service.py | 14 +-- management/server/services/tenants/service.py | 6 +- management/server/services/users/service.py | 10 +- 7 files changed, 167 insertions(+), 52 deletions(-) diff --git a/management/server/database.py b/management/server/database.py index 01ea2f9..4e1f9f8 100644 --- a/management/server/database.py +++ b/management/server/database.py @@ -2,6 +2,11 @@ import mysql.connector import os from utils import generate_uuid, encrypt_password from datetime import datetime +from minio import Minio +from dotenv import load_dotenv + +# 加载环境变量 +load_dotenv("../../docker/.env") # 检测是否在Docker容器中运行 def is_running_in_docker(): @@ -16,12 +21,70 @@ def is_running_in_docker(): # 根据运行环境选择合适的主机地址 DB_HOST = 'host.docker.internal' if is_running_in_docker() else 'localhost' +MINIO_HOST = 'host.docker.internal' if is_running_in_docker() else 'localhost' # 数据库连接配置 -db_config = { +DB_CONFIG = { "host": DB_HOST, - "port": 5455, + "port": int(os.getenv("MYSQL_PORT", "5455")), "user": "root", - "password": "infini_rag_flow", + "password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"), "database": "rag_flow", -} \ No newline at end of file +} + +# MinIO连接配置 +MINIO_CONFIG = { + "endpoint": f"{MINIO_HOST}:{os.getenv('MINIO_PORT', '9000')}", + "access_key": os.getenv("MINIO_USER", "rag_flow"), + "secret_key": os.getenv("MINIO_PASSWORD", "infini_rag_flow"), + "secure": False +} + +def get_db_connection(): + """创建MySQL数据库连接""" + try: + conn = mysql.connector.connect(**DB_CONFIG) + return conn + except Exception as e: + print(f"MySQL连接失败: {str(e)}") + raise e + +def get_minio_client(): + """创建MinIO客户端连接""" + try: + minio_client = Minio( + endpoint=MINIO_CONFIG["endpoint"], + access_key=MINIO_CONFIG["access_key"], + secret_key=MINIO_CONFIG["secret_key"], + secure=MINIO_CONFIG["secure"] + ) + return minio_client + except Exception as e: + print(f"MinIO连接失败: {str(e)}") + raise e + +def test_connections(): + """测试数据库和MinIO连接""" + try: + # 测试MySQL连接 + db_conn = get_db_connection() + cursor = db_conn.cursor() + cursor.execute("SELECT 1") + cursor.fetchone() + cursor.close() + db_conn.close() + print("MySQL连接测试成功") + + # 测试MinIO连接 + minio_client = get_minio_client() + buckets = minio_client.list_buckets() + print(f"MinIO连接测试成功,共有 {len(buckets)} 个存储桶") + + return True + except Exception as e: + print(f"连接测试失败: {str(e)}") + return False + +if __name__ == "__main__": + # 如果直接运行此文件,则测试连接 + test_connections() \ No newline at end of file diff --git a/management/server/services/files/file_service.py b/management/server/services/files/file_service.py index 3b27e47..98a25b0 100644 --- a/management/server/services/files/file_service.py +++ b/management/server/services/files/file_service.py @@ -17,7 +17,6 @@ class FileService(BaseService): 'source_type': 'knowledgebase' }) - @classmethod @classmethod def get_parser(cls, file_type, filename, tenant_id): """获取适合文件类型的解析器ID""" diff --git a/management/server/services/files/models/__init__.py b/management/server/services/files/models/__init__.py index 527a924..7915417 100644 --- a/management/server/services/files/models/__init__.py +++ b/management/server/services/files/models/__init__.py @@ -1,15 +1,7 @@ from peewee import * import os from datetime import datetime - -# 数据库连接配置 -DB_CONFIG = { - "host": "localhost", - "port": int(os.getenv("MYSQL_PORT", "5455")), - "user": "root", - "password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"), - "database": "rag_flow" -} +from database import DB_CONFIG # 使用MySQL数据库 db = MySQLDatabase( diff --git a/management/server/services/files/service.py b/management/server/services/files/service.py index 062e8a0..0218684 100644 --- a/management/server/services/files/service.py +++ b/management/server/services/files/service.py @@ -10,7 +10,7 @@ from .utils import FileType, FileSource, StatusEnum, get_uuid from .document_service import DocumentService from .file_service import FileService from .file2document_service import File2DocumentService - +from database import DB_CONFIG, MINIO_CONFIG # 加载环境变量 load_dotenv("../../docker/.env") @@ -18,24 +18,6 @@ load_dotenv("../../docker/.env") UPLOAD_FOLDER = '/data/uploads' ALLOWED_EXTENSIONS = {'pdf', 'doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx', 'jpg', 'jpeg', 'png', 'txt', 'md'} -# 数据库连接配置 -DB_CONFIG = { - "host": "localhost", - "port": int(os.getenv("MYSQL_PORT", "5455")), - "user": "root", - "password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"), - "database": "rag_flow" -} - -# MinIO连接配置 -MINIO_CONFIG = { - "endpoint": "localhost:" + os.getenv("MINIO_PORT", "9000"), - "access_key": os.getenv("MINIO_USER", "rag_flow"), - "secret_key": os.getenv("MINIO_PASSWORD", "infini_rag_flow"), - "secure": False -} - - def allowed_file(filename): """Check if the file extension is allowed""" return '.' in filename and \ @@ -435,8 +417,87 @@ def batch_delete_files(file_ids): except Exception as e: raise e -def upload_files_to_server(files, kb_id=None, user_id=None): +def upload_files_to_server(files, kb_id=None, user_id=None, parent_id=None): """处理文件上传到服务器的核心逻辑""" + if user_id is None: + try: + conn = get_db_connection() + cursor = conn.cursor(dictionary=True) + + # 查询创建时间最早的用户ID + query_earliest_user = """ + SELECT id FROM user + WHERE create_time = (SELECT MIN(create_time) FROM user) + LIMIT 1 + """ + cursor.execute(query_earliest_user) + earliest_user = cursor.fetchone() + + if earliest_user: + user_id = earliest_user['id'] + print(f"使用创建时间最早的用户ID: {user_id}") + else: + user_id = 'system' + print("未找到用户, 使用默认用户ID: system") + + cursor.close() + conn.close() + except Exception as e: + print(f"查询最早用户ID失败: {str(e)}") + user_id = 'system' + + # 如果没有指定parent_id,则获取用户的根文件夹ID + if parent_id is None: + try: + conn = get_db_connection() + cursor = conn.cursor(dictionary=True) + + # 查询用户的根文件夹 + query_root_folder = """ + SELECT id FROM file + WHERE tenant_id = %s AND parent_id = id + LIMIT 1 + """ + cursor.execute(query_root_folder, (user_id,)) + root_folder = cursor.fetchone() + + if root_folder: + parent_id = root_folder['id'] + print(f"使用用户根文件夹ID: {parent_id}") + else: + # 如果没有找到根文件夹,创建一个 + root_id = get_uuid() + # 修改时间格式,包含时分秒 + current_time = int(datetime.now().timestamp()) + current_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + + root_folder = { + "id": root_id, + "parent_id": root_id, # 根文件夹的parent_id指向自己 + "tenant_id": user_id, + "created_by": user_id, + "name": "/", + "type": FileType.FOLDER.value, + "size": 0, + "location": "", + "source_type": FileSource.LOCAL.value, + "create_time": current_time, + "create_date": current_date, + "update_time": current_time, + "update_date": current_date + } + + FileService.insert(root_folder) + parent_id = root_id + print(f"创建并使用新的根文件夹ID: {parent_id}") + + cursor.close() + conn.close() + except Exception as e: + print(f"查询根文件夹ID失败: {str(e)}") + # 如果无法获取根文件夹,使用file_bucket_id作为备选 + parent_id = None + results = [] for file in files: @@ -450,7 +511,6 @@ def upload_files_to_server(files, kb_id=None, user_id=None): # 修复文件名处理逻辑,保留中文字符 name, ext = os.path.splitext(original_filename) - # 保留中文字符,但替换不安全字符 # 只替换文件系统不安全的字符,保留中文和其他Unicode字符 safe_name = re.sub(r'[\\/:*?"<>|]', '_', name) @@ -502,8 +562,9 @@ def upload_files_to_server(files, kb_id=None, user_id=None): # 6. 创建数据库记录 doc_id = get_uuid() + # 修改时间格式,包含时分秒 current_time = int(datetime.now().timestamp()) - current_date = datetime.now().strftime('%Y-%m-%d') + current_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S') doc = { "id": doc_id, @@ -539,7 +600,7 @@ def upload_files_to_server(files, kb_id=None, user_id=None): # 8. 创建文件记录和关联 file_record = { "id": get_uuid(), - "parent_id": file_bucket_id, # 使用文件独立的bucket_id + "parent_id": parent_id or file_bucket_id, # 优先使用指定的parent_id "tenant_id": user_id or 'system', "created_by": user_id or 'system', "name": filename, diff --git a/management/server/services/teams/service.py b/management/server/services/teams/service.py index f4f88bb..d5fbbbe 100644 --- a/management/server/services/teams/service.py +++ b/management/server/services/teams/service.py @@ -1,12 +1,12 @@ import mysql.connector from datetime import datetime from utils import generate_uuid -from database import db_config +from database import DB_CONFIG def get_teams_with_pagination(current_page, page_size, name=''): """查询团队信息,支持分页和条件筛选""" try: - conn = mysql.connector.connect(**db_config) + conn = mysql.connector.connect(**DB_CONFIG) cursor = conn.cursor(dictionary=True) # 构建WHERE子句和参数 @@ -78,7 +78,7 @@ def get_teams_with_pagination(current_page, page_size, name=''): def get_team_by_id(team_id): """根据ID获取团队详情""" try: - conn = mysql.connector.connect(**db_config) + conn = mysql.connector.connect(**DB_CONFIG) cursor = conn.cursor(dictionary=True) query = """ @@ -110,7 +110,7 @@ def get_team_by_id(team_id): def delete_team(team_id): """删除指定ID的团队""" try: - conn = mysql.connector.connect(**db_config) + conn = mysql.connector.connect(**DB_CONFIG) cursor = conn.cursor() # 删除团队成员关联 @@ -136,7 +136,7 @@ def delete_team(team_id): def get_team_members(team_id): """获取团队成员列表""" try: - conn = mysql.connector.connect(**db_config) + conn = mysql.connector.connect(**DB_CONFIG) cursor = conn.cursor(dictionary=True) query = """ @@ -174,7 +174,7 @@ def get_team_members(team_id): def add_team_member(team_id, user_id, role="member"): """添加团队成员""" try: - conn = mysql.connector.connect(**db_config) + conn = mysql.connector.connect(**DB_CONFIG) cursor = conn.cursor() # 检查用户是否已经是团队成员 @@ -229,7 +229,7 @@ def add_team_member(team_id, user_id, role="member"): def remove_team_member(team_id, user_id): """移除团队成员""" try: - conn = mysql.connector.connect(**db_config) + conn = mysql.connector.connect(**DB_CONFIG) cursor = conn.cursor() # 检查是否是团队的唯一所有者 diff --git a/management/server/services/tenants/service.py b/management/server/services/tenants/service.py index 999bef8..1073516 100644 --- a/management/server/services/tenants/service.py +++ b/management/server/services/tenants/service.py @@ -1,11 +1,11 @@ import mysql.connector from datetime import datetime -from database import db_config +from database import DB_CONFIG def get_tenants_with_pagination(current_page, page_size, username=''): """查询租户信息,支持分页和条件筛选""" try: - conn = mysql.connector.connect(**db_config) + conn = mysql.connector.connect(**DB_CONFIG) cursor = conn.cursor(dictionary=True) # 构建WHERE子句和参数 @@ -83,7 +83,7 @@ def get_tenants_with_pagination(current_page, page_size, username=''): def update_tenant(tenant_id, tenant_data): """更新租户信息""" try: - conn = mysql.connector.connect(**db_config) + conn = mysql.connector.connect(**DB_CONFIG) cursor = conn.cursor() # 更新租户表 diff --git a/management/server/services/users/service.py b/management/server/services/users/service.py index f3a72b0..d8e871c 100644 --- a/management/server/services/users/service.py +++ b/management/server/services/users/service.py @@ -1,13 +1,13 @@ import mysql.connector from datetime import datetime from utils import generate_uuid, encrypt_password -from database import db_config +from database import DB_CONFIG def get_users_with_pagination(current_page, page_size, username='', email=''): """查询用户信息,支持分页和条件筛选""" try: # 建立数据库连接 - conn = mysql.connector.connect(**db_config) + conn = mysql.connector.connect(**DB_CONFIG) cursor = conn.cursor(dictionary=True) # 构建WHERE子句和参数 @@ -68,7 +68,7 @@ def get_users_with_pagination(current_page, page_size, username='', email=''): def delete_user(user_id): """删除指定ID的用户""" try: - conn = mysql.connector.connect(**db_config) + conn = mysql.connector.connect(**DB_CONFIG) cursor = conn.cursor() # 删除 user 表中的用户记录 @@ -99,7 +99,7 @@ def delete_user(user_id): def create_user(user_data): """创建新用户,并加入最早用户的团队,并使用相同的模型配置""" try: - conn = mysql.connector.connect(**db_config) + conn = mysql.connector.connect(**DB_CONFIG) cursor = conn.cursor(dictionary=True) # 检查用户表是否为空 @@ -269,7 +269,7 @@ def create_user(user_data): def update_user(user_id, user_data): """更新用户信息""" try: - conn = mysql.connector.connect(**db_config) + conn = mysql.connector.connect(**DB_CONFIG) cursor = conn.cursor() query = """