RAGflow/management/server/services/knowledgebases/service.py

994 lines
38 KiB
Python
Raw Normal View History

import mysql.connector
import json
import threading
import requests
import traceback
from datetime import datetime
from utils import generate_uuid
from database import DB_CONFIG
# 解析相关模块
from .document_parser import perform_parse, _update_document_progress
class KnowledgebaseService:
@classmethod
def _get_db_connection(cls):
"""创建数据库连接"""
return mysql.connector.connect(**DB_CONFIG)
@classmethod
def get_knowledgebase_list(cls, page=1, size=10, name=''):
"""获取知识库列表"""
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
query = """
SELECT
k.id,
k.name,
k.description,
k.create_date,
k.update_date,
k.doc_num,
k.language,
k.permission
FROM knowledgebase k
"""
params = []
if name:
query += " WHERE k.name LIKE %s"
params.append(f"%{name}%")
query += " LIMIT %s OFFSET %s"
params.extend([size, (page-1)*size])
cursor.execute(query, params)
results = cursor.fetchall()
# 处理结果
for result in results:
# 处理空描述
if not result.get('description'):
result['description'] = "暂无描述"
# 处理时间格式
if result.get('create_date'):
if isinstance(result['create_date'], datetime):
result['create_date'] = result['create_date'].strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(result['create_date'], str):
try:
# 尝试解析已有字符串格式
datetime.strptime(result['create_date'], '%Y-%m-%d %H:%M:%S')
except ValueError:
result['create_date'] = ""
# 获取总数
count_query = "SELECT COUNT(*) as total FROM knowledgebase"
if name:
count_query += " WHERE name LIKE %s"
cursor.execute(count_query, params[:1] if name else [])
total = cursor.fetchone()['total']
cursor.close()
conn.close()
return {
'list': results,
'total': total
}
@classmethod
def get_knowledgebase_detail(cls, kb_id):
"""获取知识库详情"""
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
query = """
SELECT
k.id,
k.name,
k.description,
k.create_date,
k.update_date,
k.doc_num
FROM knowledgebase k
WHERE k.id = %s
"""
cursor.execute(query, (kb_id,))
result = cursor.fetchone()
if result:
# 处理空描述
if not result.get('description'):
result['description'] = "暂无描述"
# 处理时间格式
if result.get('create_date'):
if isinstance(result['create_date'], datetime):
result['create_date'] = result['create_date'].strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(result['create_date'], str):
try:
datetime.strptime(result['create_date'], '%Y-%m-%d %H:%M:%S')
except ValueError:
result['create_date'] = ""
cursor.close()
conn.close()
return result
@classmethod
def _check_name_exists(cls, name):
"""检查知识库名称是否已存在"""
conn = cls._get_db_connection()
cursor = conn.cursor()
query = """
SELECT COUNT(*) as count
FROM knowledgebase
WHERE name = %s
"""
cursor.execute(query, (name,))
result = cursor.fetchone()
cursor.close()
conn.close()
return result[0] > 0
@classmethod
def create_knowledgebase(cls, **data):
"""创建知识库"""
try:
# 检查知识库名称是否已存在
exists = cls._check_name_exists(data['name'])
if exists:
raise Exception("知识库名称已存在")
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
# 获取最早的用户ID作为tenant_id和created_by
tenant_id = None
created_by = None
try:
query_earliest_user = """
SELECT id FROM user
WHERE create_time = (SELECT MIN(create_time) FROM user)
LIMIT 1
"""
cursor.execute(query_earliest_user)
earliest_user = cursor.fetchone()
if earliest_user:
tenant_id = earliest_user['id']
created_by = earliest_user['id'] # 使用最早用户ID作为created_by
print(f"使用创建时间最早的用户ID作为tenant_id和created_by: {tenant_id}")
else:
# 如果找不到用户,使用默认值
tenant_id = "system"
created_by = "system"
print(f"未找到用户, 使用默认值作为tenant_id和created_by: {tenant_id}")
except Exception as e:
print(f"获取用户ID失败: {str(e)},使用默认值")
tenant_id = "system"
created_by = "system"
current_time = datetime.now()
create_date = current_time.strftime('%Y-%m-%d %H:%M:%S')
create_time = int(current_time.timestamp() * 1000) # 毫秒级时间戳
update_date = create_date
update_time = create_time
# 完整的字段列表
query = """
INSERT INTO knowledgebase (
id, create_time, create_date, update_time, update_date,
avatar, tenant_id, name, language, description,
embd_id, permission, created_by, doc_num, token_num,
chunk_num, similarity_threshold, vector_similarity_weight, parser_id, parser_config,
pagerank, status
) VALUES (
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s
)
"""
# 设置默认值
default_parser_config = json.dumps({
"layout_recognize": "DeepDOC",
"chunk_token_num": 512,
"delimiter": "\n!?;。;!?",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False,
"raptor": {"use_raptor": False},
"graphrag": {"use_graphrag": False}
})
kb_id = generate_uuid()
cursor.execute(query, (
kb_id, # id
create_time, # create_time
create_date, # create_date
update_time, # update_time
update_date, # update_date
None, # avatar
tenant_id, # tenant_id
data['name'], # name
data.get('language', 'Chinese'), # language
data.get('description', ''), # description
'bge-m3:latest@Ollama', # embd_id
data.get('permission', 'me'), # permission
created_by, # created_by - 使用内部获取的值
0, # doc_num
0, # token_num
0, # chunk_num
0.7, # similarity_threshold
0.3, # vector_similarity_weight
'naive', # parser_id
default_parser_config, # parser_config
0, # pagerank
'1' # status
))
conn.commit()
cursor.close()
conn.close()
# 返回创建后的知识库详情
return cls.get_knowledgebase_detail(kb_id)
except Exception as e:
print(f"创建知识库失败: {str(e)}")
raise Exception(f"创建知识库失败: {str(e)}")
@classmethod
def update_knowledgebase(cls, kb_id, **data):
"""更新知识库"""
try:
# 直接通过ID检查知识库是否存在
kb = cls.get_knowledgebase_detail(kb_id)
if not kb:
return None
conn = cls._get_db_connection()
cursor = conn.cursor()
# 如果要更新名称,先检查名称是否已存在
if data.get('name') and data['name'] != kb['name']:
exists = cls._check_name_exists(data['name'])
if exists:
raise Exception("知识库名称已存在")
# 构建更新语句
update_fields = []
params = []
if data.get('name'):
update_fields.append("name = %s")
params.append(data['name'])
if 'description' in data:
update_fields.append("description = %s")
params.append(data['description'])
# 更新时间
current_time = datetime.now()
update_date = current_time.strftime('%Y-%m-%d %H:%M:%S')
update_fields.append("update_date = %s")
params.append(update_date)
# 如果没有要更新的字段,直接返回
if not update_fields:
return kb_id
# 构建并执行更新语句
query = f"""
UPDATE knowledgebase
SET {', '.join(update_fields)}
WHERE id = %s
"""
params.append(kb_id)
cursor.execute(query, params)
conn.commit()
cursor.close()
conn.close()
# 返回更新后的知识库详情
return cls.get_knowledgebase_detail(kb_id)
except Exception as e:
print(f"更新知识库失败: {str(e)}")
raise Exception(f"更新知识库失败: {str(e)}")
@classmethod
def delete_knowledgebase(cls, kb_id):
"""删除知识库"""
try:
conn = cls._get_db_connection()
cursor = conn.cursor()
# 先检查知识库是否存在
check_query = "SELECT id FROM knowledgebase WHERE id = %s"
cursor.execute(check_query, (kb_id,))
if not cursor.fetchone():
raise Exception("知识库不存在")
# 执行删除
delete_query = "DELETE FROM knowledgebase WHERE id = %s"
cursor.execute(delete_query, (kb_id,))
conn.commit()
cursor.close()
conn.close()
return True
except Exception as e:
print(f"删除知识库失败: {str(e)}")
raise Exception(f"删除知识库失败: {str(e)}")
@classmethod
def batch_delete_knowledgebase(cls, kb_ids):
"""批量删除知识库"""
try:
conn = cls._get_db_connection()
cursor = conn.cursor()
# 检查所有ID是否存在
check_query = "SELECT id FROM knowledgebase WHERE id IN (%s)" % \
','.join(['%s'] * len(kb_ids))
cursor.execute(check_query, kb_ids)
existing_ids = [row[0] for row in cursor.fetchall()]
if len(existing_ids) != len(kb_ids):
missing_ids = set(kb_ids) - set(existing_ids)
raise Exception(f"以下知识库不存在: {', '.join(missing_ids)}")
# 执行批量删除
delete_query = "DELETE FROM knowledgebase WHERE id IN (%s)" % \
','.join(['%s'] * len(kb_ids))
cursor.execute(delete_query, kb_ids)
conn.commit()
cursor.close()
conn.close()
return len(kb_ids)
except Exception as e:
print(f"批量删除知识库失败: {str(e)}")
raise Exception(f"批量删除知识库失败: {str(e)}")
@classmethod
def get_knowledgebase_documents(cls, kb_id, page=1, size=10, name=''):
"""获取知识库下的文档列表"""
try:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
# 先检查知识库是否存在
check_query = "SELECT id FROM knowledgebase WHERE id = %s"
cursor.execute(check_query, (kb_id,))
if not cursor.fetchone():
raise Exception("知识库不存在")
# 查询文档列表
query = """
SELECT
d.id,
d.name,
d.chunk_num,
d.create_date,
d.status,
d.run,
d.progress,
d.parser_id,
d.parser_config,
d.meta_fields
FROM document d
WHERE d.kb_id = %s
"""
params = [kb_id]
if name:
query += " AND d.name LIKE %s"
params.append(f"%{name}%")
query += " ORDER BY d.create_date DESC LIMIT %s OFFSET %s"
params.extend([size, (page-1)*size])
cursor.execute(query, params)
results = cursor.fetchall()
# 处理日期时间格式
for result in results:
if result.get('create_date'):
result['create_date'] = result['create_date'].strftime('%Y-%m-%d %H:%M:%S')
# 获取总数
count_query = "SELECT COUNT(*) as total FROM document WHERE kb_id = %s"
count_params = [kb_id]
if name:
count_query += " AND name LIKE %s"
count_params.append(f"%{name}%")
cursor.execute(count_query, count_params)
total = cursor.fetchone()['total']
cursor.close()
conn.close()
return {
'list': results,
'total': total
}
except Exception as e:
print(f"获取知识库文档列表失败: {str(e)}")
raise Exception(f"获取知识库文档列表失败: {str(e)}")
@classmethod
def add_documents_to_knowledgebase(cls, kb_id, file_ids, created_by=None):
"""添加文档到知识库"""
try:
print(f"[DEBUG] 开始添加文档,参数: kb_id={kb_id}, file_ids={file_ids}")
# 如果没有传入created_by则获取最早的用户ID
if created_by is None:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
# 查询创建时间最早的用户ID
query_earliest_user = """
SELECT id FROM user
WHERE create_time = (SELECT MIN(create_time) FROM user)
LIMIT 1
"""
cursor.execute(query_earliest_user)
earliest_user = cursor.fetchone()
if earliest_user:
created_by = earliest_user['id']
print(f"使用创建时间最早的用户ID: {created_by}")
else:
created_by = 'system'
print("未找到用户, 使用默认用户ID: system")
cursor.close()
conn.close()
# 检查知识库是否存在
kb = cls.get_knowledgebase_detail(kb_id)
print(f"[DEBUG] 知识库检查结果: {kb}")
if not kb:
print(f"[ERROR] 知识库不存在: {kb_id}")
raise Exception("知识库不存在")
conn = cls._get_db_connection()
cursor = conn.cursor()
# 获取文件信息
file_query = """
SELECT id, name, location, size, type
FROM file
WHERE id IN (%s)
""" % ','.join(['%s'] * len(file_ids))
print(f"[DEBUG] 执行文件查询SQL: {file_query}")
print(f"[DEBUG] 查询参数: {file_ids}")
try:
cursor.execute(file_query, file_ids)
files = cursor.fetchall()
print(f"[DEBUG] 查询到的文件数据: {files}")
except Exception as e:
print(f"[ERROR] 文件查询失败: {str(e)}")
raise
if len(files) != len(file_ids):
print(f"部分文件不存在: 期望={len(file_ids)}, 实际={len(files)}")
raise Exception("部分文件不存在")
# 添加文档记录
added_count = 0
for file in files:
file_id = file[0]
file_name = file[1]
print(f"处理文件: id={file_id}, name={file_name}")
file_location = file[2]
file_size = file[3]
file_type = file[4]
# 检查文档是否已存在于知识库
check_query = """
SELECT COUNT(*)
FROM document d
JOIN file2document f2d ON d.id = f2d.document_id
WHERE d.kb_id = %s AND f2d.file_id = %s
"""
cursor.execute(check_query, (kb_id, file_id))
exists = cursor.fetchone()[0] > 0
if exists:
continue # 跳过已存在的文档
# 创建文档记录
doc_id = generate_uuid()
current_datetime = datetime.now()
create_time = int(current_datetime.timestamp() * 1000) # 毫秒级时间戳
current_date = current_datetime.strftime("%Y-%m-%d %H:%M:%S") # 格式化日期字符串
# 设置默认值
default_parser_id = "naive"
default_parser_config = json.dumps({
"layout_recognize": "DeepDOC",
"chunk_token_num": 512,
"delimiter": "\n!?;。;!?",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False,
"raptor": {
"use_raptor": False
},
"graphrag": {
"use_graphrag": False
}
})
default_source_type = "local"
# 插入document表
doc_query = """
INSERT INTO document (
id, create_time, create_date, update_time, update_date,
thumbnail, kb_id, parser_id, parser_config, source_type,
type, created_by, name, location, size,
token_num, chunk_num, progress, progress_msg, process_begin_at,
process_duation, meta_fields, run, status
) VALUES (
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s
)
"""
doc_params = [
doc_id, create_time, current_date, create_time, current_date, # ID和时间
None, kb_id, default_parser_id, default_parser_config, default_source_type, # thumbnail到source_type
file_type, created_by, file_name, file_location, file_size, # type到size
0, 0, 0.0, None, None, # token_num到process_begin_at
0.0, None, '0', '1' # process_duation到status
]
cursor.execute(doc_query, doc_params)
# 创建文件到文档的映射
f2d_id = generate_uuid()
f2d_query = """
INSERT INTO file2document (
id, create_time, create_date, update_time, update_date,
file_id, document_id
) VALUES (
%s, %s, %s, %s, %s,
%s, %s
)
"""
f2d_params = [
f2d_id, create_time, current_date, create_time, current_date,
file_id, doc_id
]
cursor.execute(f2d_query, f2d_params)
added_count += 1
# 更新知识库文档数量
if added_count > 0:
try:
update_query = """
UPDATE knowledgebase
SET doc_num = doc_num + %s,
update_date = %s
WHERE id = %s
"""
cursor.execute(update_query, (added_count, current_date, kb_id))
conn.commit() # 先提交更新操作
except Exception as e:
print(f"[WARNING] 更新知识库文档数量失败,但文档已添加: {str(e)}")
cursor.close()
conn.close()
return {
"added_count": added_count
}
except Exception as e:
print(f"[ERROR] 添加文档失败: {str(e)}")
print(f"[ERROR] 错误类型: {type(e)}")
import traceback
print(f"[ERROR] 堆栈信息: {traceback.format_exc()}")
raise Exception(f"添加文档到知识库失败: {str(e)}")
@classmethod
def delete_document(cls, doc_id):
"""删除文档"""
try:
conn = cls._get_db_connection()
cursor = conn.cursor()
# 先检查文档是否存在
check_query = "SELECT kb_id FROM document WHERE id = %s"
cursor.execute(check_query, (doc_id,))
result = cursor.fetchone()
if not result:
raise Exception("文档不存在")
kb_id = result[0]
# 删除文件到文档的映射
f2d_query = "DELETE FROM file2document WHERE document_id = %s"
cursor.execute(f2d_query, (doc_id,))
# 删除文档
doc_query = "DELETE FROM document WHERE id = %s"
cursor.execute(doc_query, (doc_id,))
# 更新知识库文档数量
update_query = """
UPDATE knowledgebase
SET doc_num = doc_num - 1,
update_date = %s
WHERE id = %s
"""
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
cursor.execute(update_query, (current_date, kb_id))
conn.commit()
cursor.close()
conn.close()
return True
except Exception as e:
print(f"[ERROR] 删除文档失败: {str(e)}")
raise Exception(f"删除文档失败: {str(e)}")
@classmethod
def parse_document(cls, doc_id):
"""解析文档(调用解析逻辑)"""
conn = None
cursor = None
try:
# 1. 获取文档和文件信息
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
# 查询文档信息
doc_query = """
SELECT d.id, d.name, d.location, d.type, d.kb_id, d.parser_id, d.parser_config, d.created_by
FROM document d
WHERE d.id = %s
"""
cursor.execute(doc_query, (doc_id,))
doc_info = cursor.fetchone()
if not doc_info:
raise Exception("文档不存在")
# 获取关联的文件信息 (主要是 parent_id 作为 bucket_name)
f2d_query = "SELECT file_id FROM file2document WHERE document_id = %s"
cursor.execute(f2d_query, (doc_id,))
f2d_result = cursor.fetchone()
if not f2d_result:
raise Exception("无法找到文件到文档的映射关系")
file_id = f2d_result['file_id']
file_query = "SELECT parent_id FROM file WHERE id = %s"
cursor.execute(file_query, (file_id,))
file_info = cursor.fetchone()
if not file_info:
raise Exception("无法找到文件记录")
cursor.close()
conn.close()
conn = None # 确保连接已关闭
# 2. 更新文档状态为处理中 (使用 parser 模块的函数)
_update_document_progress(doc_id, status='2', run='1', progress=0.0, message='开始解析')
# 3. 调用后台解析函数
embedding_config = cls.get_system_embedding_config()
parse_result = perform_parse(doc_id, doc_info, file_info, embedding_config)
# 4. 返回解析结果
return parse_result
except Exception as e:
print(f"文档解析启动或执行过程中出错 (Doc ID: {doc_id}): {str(e)}")
# 确保在异常时更新状态为失败
try:
_update_document_progress(doc_id, status='1', run='0', message=f"解析失败: {str(e)}")
except Exception as update_err:
print(f"更新文档失败状态时出错 (Doc ID: {doc_id}): {str(update_err)}")
# raise Exception(f"文档解析失败: {str(e)}")
return {"success": False, "error": f"文档解析失败: {str(e)}"}
finally:
if cursor:
cursor.close()
if conn:
conn.close()
@classmethod
def async_parse_document(cls, doc_id):
"""异步解析文档"""
try:
# 启动后台线程执行同步的 parse_document 方法
thread = threading.Thread(target=cls.parse_document, args=(doc_id,))
thread.daemon = True # 设置为守护线程,主程序退出时线程也退出
thread.start()
# 立即返回,表示任务已提交
return {
"task_id": doc_id, # 使用 doc_id 作为任务标识符
"status": "processing",
"message": "文档解析任务已提交到后台处理"
}
except Exception as e:
print(f"启动异步解析任务失败 (Doc ID: {doc_id}): {str(e)}")
# 可以在这里尝试更新文档状态为失败
try:
_update_document_progress(doc_id, status='1', run='0', message=f"启动解析失败: {str(e)}")
except Exception as update_err:
print(f"更新文档启动失败状态时出错 (Doc ID: {doc_id}): {str(update_err)}")
raise Exception(f"启动异步解析任务失败: {str(e)}")
@classmethod
def get_document_parse_progress(cls, doc_id):
"""获取文档解析进度"""
conn = None
cursor = None
try:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
query = """
SELECT progress, progress_msg, status, run
FROM document
WHERE id = %s
"""
cursor.execute(query, (doc_id,))
result = cursor.fetchone()
if not result:
return {"error": "文档不存在"}
# 确保 progress 是浮点数
progress_value = 0.0
if result.get("progress") is not None:
try:
progress_value = float(result["progress"])
except (ValueError, TypeError):
progress_value = 0.0 # 或记录错误
return {
"progress": progress_value,
"message": result.get("progress_msg", ""),
"status": result.get("status", "0"),
"running": result.get("run", "0"),
}
except Exception as e:
print(f"获取文档进度失败 (Doc ID: {doc_id}): {str(e)}")
return {"error": f"获取进度失败: {str(e)}"}
finally:
if cursor:
cursor.close()
if conn:
conn.close()
# --- 获取最早用户 ID ---
@classmethod
def _get_earliest_user_tenant_id(cls):
"""获取创建时间最早的用户的 ID (作为 tenant_id)"""
conn = None
cursor = None
try:
conn = cls._get_db_connection()
cursor = conn.cursor()
query = "SELECT id FROM user ORDER BY create_time ASC LIMIT 1"
cursor.execute(query)
result = cursor.fetchone()
if result:
return result[0] # 返回用户 ID
else:
print("警告: 数据库中没有用户!")
return None
except Exception as e:
print(f"查询最早用户时出错: {e}")
traceback.print_exc()
return None
finally:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()
# --- 测试 Embedding 连接 ---
@classmethod
def _test_embedding_connection(cls, base_url, model_name, api_key):
"""
测试与自定义 Embedding 模型的连接 (使用 requests)
"""
print(f"开始测试连接: base_url={base_url}, model_name={model_name}")
try:
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
payload = {"input": ["Test connection"], "model": model_name}
if not base_url.startswith(('http://', 'https://')):
base_url = 'http://' + base_url
if not base_url.endswith('/'):
base_url += '/'
endpoint = "embeddings"
current_test_url = base_url + endpoint
print(f"尝试请求 URL: {current_test_url}")
try:
response = requests.post(current_test_url, headers=headers, json=payload, timeout=15)
print(f"请求 {current_test_url} 返回状态码: {response.status_code}")
if response.status_code == 200:
res_json = response.json()
if ("data" in res_json and isinstance(res_json["data"], list) and len(res_json["data"]) > 0 and "embedding" in res_json["data"][0] and len(res_json["data"][0]["embedding"]) > 0) or \
(isinstance(res_json, list) and len(res_json) > 0 and isinstance(res_json[0], list) and len(res_json[0]) > 0):
print(f"连接测试成功: {current_test_url}")
return True, "连接成功"
else:
print(f"连接成功但响应格式不正确于 {current_test_url}")
except Exception as json_e:
print(f"解析 JSON 响应失败于 {current_test_url}: {json_e}")
return False, "连接失败: 响应错误"
except Exception as e:
print(f"连接测试发生未知错误: {str(e)}")
traceback.print_exc()
return False, f"测试时发生未知错误: {str(e)}"
# --- 获取系统 Embedding 配置 ---
@classmethod
def get_system_embedding_config(cls):
"""获取系统级(最早用户)的 Embedding 配置"""
tenant_id = cls._get_earliest_user_tenant_id()
if not tenant_id:
raise Exception("无法找到系统基础用户") # 在服务层抛出异常
conn = None
cursor = None
try:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True) # 使用字典游标方便访问列名
query = """
SELECT llm_name, api_key, api_base
FROM tenant_llm
WHERE tenant_id = %s
LIMIT 1
"""
cursor.execute(query, (tenant_id,))
config = cursor.fetchone()
if config:
llm_name = config.get("llm_name", "")
api_key = config.get("api_key", "")
api_base = config.get("api_base", "")
# 对模型名称进行处理
if llm_name and '___' in llm_name:
llm_name = llm_name.split('___')[0]
# 如果有配置,返回
return {
"llm_name": llm_name,
"api_key": api_key,
"api_base": api_base
}
else:
# 如果没有配置,返回空
return {
"llm_name": "",
"api_key": "",
"api_base": ""
}
except Exception as e:
print(f"获取系统 Embedding 配置时出错: {e}")
traceback.print_exc()
raise Exception(f"获取配置时数据库出错: {e}") # 重新抛出异常
finally:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()
# --- 设置系统 Embedding 配置 ---
@classmethod
def set_system_embedding_config(cls, llm_name, api_base, api_key):
"""设置系统级(最早用户)的 Embedding 配置"""
tenant_id = cls._get_earliest_user_tenant_id()
if not tenant_id:
raise Exception("无法找到系统基础用户")
# 执行连接测试
is_connected, message = cls._test_embedding_connection(
base_url=api_base,
model_name=llm_name,
api_key=api_key
)
if not is_connected:
# 返回具体的测试失败原因给调用者(路由层)处理
return False, f"连接测试失败: {message}"
return True, f"连接成功: {message}"
# 测试通过,保存或更新配置到数据库(先不保存,以防冲突)
# conn = None
# cursor = None
# try:
# conn = cls._get_db_connection()
# cursor = conn.cursor()
# # 检查 TenantLLM 记录是否存在
# check_query = """
# SELECT id FROM tenant_llm
# WHERE tenant_id = %s AND llm_name = %s
# """
# cursor.execute(check_query, (tenant_id, llm_name))
# existing_config = cursor.fetchone()
# now = datetime.now()
# if existing_config:
# # 更新记录
# update_sql = """
# UPDATE tenant_llm
# SET api_key = %s, api_base = %s, max_tokens = %s, update_time = %s, update_date = %s
# WHERE id = %s
# """
# update_params = (api_key, api_base, max_tokens, now, now.date(), existing_config[0])
# cursor.execute(update_sql, update_params)
# print(f"已更新 TenantLLM 记录 (ID: {existing_config[0]})")
# else:
# # 插入新记录
# insert_sql = """
# INSERT INTO tenant_llm (tenant_id, llm_factory, model_type, llm_name, api_key, api_base, max_tokens, create_time, create_date, update_time, update_date, used_tokens)
# VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
# """
# insert_params = (tenant_id, "VLLM", "embedding", llm_name, api_key, api_base, max_tokens, now, now.date(), now, now.date(), 0) # used_tokens 默认为 0
# cursor.execute(insert_sql, insert_params)
# print(f"已创建新的 TenantLLM 记录")
# conn.commit() # 提交事务
# return True, "配置已成功保存"
# except Exception as e:
# if conn:
# conn.rollback() # 出错时回滚
# print(f"保存系统 Embedding 配置时数据库出错: {e}")
# traceback.print_exc()
# # 返回 False 和错误信息给路由层
# return False, f"保存配置时数据库出错: {e}"
# finally:
# if cursor:
# cursor.close()
# if conn and conn.is_connected():
# conn.close()