refactor(database): 统一数据库配置并移除重复代码 (#22)
将数据库配置从各个服务文件中移除,统一到 `database.py` 中,减少代码重复。
This commit is contained in:
parent
9689a2efd7
commit
07054fa7c3
|
@ -2,6 +2,11 @@ import mysql.connector
|
|||
import os
|
||||
from utils import generate_uuid, encrypt_password
|
||||
from datetime import datetime
|
||||
from minio import Minio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv("../../docker/.env")
|
||||
|
||||
# 检测是否在Docker容器中运行
|
||||
def is_running_in_docker():
|
||||
|
@ -16,12 +21,70 @@ def is_running_in_docker():
|
|||
|
||||
# 根据运行环境选择合适的主机地址
|
||||
DB_HOST = 'host.docker.internal' if is_running_in_docker() else 'localhost'
|
||||
MINIO_HOST = 'host.docker.internal' if is_running_in_docker() else 'localhost'
|
||||
|
||||
# 数据库连接配置
|
||||
db_config = {
|
||||
DB_CONFIG = {
|
||||
"host": DB_HOST,
|
||||
"port": 5455,
|
||||
"port": int(os.getenv("MYSQL_PORT", "5455")),
|
||||
"user": "root",
|
||||
"password": "infini_rag_flow",
|
||||
"password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"),
|
||||
"database": "rag_flow",
|
||||
}
|
||||
}
|
||||
|
||||
# MinIO连接配置
|
||||
MINIO_CONFIG = {
|
||||
"endpoint": f"{MINIO_HOST}:{os.getenv('MINIO_PORT', '9000')}",
|
||||
"access_key": os.getenv("MINIO_USER", "rag_flow"),
|
||||
"secret_key": os.getenv("MINIO_PASSWORD", "infini_rag_flow"),
|
||||
"secure": False
|
||||
}
|
||||
|
||||
def get_db_connection():
|
||||
"""创建MySQL数据库连接"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
return conn
|
||||
except Exception as e:
|
||||
print(f"MySQL连接失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def get_minio_client():
|
||||
"""创建MinIO客户端连接"""
|
||||
try:
|
||||
minio_client = Minio(
|
||||
endpoint=MINIO_CONFIG["endpoint"],
|
||||
access_key=MINIO_CONFIG["access_key"],
|
||||
secret_key=MINIO_CONFIG["secret_key"],
|
||||
secure=MINIO_CONFIG["secure"]
|
||||
)
|
||||
return minio_client
|
||||
except Exception as e:
|
||||
print(f"MinIO连接失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def test_connections():
|
||||
"""测试数据库和MinIO连接"""
|
||||
try:
|
||||
# 测试MySQL连接
|
||||
db_conn = get_db_connection()
|
||||
cursor = db_conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.fetchone()
|
||||
cursor.close()
|
||||
db_conn.close()
|
||||
print("MySQL连接测试成功")
|
||||
|
||||
# 测试MinIO连接
|
||||
minio_client = get_minio_client()
|
||||
buckets = minio_client.list_buckets()
|
||||
print(f"MinIO连接测试成功,共有 {len(buckets)} 个存储桶")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"连接测试失败: {str(e)}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 如果直接运行此文件,则测试连接
|
||||
test_connections()
|
|
@ -17,7 +17,6 @@ class FileService(BaseService):
|
|||
'source_type': 'knowledgebase'
|
||||
})
|
||||
|
||||
@classmethod
|
||||
@classmethod
|
||||
def get_parser(cls, file_type, filename, tenant_id):
|
||||
"""获取适合文件类型的解析器ID"""
|
||||
|
|
|
@ -1,15 +1,7 @@
|
|||
from peewee import *
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
# 数据库连接配置
|
||||
DB_CONFIG = {
|
||||
"host": "localhost",
|
||||
"port": int(os.getenv("MYSQL_PORT", "5455")),
|
||||
"user": "root",
|
||||
"password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"),
|
||||
"database": "rag_flow"
|
||||
}
|
||||
from database import DB_CONFIG
|
||||
|
||||
# 使用MySQL数据库
|
||||
db = MySQLDatabase(
|
||||
|
|
|
@ -10,7 +10,7 @@ from .utils import FileType, FileSource, StatusEnum, get_uuid
|
|||
from .document_service import DocumentService
|
||||
from .file_service import FileService
|
||||
from .file2document_service import File2DocumentService
|
||||
|
||||
from database import DB_CONFIG, MINIO_CONFIG
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv("../../docker/.env")
|
||||
|
@ -18,24 +18,6 @@ load_dotenv("../../docker/.env")
|
|||
UPLOAD_FOLDER = '/data/uploads'
|
||||
ALLOWED_EXTENSIONS = {'pdf', 'doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx', 'jpg', 'jpeg', 'png', 'txt', 'md'}
|
||||
|
||||
# 数据库连接配置
|
||||
DB_CONFIG = {
|
||||
"host": "localhost",
|
||||
"port": int(os.getenv("MYSQL_PORT", "5455")),
|
||||
"user": "root",
|
||||
"password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"),
|
||||
"database": "rag_flow"
|
||||
}
|
||||
|
||||
# MinIO连接配置
|
||||
MINIO_CONFIG = {
|
||||
"endpoint": "localhost:" + os.getenv("MINIO_PORT", "9000"),
|
||||
"access_key": os.getenv("MINIO_USER", "rag_flow"),
|
||||
"secret_key": os.getenv("MINIO_PASSWORD", "infini_rag_flow"),
|
||||
"secure": False
|
||||
}
|
||||
|
||||
|
||||
def allowed_file(filename):
|
||||
"""Check if the file extension is allowed"""
|
||||
return '.' in filename and \
|
||||
|
@ -435,8 +417,87 @@ def batch_delete_files(file_ids):
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def upload_files_to_server(files, kb_id=None, user_id=None):
|
||||
def upload_files_to_server(files, kb_id=None, user_id=None, parent_id=None):
|
||||
"""处理文件上传到服务器的核心逻辑"""
|
||||
if user_id is None:
|
||||
try:
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询创建时间最早的用户ID
|
||||
query_earliest_user = """
|
||||
SELECT id FROM user
|
||||
WHERE create_time = (SELECT MIN(create_time) FROM user)
|
||||
LIMIT 1
|
||||
"""
|
||||
cursor.execute(query_earliest_user)
|
||||
earliest_user = cursor.fetchone()
|
||||
|
||||
if earliest_user:
|
||||
user_id = earliest_user['id']
|
||||
print(f"使用创建时间最早的用户ID: {user_id}")
|
||||
else:
|
||||
user_id = 'system'
|
||||
print("未找到用户, 使用默认用户ID: system")
|
||||
|
||||
cursor.close()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
print(f"查询最早用户ID失败: {str(e)}")
|
||||
user_id = 'system'
|
||||
|
||||
# 如果没有指定parent_id,则获取用户的根文件夹ID
|
||||
if parent_id is None:
|
||||
try:
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询用户的根文件夹
|
||||
query_root_folder = """
|
||||
SELECT id FROM file
|
||||
WHERE tenant_id = %s AND parent_id = id
|
||||
LIMIT 1
|
||||
"""
|
||||
cursor.execute(query_root_folder, (user_id,))
|
||||
root_folder = cursor.fetchone()
|
||||
|
||||
if root_folder:
|
||||
parent_id = root_folder['id']
|
||||
print(f"使用用户根文件夹ID: {parent_id}")
|
||||
else:
|
||||
# 如果没有找到根文件夹,创建一个
|
||||
root_id = get_uuid()
|
||||
# 修改时间格式,包含时分秒
|
||||
current_time = int(datetime.now().timestamp())
|
||||
current_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
root_folder = {
|
||||
"id": root_id,
|
||||
"parent_id": root_id, # 根文件夹的parent_id指向自己
|
||||
"tenant_id": user_id,
|
||||
"created_by": user_id,
|
||||
"name": "/",
|
||||
"type": FileType.FOLDER.value,
|
||||
"size": 0,
|
||||
"location": "",
|
||||
"source_type": FileSource.LOCAL.value,
|
||||
"create_time": current_time,
|
||||
"create_date": current_date,
|
||||
"update_time": current_time,
|
||||
"update_date": current_date
|
||||
}
|
||||
|
||||
FileService.insert(root_folder)
|
||||
parent_id = root_id
|
||||
print(f"创建并使用新的根文件夹ID: {parent_id}")
|
||||
|
||||
cursor.close()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
print(f"查询根文件夹ID失败: {str(e)}")
|
||||
# 如果无法获取根文件夹,使用file_bucket_id作为备选
|
||||
parent_id = None
|
||||
|
||||
results = []
|
||||
|
||||
for file in files:
|
||||
|
@ -450,7 +511,6 @@ def upload_files_to_server(files, kb_id=None, user_id=None):
|
|||
# 修复文件名处理逻辑,保留中文字符
|
||||
name, ext = os.path.splitext(original_filename)
|
||||
|
||||
# 保留中文字符,但替换不安全字符
|
||||
# 只替换文件系统不安全的字符,保留中文和其他Unicode字符
|
||||
safe_name = re.sub(r'[\\/:*?"<>|]', '_', name)
|
||||
|
||||
|
@ -502,8 +562,9 @@ def upload_files_to_server(files, kb_id=None, user_id=None):
|
|||
|
||||
# 6. 创建数据库记录
|
||||
doc_id = get_uuid()
|
||||
# 修改时间格式,包含时分秒
|
||||
current_time = int(datetime.now().timestamp())
|
||||
current_date = datetime.now().strftime('%Y-%m-%d')
|
||||
current_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
doc = {
|
||||
"id": doc_id,
|
||||
|
@ -539,7 +600,7 @@ def upload_files_to_server(files, kb_id=None, user_id=None):
|
|||
# 8. 创建文件记录和关联
|
||||
file_record = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": file_bucket_id, # 使用文件独立的bucket_id
|
||||
"parent_id": parent_id or file_bucket_id, # 优先使用指定的parent_id
|
||||
"tenant_id": user_id or 'system',
|
||||
"created_by": user_id or 'system',
|
||||
"name": filename,
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import mysql.connector
|
||||
from datetime import datetime
|
||||
from utils import generate_uuid
|
||||
from database import db_config
|
||||
from database import DB_CONFIG
|
||||
|
||||
def get_teams_with_pagination(current_page, page_size, name=''):
|
||||
"""查询团队信息,支持分页和条件筛选"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**db_config)
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 构建WHERE子句和参数
|
||||
|
@ -78,7 +78,7 @@ def get_teams_with_pagination(current_page, page_size, name=''):
|
|||
def get_team_by_id(team_id):
|
||||
"""根据ID获取团队详情"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**db_config)
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
query = """
|
||||
|
@ -110,7 +110,7 @@ def get_team_by_id(team_id):
|
|||
def delete_team(team_id):
|
||||
"""删除指定ID的团队"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**db_config)
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 删除团队成员关联
|
||||
|
@ -136,7 +136,7 @@ def delete_team(team_id):
|
|||
def get_team_members(team_id):
|
||||
"""获取团队成员列表"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**db_config)
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
query = """
|
||||
|
@ -174,7 +174,7 @@ def get_team_members(team_id):
|
|||
def add_team_member(team_id, user_id, role="member"):
|
||||
"""添加团队成员"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**db_config)
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 检查用户是否已经是团队成员
|
||||
|
@ -229,7 +229,7 @@ def add_team_member(team_id, user_id, role="member"):
|
|||
def remove_team_member(team_id, user_id):
|
||||
"""移除团队成员"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**db_config)
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 检查是否是团队的唯一所有者
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import mysql.connector
|
||||
from datetime import datetime
|
||||
from database import db_config
|
||||
from database import DB_CONFIG
|
||||
|
||||
def get_tenants_with_pagination(current_page, page_size, username=''):
|
||||
"""查询租户信息,支持分页和条件筛选"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**db_config)
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 构建WHERE子句和参数
|
||||
|
@ -83,7 +83,7 @@ def get_tenants_with_pagination(current_page, page_size, username=''):
|
|||
def update_tenant(tenant_id, tenant_data):
|
||||
"""更新租户信息"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**db_config)
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 更新租户表
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import mysql.connector
|
||||
from datetime import datetime
|
||||
from utils import generate_uuid, encrypt_password
|
||||
from database import db_config
|
||||
from database import DB_CONFIG
|
||||
|
||||
def get_users_with_pagination(current_page, page_size, username='', email=''):
|
||||
"""查询用户信息,支持分页和条件筛选"""
|
||||
try:
|
||||
# 建立数据库连接
|
||||
conn = mysql.connector.connect(**db_config)
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 构建WHERE子句和参数
|
||||
|
@ -68,7 +68,7 @@ def get_users_with_pagination(current_page, page_size, username='', email=''):
|
|||
def delete_user(user_id):
|
||||
"""删除指定ID的用户"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**db_config)
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 删除 user 表中的用户记录
|
||||
|
@ -99,7 +99,7 @@ def delete_user(user_id):
|
|||
def create_user(user_data):
|
||||
"""创建新用户,并加入最早用户的团队,并使用相同的模型配置"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**db_config)
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查用户表是否为空
|
||||
|
@ -269,7 +269,7 @@ def create_user(user_data):
|
|||
def update_user(user_id, user_data):
|
||||
"""更新用户信息"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**db_config)
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor()
|
||||
|
||||
query = """
|
||||
|
|
Loading…
Reference in New Issue