RAGflow/management/server/database.py

269 lines
9.4 KiB
Python
Raw Normal View History

2025-03-28 22:45:42 +08:00
import mysql.connector
from utils import generate_uuid, encrypt_password
from datetime import datetime
# 数据库连接配置
db_config = {
# "host": "host.docker.internal", 如果是在Docke容器内部访问数据库
"host": "localhost",
2025-03-28 22:45:42 +08:00
"port": 5455,
"user": "root",
"password": "infini_rag_flow",
"database": "rag_flow",
}
def get_users_with_pagination(current_page, page_size, username='', email=''):
"""查询用户信息,支持分页和条件筛选"""
try:
# 建立数据库连接
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor(dictionary=True)
# 构建WHERE子句和参数
where_clauses = []
params = []
if username:
where_clauses.append("nickname LIKE %s")
params.append(f"%{username}%")
if email:
where_clauses.append("email LIKE %s")
params.append(f"%{email}%")
# 组合WHERE子句
where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
# 查询总记录数
count_sql = f"SELECT COUNT(*) as total FROM user WHERE {where_sql}"
cursor.execute(count_sql, params)
total = cursor.fetchone()['total']
# 计算分页偏移量
offset = (current_page - 1) * page_size
# 执行分页查询
query = f"""
SELECT id, nickname, email, create_date, update_date, status, is_superuser
FROM user
WHERE {where_sql}
ORDER BY id DESC
LIMIT %s OFFSET %s
"""
cursor.execute(query, params + [page_size, offset])
results = cursor.fetchall()
# 关闭连接
cursor.close()
conn.close()
# 格式化结果
formatted_users = []
for user in results:
formatted_users.append({
"id": user["id"],
"username": user["nickname"],
"email": user["email"],
"createTime": user["create_date"].strftime("%Y-%m-%d %H:%M:%S") if user["create_date"] else "",
"updateTime": user["update_date"].strftime("%Y-%m-%d %H:%M:%S") if user["update_date"] else "",
})
return formatted_users, total
except mysql.connector.Error as err:
print(f"数据库错误: {err}")
return [], 0
def delete_user(user_id):
"""删除指定ID的用户"""
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor()
# 删除 user 表中的用户记录
query = "DELETE FROM user WHERE id = %s"
cursor.execute(query, (user_id,))
# 删除 user_tenant 表中的关联记录
user_tenant_query = "DELETE FROM user_tenant WHERE user_id = %s"
cursor.execute(user_tenant_query, (user_id,))
# 删除 tenant 表中的关联记录
tenant_query = "DELETE FROM tenant WHERE id = %s"
cursor.execute(tenant_query, (user_id,))
# 删除 tenant_llm 表中的关联记录
tenant_llm_query = "DELETE FROM tenant_llm WHERE tenant_id = %s"
cursor.execute(tenant_llm_query, (user_id,))
conn.commit()
cursor.close()
conn.close()
return True
except mysql.connector.Error as err:
print(f"删除用户错误: {err}")
return False
def create_user(user_data):
"""创建新用户,并加入最早用户的团队,并使用相同的模型配置"""
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor(dictionary=True)
# 查询最早创建的tenant配置
query_earliest_tenant = """
SELECT id, llm_id, embd_id, asr_id, img2txt_id, rerank_id, tts_id, parser_ids, credit
FROM tenant
WHERE create_time = (SELECT MIN(create_time) FROM tenant)
LIMIT 1
"""
cursor.execute(query_earliest_tenant)
earliest_tenant = cursor.fetchone()
# 查询最早创建的tenant配置
query_earliest_tenant_llm = """
SELECT llm_factory, model_type, llm_name, api_key, api_base, max_tokens, used_tokens
FROM tenant_llm
WHERE create_time = (SELECT MIN(create_time) FROM tenant_llm)
LIMIT 1
"""
cursor.execute(query_earliest_tenant_llm)
earliest_tenant_llm = cursor.fetchone()
# 开始插入
user_id = generate_uuid()
# 获取基本信息
username = user_data.get("username")
email = user_data.get("email")
password = user_data.get("password")
# 加密密码
encrypted_password = encrypt_password(password)
current_datetime = datetime.now()
create_time = int(current_datetime.timestamp() * 1000)
current_date = current_datetime.strftime("%Y-%m-%d %H:%M:%S")
# 插入用户表
user_insert_query = """
INSERT INTO user (
id, create_time, create_date, update_time, update_date, access_token,
nickname, password, email, avatar, language, color_schema, timezone,
last_login_time, is_authenticated, is_active, is_anonymous, login_channel,
status, is_superuser
) VALUES (
%s, %s, %s, %s, %s, %s,
%s, %s, %s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s
)
"""
user_data = (
user_id, create_time, current_date, create_time, current_date, None,
username, encrypted_password, email, None, "Chinese", "Bright", "UTC+8 Asia/Shanghai",
current_date, 1, 1, 0, "password",
1, 0
)
cursor.execute(user_insert_query, user_data)
# 插入租户表
tenant_insert_query = """
INSERT INTO tenant (
id, create_time, create_date, update_time, update_date, name,
public_key, llm_id, embd_id, asr_id, img2txt_id, rerank_id, tts_id,
parser_ids, credit, status
) VALUES (
%s, %s, %s, %s, %s, %s,
%s, %s, %s, %s, %s, %s, %s,
%s, %s, %s
)
"""
tenant_data = (
user_id, create_time, current_date, create_time, current_date, username + "'s Kingdom",
None, str(earliest_tenant['llm_id']), str(earliest_tenant['embd_id']),
str(earliest_tenant['asr_id']), str(earliest_tenant['img2txt_id']),
str(earliest_tenant['rerank_id']), str(earliest_tenant['tts_id']),
str(earliest_tenant['parser_ids']), str(earliest_tenant['credit']), 1
)
cursor.execute(tenant_insert_query, tenant_data)
# 插入用户租户关系表owner角色
user_tenant_insert_owner_query = """
INSERT INTO user_tenant (
id, create_time, create_date, update_time, update_date, user_id,
tenant_id, role, invited_by, status
) VALUES (
%s, %s, %s, %s, %s, %s,
%s, %s, %s, %s
)
"""
user_tenant_data_owner = (
generate_uuid(), create_time, current_date, create_time, current_date, user_id,
user_id, "owner", user_id, 1
)
cursor.execute(user_tenant_insert_owner_query, user_tenant_data_owner)
# 插入用户租户关系表normal角色
user_tenant_insert_normal_query = """
INSERT INTO user_tenant (
id, create_time, create_date, update_time, update_date, user_id,
tenant_id, role, invited_by, status
) VALUES (
%s, %s, %s, %s, %s, %s,
%s, %s, %s, %s
)
"""
user_tenant_data_normal = (
generate_uuid(), create_time, current_date, create_time, current_date, user_id,
earliest_tenant['id'], "normal", earliest_tenant['id'], 1
)
cursor.execute(user_tenant_insert_normal_query, user_tenant_data_normal)
# 插入租户LLM配置表
tenant_llm_insert_query = """
INSERT INTO tenant_llm (
create_time, create_date, update_time, update_date, tenant_id,
llm_factory, model_type, llm_name, api_key, api_base, max_tokens, used_tokens
) VALUES (
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s, %s, %s
)
"""
tenant_llm_data = (
create_time, current_date, create_time, current_date, user_id,
str(earliest_tenant_llm['llm_factory']), str(earliest_tenant_llm['model_type']), str(earliest_tenant_llm['llm_name']),
str(earliest_tenant_llm['api_key']), str(earliest_tenant_llm['api_base']), str(earliest_tenant_llm['max_tokens']), 0
)
cursor.execute(tenant_llm_insert_query, tenant_llm_data)
conn.commit()
cursor.close()
conn.close()
return True
except mysql.connector.Error as err:
print(f"创建用户错误: {err}")
return False
def update_user(user_id, user_data):
"""更新用户信息"""
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor()
query = """
UPDATE user SET nickname = %s WHERE id = %s
"""
cursor.execute(query, (
user_data.get("username"),
user_id
))
conn.commit()
cursor.close()
conn.close()
return True
except mysql.connector.Error as err:
print(f"更新用户错误: {err}")
return False