refactor(database): 统一数据库配置并移除重复代码 (#22)

将数据库配置从各个服务文件中移除,统一到 `database.py` 中,减少代码重复。
This commit is contained in:
zstar 2025-04-12 16:40:35 +08:00 committed by GitHub
parent 9689a2efd7
commit 07054fa7c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 167 additions and 52 deletions

View File

@ -2,6 +2,11 @@ import mysql.connector
import os
from utils import generate_uuid, encrypt_password
from datetime import datetime
from minio import Minio
from dotenv import load_dotenv
# 加载环境变量
load_dotenv("../../docker/.env")
# 检测是否在Docker容器中运行
def is_running_in_docker():
@ -16,12 +21,70 @@ def is_running_in_docker():
# 根据运行环境选择合适的主机地址
DB_HOST = 'host.docker.internal' if is_running_in_docker() else 'localhost'
MINIO_HOST = 'host.docker.internal' if is_running_in_docker() else 'localhost'
# 数据库连接配置
db_config = {
DB_CONFIG = {
"host": DB_HOST,
"port": 5455,
"port": int(os.getenv("MYSQL_PORT", "5455")),
"user": "root",
"password": "infini_rag_flow",
"password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"),
"database": "rag_flow",
}
}
# MinIO连接配置
MINIO_CONFIG = {
"endpoint": f"{MINIO_HOST}:{os.getenv('MINIO_PORT', '9000')}",
"access_key": os.getenv("MINIO_USER", "rag_flow"),
"secret_key": os.getenv("MINIO_PASSWORD", "infini_rag_flow"),
"secure": False
}
def get_db_connection():
"""创建MySQL数据库连接"""
try:
conn = mysql.connector.connect(**DB_CONFIG)
return conn
except Exception as e:
print(f"MySQL连接失败: {str(e)}")
raise e
def get_minio_client():
"""创建MinIO客户端连接"""
try:
minio_client = Minio(
endpoint=MINIO_CONFIG["endpoint"],
access_key=MINIO_CONFIG["access_key"],
secret_key=MINIO_CONFIG["secret_key"],
secure=MINIO_CONFIG["secure"]
)
return minio_client
except Exception as e:
print(f"MinIO连接失败: {str(e)}")
raise e
def test_connections():
"""测试数据库和MinIO连接"""
try:
# 测试MySQL连接
db_conn = get_db_connection()
cursor = db_conn.cursor()
cursor.execute("SELECT 1")
cursor.fetchone()
cursor.close()
db_conn.close()
print("MySQL连接测试成功")
# 测试MinIO连接
minio_client = get_minio_client()
buckets = minio_client.list_buckets()
print(f"MinIO连接测试成功共有 {len(buckets)} 个存储桶")
return True
except Exception as e:
print(f"连接测试失败: {str(e)}")
return False
if __name__ == "__main__":
# 如果直接运行此文件,则测试连接
test_connections()

View File

@ -17,7 +17,6 @@ class FileService(BaseService):
'source_type': 'knowledgebase'
})
@classmethod
@classmethod
def get_parser(cls, file_type, filename, tenant_id):
"""获取适合文件类型的解析器ID"""

View File

@ -1,15 +1,7 @@
from peewee import *
import os
from datetime import datetime
# 数据库连接配置
DB_CONFIG = {
"host": "localhost",
"port": int(os.getenv("MYSQL_PORT", "5455")),
"user": "root",
"password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"),
"database": "rag_flow"
}
from database import DB_CONFIG
# 使用MySQL数据库
db = MySQLDatabase(

View File

@ -10,7 +10,7 @@ from .utils import FileType, FileSource, StatusEnum, get_uuid
from .document_service import DocumentService
from .file_service import FileService
from .file2document_service import File2DocumentService
from database import DB_CONFIG, MINIO_CONFIG
# 加载环境变量
load_dotenv("../../docker/.env")
@ -18,24 +18,6 @@ load_dotenv("../../docker/.env")
UPLOAD_FOLDER = '/data/uploads'
ALLOWED_EXTENSIONS = {'pdf', 'doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx', 'jpg', 'jpeg', 'png', 'txt', 'md'}
# 数据库连接配置
DB_CONFIG = {
"host": "localhost",
"port": int(os.getenv("MYSQL_PORT", "5455")),
"user": "root",
"password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"),
"database": "rag_flow"
}
# MinIO连接配置
MINIO_CONFIG = {
"endpoint": "localhost:" + os.getenv("MINIO_PORT", "9000"),
"access_key": os.getenv("MINIO_USER", "rag_flow"),
"secret_key": os.getenv("MINIO_PASSWORD", "infini_rag_flow"),
"secure": False
}
def allowed_file(filename):
"""Check if the file extension is allowed"""
return '.' in filename and \
@ -435,8 +417,87 @@ def batch_delete_files(file_ids):
except Exception as e:
raise e
def upload_files_to_server(files, kb_id=None, user_id=None):
def upload_files_to_server(files, kb_id=None, user_id=None, parent_id=None):
"""处理文件上传到服务器的核心逻辑"""
if user_id is None:
try:
conn = get_db_connection()
cursor = conn.cursor(dictionary=True)
# 查询创建时间最早的用户ID
query_earliest_user = """
SELECT id FROM user
WHERE create_time = (SELECT MIN(create_time) FROM user)
LIMIT 1
"""
cursor.execute(query_earliest_user)
earliest_user = cursor.fetchone()
if earliest_user:
user_id = earliest_user['id']
print(f"使用创建时间最早的用户ID: {user_id}")
else:
user_id = 'system'
print("未找到用户, 使用默认用户ID: system")
cursor.close()
conn.close()
except Exception as e:
print(f"查询最早用户ID失败: {str(e)}")
user_id = 'system'
# 如果没有指定parent_id则获取用户的根文件夹ID
if parent_id is None:
try:
conn = get_db_connection()
cursor = conn.cursor(dictionary=True)
# 查询用户的根文件夹
query_root_folder = """
SELECT id FROM file
WHERE tenant_id = %s AND parent_id = id
LIMIT 1
"""
cursor.execute(query_root_folder, (user_id,))
root_folder = cursor.fetchone()
if root_folder:
parent_id = root_folder['id']
print(f"使用用户根文件夹ID: {parent_id}")
else:
# 如果没有找到根文件夹,创建一个
root_id = get_uuid()
# 修改时间格式,包含时分秒
current_time = int(datetime.now().timestamp())
current_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
root_folder = {
"id": root_id,
"parent_id": root_id, # 根文件夹的parent_id指向自己
"tenant_id": user_id,
"created_by": user_id,
"name": "/",
"type": FileType.FOLDER.value,
"size": 0,
"location": "",
"source_type": FileSource.LOCAL.value,
"create_time": current_time,
"create_date": current_date,
"update_time": current_time,
"update_date": current_date
}
FileService.insert(root_folder)
parent_id = root_id
print(f"创建并使用新的根文件夹ID: {parent_id}")
cursor.close()
conn.close()
except Exception as e:
print(f"查询根文件夹ID失败: {str(e)}")
# 如果无法获取根文件夹使用file_bucket_id作为备选
parent_id = None
results = []
for file in files:
@ -450,7 +511,6 @@ def upload_files_to_server(files, kb_id=None, user_id=None):
# 修复文件名处理逻辑,保留中文字符
name, ext = os.path.splitext(original_filename)
# 保留中文字符,但替换不安全字符
# 只替换文件系统不安全的字符保留中文和其他Unicode字符
safe_name = re.sub(r'[\\/:*?"<>|]', '_', name)
@ -502,8 +562,9 @@ def upload_files_to_server(files, kb_id=None, user_id=None):
# 6. 创建数据库记录
doc_id = get_uuid()
# 修改时间格式,包含时分秒
current_time = int(datetime.now().timestamp())
current_date = datetime.now().strftime('%Y-%m-%d')
current_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
doc = {
"id": doc_id,
@ -539,7 +600,7 @@ def upload_files_to_server(files, kb_id=None, user_id=None):
# 8. 创建文件记录和关联
file_record = {
"id": get_uuid(),
"parent_id": file_bucket_id, # 使用文件独立的bucket_id
"parent_id": parent_id or file_bucket_id, # 优先使用指定的parent_id
"tenant_id": user_id or 'system',
"created_by": user_id or 'system',
"name": filename,

View File

@ -1,12 +1,12 @@
import mysql.connector
from datetime import datetime
from utils import generate_uuid
from database import db_config
from database import DB_CONFIG
def get_teams_with_pagination(current_page, page_size, name=''):
"""查询团队信息,支持分页和条件筛选"""
try:
conn = mysql.connector.connect(**db_config)
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor(dictionary=True)
# 构建WHERE子句和参数
@ -78,7 +78,7 @@ def get_teams_with_pagination(current_page, page_size, name=''):
def get_team_by_id(team_id):
"""根据ID获取团队详情"""
try:
conn = mysql.connector.connect(**db_config)
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor(dictionary=True)
query = """
@ -110,7 +110,7 @@ def get_team_by_id(team_id):
def delete_team(team_id):
"""删除指定ID的团队"""
try:
conn = mysql.connector.connect(**db_config)
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor()
# 删除团队成员关联
@ -136,7 +136,7 @@ def delete_team(team_id):
def get_team_members(team_id):
"""获取团队成员列表"""
try:
conn = mysql.connector.connect(**db_config)
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor(dictionary=True)
query = """
@ -174,7 +174,7 @@ def get_team_members(team_id):
def add_team_member(team_id, user_id, role="member"):
"""添加团队成员"""
try:
conn = mysql.connector.connect(**db_config)
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor()
# 检查用户是否已经是团队成员
@ -229,7 +229,7 @@ def add_team_member(team_id, user_id, role="member"):
def remove_team_member(team_id, user_id):
"""移除团队成员"""
try:
conn = mysql.connector.connect(**db_config)
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor()
# 检查是否是团队的唯一所有者

View File

@ -1,11 +1,11 @@
import mysql.connector
from datetime import datetime
from database import db_config
from database import DB_CONFIG
def get_tenants_with_pagination(current_page, page_size, username=''):
"""查询租户信息,支持分页和条件筛选"""
try:
conn = mysql.connector.connect(**db_config)
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor(dictionary=True)
# 构建WHERE子句和参数
@ -83,7 +83,7 @@ def get_tenants_with_pagination(current_page, page_size, username=''):
def update_tenant(tenant_id, tenant_data):
"""更新租户信息"""
try:
conn = mysql.connector.connect(**db_config)
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor()
# 更新租户表

View File

@ -1,13 +1,13 @@
import mysql.connector
from datetime import datetime
from utils import generate_uuid, encrypt_password
from database import db_config
from database import DB_CONFIG
def get_users_with_pagination(current_page, page_size, username='', email=''):
"""查询用户信息,支持分页和条件筛选"""
try:
# 建立数据库连接
conn = mysql.connector.connect(**db_config)
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor(dictionary=True)
# 构建WHERE子句和参数
@ -68,7 +68,7 @@ def get_users_with_pagination(current_page, page_size, username='', email=''):
def delete_user(user_id):
"""删除指定ID的用户"""
try:
conn = mysql.connector.connect(**db_config)
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor()
# 删除 user 表中的用户记录
@ -99,7 +99,7 @@ def delete_user(user_id):
def create_user(user_data):
"""创建新用户,并加入最早用户的团队,并使用相同的模型配置"""
try:
conn = mysql.connector.connect(**db_config)
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor(dictionary=True)
# 检查用户表是否为空
@ -269,7 +269,7 @@ def create_user(user_data):
def update_user(user_id, user_data):
"""更新用户信息"""
try:
conn = mysql.connector.connect(**db_config)
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor()
query = """