RAGflow/management/server/services/knowledgebases/service.py

1288 lines
49 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import threading
import time
import traceback
from datetime import datetime
import mysql.connector
import requests
from database import DB_CONFIG, get_es_client
from utils import generate_uuid
# 解析相关模块
from .document_parser import _update_document_progress, perform_parse
# 用于存储进行中的顺序批量任务状态
# 结构: { kb_id: {"status": "running/completed/failed", "total": N, "current": M, "message": "...", "start_time": timestamp} }
SEQUENTIAL_BATCH_TASKS = {}
class KnowledgebaseService:
@classmethod
def _get_db_connection(cls):
"""创建数据库连接"""
return mysql.connector.connect(**DB_CONFIG)
@classmethod
def get_knowledgebase_list(cls, page=1, size=10, name="", sort_by="create_time", sort_order="desc"):
"""获取知识库列表"""
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
# 验证排序字段
valid_sort_fields = ["name", "create_time", "create_date"]
if sort_by not in valid_sort_fields:
sort_by = "create_time"
# 构建排序子句
sort_clause = f"ORDER BY k.{sort_by} {sort_order.upper()}"
query = """
SELECT
k.id,
k.name,
k.description,
k.create_date,
k.update_date,
k.doc_num,
k.language,
k.permission
FROM knowledgebase k
"""
params = []
if name:
query += " WHERE k.name LIKE %s"
params.append(f"%{name}%")
# 添加查询排序条件
query += f" {sort_clause}"
query += " LIMIT %s OFFSET %s"
params.extend([size, (page - 1) * size])
cursor.execute(query, params)
results = cursor.fetchall()
# 处理结果
for result in results:
# 处理空描述
if not result.get("description"):
result["description"] = "暂无描述"
# 处理时间格式
if result.get("create_date"):
if isinstance(result["create_date"], datetime):
result["create_date"] = result["create_date"].strftime("%Y-%m-%d %H:%M:%S")
elif isinstance(result["create_date"], str):
try:
# 尝试解析已有字符串格式
datetime.strptime(result["create_date"], "%Y-%m-%d %H:%M:%S")
except ValueError:
result["create_date"] = ""
# 获取总数
count_query = "SELECT COUNT(*) as total FROM knowledgebase"
if name:
count_query += " WHERE name LIKE %s"
cursor.execute(count_query, params[:1] if name else [])
total = cursor.fetchone()["total"]
cursor.close()
conn.close()
return {"list": results, "total": total}
@classmethod
def get_knowledgebase_detail(cls, kb_id):
"""获取知识库详情"""
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
query = """
SELECT
k.id,
k.name,
k.description,
k.create_date,
k.update_date,
k.doc_num,
k.avatar
FROM knowledgebase k
WHERE k.id = %s
"""
cursor.execute(query, (kb_id,))
result = cursor.fetchone()
if result:
# 处理空描述
if not result.get("description"):
result["description"] = "暂无描述"
# 处理时间格式
if result.get("create_date"):
if isinstance(result["create_date"], datetime):
result["create_date"] = result["create_date"].strftime("%Y-%m-%d %H:%M:%S")
elif isinstance(result["create_date"], str):
try:
datetime.strptime(result["create_date"], "%Y-%m-%d %H:%M:%S")
except ValueError:
result["create_date"] = ""
cursor.close()
conn.close()
return result
@classmethod
def _check_name_exists(cls, name):
"""检查知识库名称是否已存在"""
conn = cls._get_db_connection()
cursor = conn.cursor()
query = """
SELECT COUNT(*) as count
FROM knowledgebase
WHERE name = %s
"""
cursor.execute(query, (name,))
result = cursor.fetchone()
cursor.close()
conn.close()
return result[0] > 0
@classmethod
def create_knowledgebase(cls, **data):
"""创建知识库"""
try:
# 检查知识库名称是否已存在
exists = cls._check_name_exists(data["name"])
if exists:
raise Exception("知识库名称已存在")
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
# 使用传入的 creator_id 作为 tenant_id 和 created_by
tenant_id = data.get("creator_id")
created_by = data.get("creator_id")
if not tenant_id:
# 如果没有提供 creator_id则使用默认值
print("未提供 creator_id尝试获取最早用户 ID")
try:
query_earliest_user = """
SELECT id FROM user
WHERE create_time = (SELECT MIN(create_time) FROM user)
LIMIT 1
"""
cursor.execute(query_earliest_user)
earliest_user = cursor.fetchone()
if earliest_user:
tenant_id = earliest_user["id"]
created_by = earliest_user["id"]
print(f"使用创建时间最早的用户ID作为tenant_id和created_by: {tenant_id}")
else:
# 如果找不到用户,使用默认值
tenant_id = "system"
created_by = "system"
print(f"未找到用户, 使用默认值作为tenant_id和created_by: {tenant_id}")
except Exception as e:
print(f"获取用户ID失败: {str(e)},使用默认值")
tenant_id = "system"
created_by = "system"
else:
print(f"使用传入的 creator_id 作为 tenant_id 和 created_by: {tenant_id}")
# --- 获取动态 embd_id ---
dynamic_embd_id = None
default_embd_id = "bge-m3" # Fallback default
try:
query_embedding_model = """
SELECT llm_name
FROM tenant_llm
WHERE model_type = 'embedding'
ORDER BY create_time DESC
LIMIT 1
"""
cursor.execute(query_embedding_model)
embedding_model = cursor.fetchone()
if embedding_model and embedding_model.get("llm_name"):
dynamic_embd_id = embedding_model["llm_name"]
# 对硅基流动平台进行特异性处理
if dynamic_embd_id == "netease-youdao/bce-embedding-base_v1":
dynamic_embd_id = "BAAI/bge-m3"
print(f"动态获取到的 embedding 模型 ID: {dynamic_embd_id}")
else:
dynamic_embd_id = default_embd_id
print(f"未在 tenant_llm 表中找到 embedding 模型, 使用默认值: {dynamic_embd_id}")
except Exception as e:
dynamic_embd_id = default_embd_id
print(f"查询 embedding 模型失败: {str(e)},使用默认值: {dynamic_embd_id}")
traceback.print_exc() # Log the full traceback for debugging
current_time = datetime.now()
create_date = current_time.strftime("%Y-%m-%d %H:%M:%S")
create_time = int(current_time.timestamp() * 1000) # 毫秒级时间戳
update_date = create_date
update_time = create_time
# 完整的字段列表
query = """
INSERT INTO knowledgebase (
id, create_time, create_date, update_time, update_date,
avatar, tenant_id, name, language, description,
embd_id, permission, created_by, doc_num, token_num,
chunk_num, similarity_threshold, vector_similarity_weight, parser_id, parser_config,
pagerank, status
) VALUES (
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s
)
"""
# 设置默认值
default_parser_config = json.dumps(
{
"layout_recognize": "MinerU",
"chunk_token_num": 512,
"delimiter": "\n!?;。;!?",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False,
"raptor": {"use_raptor": False},
"graphrag": {"use_graphrag": False},
}
)
kb_id = generate_uuid()
cursor.execute(
query,
(
kb_id, # id
create_time, # create_time
create_date, # create_date
update_time, # update_time
update_date, # update_date
None, # avatar
tenant_id, # tenant_id
data["name"], # name
data.get("language", "Chinese"), # language
data.get("description", ""), # description
dynamic_embd_id, # embd_id
data.get("permission", "me"), # permission
created_by, # created_by - 使用内部获取的值
0, # doc_num
0, # token_num
0, # chunk_num
0.7, # similarity_threshold
0.3, # vector_similarity_weight
"naive", # parser_id
default_parser_config, # parser_config
0, # pagerank
"1", # status
),
)
conn.commit()
cursor.close()
conn.close()
# 返回创建后的知识库详情
return cls.get_knowledgebase_detail(kb_id)
except Exception as e:
print(f"创建知识库失败: {str(e)}")
raise Exception(f"创建知识库失败: {str(e)}")
@classmethod
def update_knowledgebase(cls, kb_id, **data):
"""更新知识库"""
try:
# 直接通过ID检查知识库是否存在
kb = cls.get_knowledgebase_detail(kb_id)
if not kb:
return None
conn = cls._get_db_connection()
cursor = conn.cursor()
# 如果要更新名称,先检查名称是否已存在
if data.get("name") and data["name"] != kb["name"]:
exists = cls._check_name_exists(data["name"])
if exists:
raise Exception("知识库名称已存在")
# 构建更新语句
update_fields = []
params = []
if data.get("name"):
update_fields.append("name = %s")
params.append(data["name"])
if "description" in data:
update_fields.append("description = %s")
params.append(data["description"])
if "permission" in data:
update_fields.append("permission = %s")
params.append(data["permission"])
if "avatar" in data and data["avatar"]:
avatar_base64 = data["avatar"]
# 拼接上前缀
full_avatar_url = f"data:image/png;base64,{avatar_base64}"
update_fields.append("avatar = %s")
params.append(full_avatar_url)
# 更新时间
current_time = datetime.now()
update_date = current_time.strftime("%Y-%m-%d %H:%M:%S")
update_fields.append("update_date = %s")
params.append(update_date)
# 如果没有要更新的字段,直接返回
if not update_fields:
return kb_id
# 构建并执行更新语句
query = f"""
UPDATE knowledgebase
SET {", ".join(update_fields)}
WHERE id = %s
"""
params.append(kb_id)
cursor.execute(query, params)
conn.commit()
cursor.close()
conn.close()
# 返回更新后的知识库详情
return cls.get_knowledgebase_detail(kb_id)
except Exception as e:
print(f"更新知识库失败: {str(e)}")
raise Exception(f"更新知识库失败: {str(e)}")
@classmethod
def delete_knowledgebase(cls, kb_id):
"""删除知识库"""
try:
conn = cls._get_db_connection()
cursor = conn.cursor()
# 先检查知识库是否存在
check_query = "SELECT id FROM knowledgebase WHERE id = %s"
cursor.execute(check_query, (kb_id,))
if not cursor.fetchone():
raise Exception("知识库不存在")
# 执行删除
delete_query = "DELETE FROM knowledgebase WHERE id = %s"
cursor.execute(delete_query, (kb_id,))
conn.commit()
cursor.close()
conn.close()
return True
except Exception as e:
print(f"删除知识库失败: {str(e)}")
raise Exception(f"删除知识库失败: {str(e)}")
@classmethod
def batch_delete_knowledgebase(cls, kb_ids):
"""批量删除知识库"""
try:
conn = cls._get_db_connection()
cursor = conn.cursor()
# 检查所有ID是否存在
check_query = "SELECT id FROM knowledgebase WHERE id IN (%s)" % ",".join(["%s"] * len(kb_ids))
cursor.execute(check_query, kb_ids)
existing_ids = [row[0] for row in cursor.fetchall()]
if len(existing_ids) != len(kb_ids):
missing_ids = set(kb_ids) - set(existing_ids)
raise Exception(f"以下知识库不存在: {', '.join(missing_ids)}")
# 执行批量删除
delete_query = "DELETE FROM knowledgebase WHERE id IN (%s)" % ",".join(["%s"] * len(kb_ids))
cursor.execute(delete_query, kb_ids)
conn.commit()
cursor.close()
conn.close()
return len(kb_ids)
except Exception as e:
print(f"批量删除知识库失败: {str(e)}")
raise Exception(f"批量删除知识库失败: {str(e)}")
@classmethod
def get_knowledgebase_documents(cls, kb_id, page=1, size=10, name="", sort_by="create_time", sort_order="desc"):
"""获取知识库下的文档列表"""
try:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
# 先检查知识库是否存在
check_query = "SELECT id FROM knowledgebase WHERE id = %s"
cursor.execute(check_query, (kb_id,))
if not cursor.fetchone():
raise Exception("知识库不存在")
# 验证排序字段
valid_sort_fields = ["name", "size", "create_time", "create_date"]
if sort_by not in valid_sort_fields:
sort_by = "create_time"
# 构建排序子句
sort_clause = f"ORDER BY d.{sort_by} {sort_order.upper()}"
# 查询文档列表
query = """
SELECT
d.id,
d.name,
d.chunk_num,
d.create_date,
d.status,
d.run,
d.progress,
d.parser_id,
d.parser_config,
d.meta_fields
FROM document d
WHERE d.kb_id = %s
"""
params = [kb_id]
if name:
query += " AND d.name LIKE %s"
params.append(f"%{name}%")
# 添加查询排序条件
query += f" {sort_clause}"
query += " LIMIT %s OFFSET %s"
params.extend([size, (page - 1) * size])
cursor.execute(query, params)
results = cursor.fetchall()
# 处理日期时间格式
for result in results:
if result.get("create_date"):
result["create_date"] = result["create_date"].strftime("%Y-%m-%d %H:%M:%S")
# 获取总数
count_query = "SELECT COUNT(*) as total FROM document WHERE kb_id = %s"
count_params = [kb_id]
if name:
count_query += " AND name LIKE %s"
count_params.append(f"%{name}%")
cursor.execute(count_query, count_params)
total = cursor.fetchone()["total"]
cursor.close()
conn.close()
return {"list": results, "total": total}
except Exception as e:
print(f"获取知识库文档列表失败: {str(e)}")
raise Exception(f"获取知识库文档列表失败: {str(e)}")
@classmethod
def add_documents_to_knowledgebase(cls, kb_id, file_ids, created_by=None):
"""添加文档到知识库"""
try:
print(f"[DEBUG] 开始添加文档,参数: kb_id={kb_id}, file_ids={file_ids}")
# 如果没有传入created_by则获取最早的用户ID
if created_by is None:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
# 查询创建时间最早的用户ID
query_earliest_user = """
SELECT id FROM user
WHERE create_time = (SELECT MIN(create_time) FROM user)
LIMIT 1
"""
cursor.execute(query_earliest_user)
earliest_user = cursor.fetchone()
if earliest_user:
created_by = earliest_user["id"]
print(f"使用创建时间最早的用户ID: {created_by}")
else:
created_by = "system"
print("未找到用户, 使用默认用户ID: system")
cursor.close()
conn.close()
# 检查知识库是否存在
kb = cls.get_knowledgebase_detail(kb_id)
print(f"[DEBUG] 知识库检查结果: {kb}")
if not kb:
print(f"[ERROR] 知识库不存在: {kb_id}")
raise Exception("知识库不存在")
conn = cls._get_db_connection()
cursor = conn.cursor()
# 获取文件信息
file_query = """
SELECT id, name, location, size, type
FROM file
WHERE id IN (%s)
""" % ",".join(["%s"] * len(file_ids))
print(f"[DEBUG] 执行文件查询SQL: {file_query}")
print(f"[DEBUG] 查询参数: {file_ids}")
try:
cursor.execute(file_query, file_ids)
files = cursor.fetchall()
print(f"[DEBUG] 查询到的文件数据: {files}")
except Exception as e:
print(f"[ERROR] 文件查询失败: {str(e)}")
raise
if len(files) != len(file_ids):
print(f"部分文件不存在: 期望={len(file_ids)}, 实际={len(files)}")
raise Exception("部分文件不存在")
# 添加文档记录
added_count = 0
for file in files:
file_id = file[0]
file_name = file[1]
print(f"处理文件: id={file_id}, name={file_name}")
file_location = file[2]
file_size = file[3]
file_type = file[4]
# 检查文档是否已存在于知识库
check_query = """
SELECT COUNT(*)
FROM document d
JOIN file2document f2d ON d.id = f2d.document_id
WHERE d.kb_id = %s AND f2d.file_id = %s
"""
cursor.execute(check_query, (kb_id, file_id))
exists = cursor.fetchone()[0] > 0
if exists:
continue # 跳过已存在的文档
# 创建文档记录
doc_id = generate_uuid()
current_datetime = datetime.now()
create_time = int(current_datetime.timestamp() * 1000) # 毫秒级时间戳
current_date = current_datetime.strftime("%Y-%m-%d %H:%M:%S") # 格式化日期字符串
# 设置默认值
default_parser_id = "naive"
default_parser_config = json.dumps(
{
"layout_recognize": "MinerU",
"chunk_token_num": 512,
"delimiter": "\n!?;。;!?",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False,
"raptor": {"use_raptor": False},
"graphrag": {"use_graphrag": False},
}
)
default_source_type = "local"
# 插入document表
doc_query = """
INSERT INTO document (
id, create_time, create_date, update_time, update_date,
thumbnail, kb_id, parser_id, parser_config, source_type,
type, created_by, name, location, size,
token_num, chunk_num, progress, progress_msg, process_begin_at,
process_duation, meta_fields, run, status
) VALUES (
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s
)
"""
doc_params = [
doc_id,
create_time,
current_date,
create_time,
current_date, # ID和时间
None,
kb_id,
default_parser_id,
default_parser_config,
default_source_type, # thumbnail到source_type
file_type,
created_by,
file_name,
file_location,
file_size, # type到size
0,
0,
0.0,
None,
None, # token_num到process_begin_at
0.0,
None,
"0",
"1", # process_duation到status
]
cursor.execute(doc_query, doc_params)
# 创建文件到文档的映射
f2d_id = generate_uuid()
f2d_query = """
INSERT INTO file2document (
id, create_time, create_date, update_time, update_date,
file_id, document_id
) VALUES (
%s, %s, %s, %s, %s,
%s, %s
)
"""
f2d_params = [f2d_id, create_time, current_date, create_time, current_date, file_id, doc_id]
cursor.execute(f2d_query, f2d_params)
added_count += 1
# 更新知识库文档数量
if added_count > 0:
try:
update_query = """
UPDATE knowledgebase
SET doc_num = doc_num + %s,
update_date = %s
WHERE id = %s
"""
cursor.execute(update_query, (added_count, current_date, kb_id))
conn.commit() # 先提交更新操作
except Exception as e:
print(f"[WARNING] 更新知识库文档数量失败,但文档已添加: {str(e)}")
cursor.close()
conn.close()
return {"added_count": added_count}
except Exception as e:
print(f"[ERROR] 添加文档失败: {str(e)}")
print(f"[ERROR] 错误类型: {type(e)}")
import traceback
print(f"[ERROR] 堆栈信息: {traceback.format_exc()}")
raise Exception(f"添加文档到知识库失败: {str(e)}")
@classmethod
def delete_document(cls, doc_id):
"""删除文档"""
try:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
# 先检查文档是否存在
check_query = """
SELECT
d.kb_id,
kb.created_by AS tenant_id -- 获取 tenant_id (knowledgebase的创建者)
FROM document d
JOIN knowledgebase kb ON d.kb_id = kb.id -- JOIN knowledgebase 表
WHERE d.id = %s
"""
cursor.execute(check_query, (doc_id,))
doc_data = cursor.fetchone()
if not doc_data:
print(f"[INFO] 文档 {doc_id} 在数据库中未找到。")
return False
kb_id = doc_data["kb_id"]
# 删除文件到文档的映射
f2d_query = "DELETE FROM file2document WHERE document_id = %s"
cursor.execute(f2d_query, (doc_id,))
# 删除文档
doc_query = "DELETE FROM document WHERE id = %s"
cursor.execute(doc_query, (doc_id,))
# 更新知识库文档数量
update_query = """
UPDATE knowledgebase
SET doc_num = doc_num - 1,
update_date = %s
WHERE id = %s
"""
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
cursor.execute(update_query, (current_date, kb_id))
conn.commit()
cursor.close()
conn.close()
es_client = get_es_client()
tenant_id_for_cleanup = doc_data["tenant_id"]
# 删除 Elasticsearch 中的相关文档块
if es_client and tenant_id_for_cleanup:
es_index_name = f"ragflow_{tenant_id_for_cleanup}"
try:
if es_client.indices.exists(index=es_index_name):
query_body = {"query": {"term": {"doc_id": doc_id}}}
resp = es_client.delete_by_query(
index=es_index_name,
body=query_body,
refresh=True, # 确保立即生效
ignore_unavailable=True, # 如果索引在此期间被删除
)
deleted_count = resp.get("deleted", 0)
print(f"[ES-SUCCESS] 从索引 {es_index_name} 中删除 {deleted_count} 个与 doc_id {doc_id} 相关的块。")
else:
print(f"[ES-INFO] 索引 {es_index_name} 不存在,跳过 ES 清理 for doc_id {doc_id}")
except Exception as es_err:
print(f"[ES-ERROR] 清理 ES 块 for doc_id {doc_id} (index {es_index_name}) 失败: {str(es_err)}")
return True
except Exception as e:
print(f"[ERROR] 删除文档失败: {str(e)}")
raise Exception(f"删除文档失败: {str(e)}")
@classmethod
def parse_document(cls, doc_id):
"""解析文档"""
conn = None
cursor = None
try:
# 获取文档和文件信息
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
# 查询文档信息
doc_query = """
SELECT d.id, d.name, d.location, d.type, d.kb_id, d.parser_id, d.parser_config, d.created_by
FROM document d
WHERE d.id = %s
"""
cursor.execute(doc_query, (doc_id,))
doc_info = cursor.fetchone()
if not doc_info:
raise Exception("文档不存在")
# 获取关联的文件信息 (主要是 parent_id 作为 bucket_name)
f2d_query = "SELECT file_id FROM file2document WHERE document_id = %s"
cursor.execute(f2d_query, (doc_id,))
f2d_result = cursor.fetchone()
if not f2d_result:
raise Exception("无法找到文件到文档的映射关系")
file_id = f2d_result["file_id"]
file_query = "SELECT parent_id FROM file WHERE id = %s"
cursor.execute(file_query, (file_id,))
file_info = cursor.fetchone()
if not file_info:
raise Exception("无法找到文件记录")
# 获取知识库创建人信息
# 根据doc_id查询document这张表得到kb_id
kb_id_query = "SELECT kb_id FROM document WHERE id = %s"
cursor.execute(kb_id_query, (doc_id,))
kb_id = cursor.fetchone()
# 根据kb_id查询knowledgebase这张表得到created_by
kb_query = "SELECT created_by FROM knowledgebase WHERE id = %s"
cursor.execute(kb_query, (kb_id["kb_id"],))
kb_info = cursor.fetchone()
cursor.close()
conn.close()
conn = None # 确保连接已关闭
# 更新文档状态为处理中 (使用 parser 模块的函数)
_update_document_progress(doc_id, status="2", run="1", progress=0.0, message="开始解析")
# 调用后台解析函数
embedding_config = cls.get_system_embedding_config()
parse_result = perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info)
# 返回解析结果
return parse_result
except Exception as e:
print(f"文档解析启动或执行过程中出错 (Doc ID: {doc_id}): {str(e)}")
# 确保在异常时更新状态为失败
try:
_update_document_progress(doc_id, status="1", run="0", message=f"解析失败: {str(e)}")
except Exception as update_err:
print(f"更新文档失败状态时出错 (Doc ID: {doc_id}): {str(update_err)}")
# raise Exception(f"文档解析失败: {str(e)}")
return {"success": False, "error": f"文档解析失败: {str(e)}"}
finally:
if cursor:
cursor.close()
if conn:
conn.close()
@classmethod
def async_parse_document(cls, doc_id):
"""异步解析文档"""
try:
# 启动后台线程执行同步的 parse_document 方法
thread = threading.Thread(target=cls.parse_document, args=(doc_id,))
thread.daemon = True # 设置为守护线程,主程序退出时线程也退出
thread.start()
# 立即返回,表示任务已提交
return {
"task_id": doc_id, # 使用 doc_id 作为任务标识符
"status": "processing",
"message": "文档解析任务已提交到后台处理",
}
except Exception as e:
print(f"启动异步解析任务失败 (Doc ID: {doc_id}): {str(e)}")
try:
_update_document_progress(doc_id, status="1", run="0", message=f"启动解析失败: {str(e)}")
except Exception as update_err:
print(f"更新文档启动失败状态时出错 (Doc ID: {doc_id}): {str(update_err)}")
raise Exception(f"启动异步解析任务失败: {str(e)}")
@classmethod
def get_document_parse_progress(cls, doc_id):
"""获取文档解析进度"""
conn = None
cursor = None
try:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
query = """
SELECT progress, progress_msg, status, run
FROM document
WHERE id = %s
"""
cursor.execute(query, (doc_id,))
result = cursor.fetchone()
if not result:
return {"error": "文档不存在"}
# 确保 progress 是浮点数
progress_value = 0.0
if result.get("progress") is not None:
try:
progress_value = float(result["progress"])
except (ValueError, TypeError):
progress_value = 0.0 # 或记录错误
return {
"progress": progress_value,
"message": result.get("progress_msg", ""),
"status": result.get("status", "0"),
"running": result.get("run", "0"),
}
except Exception as e:
print(f"获取文档进度失败 (Doc ID: {doc_id}): {str(e)}")
return {"error": f"获取进度失败: {str(e)}"}
finally:
if cursor:
cursor.close()
if conn:
conn.close()
# --- 获取最早用户 ID ---
@classmethod
def _get_earliest_user_tenant_id(cls):
"""获取创建时间最早的用户的 ID (作为 tenant_id)"""
conn = None
cursor = None
try:
conn = cls._get_db_connection()
cursor = conn.cursor()
query = "SELECT id FROM user ORDER BY create_time ASC LIMIT 1"
cursor.execute(query)
result = cursor.fetchone()
if result:
return result[0] # 返回用户 ID
else:
print("警告: 数据库中没有用户!")
return None
except Exception as e:
print(f"查询最早用户时出错: {e}")
traceback.print_exc()
return None
finally:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()
# --- 测试 Embedding 连接 ---
@classmethod
def _test_embedding_connection(cls, base_url, model_name, api_key):
"""
测试与自定义 Embedding 模型的连接 (使用 requests)。
"""
print(f"开始测试连接: base_url={base_url}, model_name={model_name}")
try:
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
payload = {"input": ["Test connection"], "model": model_name}
if not base_url.startswith(("http://", "https://")):
base_url = "http://" + base_url
if not base_url.endswith("/"):
base_url += "/"
# --- URL 拼接优化 ---
endpoint_segment = "embeddings"
full_endpoint_path = "v1/embeddings"
# 移除末尾斜杠以方便判断
normalized_base_url = base_url.rstrip("/")
if normalized_base_url.endswith("/v1"):
# 如果 base_url 已经是 http://host/v1 形式
current_test_url = normalized_base_url + "/" + endpoint_segment
elif normalized_base_url.endswith("/embeddings"):
# 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API无需再进行处理)
current_test_url = normalized_base_url
else:
# 如果 base_url 是 http://host 或 http://host/api 形式
current_test_url = normalized_base_url + "/" + full_endpoint_path
# --- 结束 URL 拼接优化 ---
print(f"尝试请求 URL: {current_test_url}")
try:
response = requests.post(current_test_url, headers=headers, json=payload, timeout=15)
print(f"请求 {current_test_url} 返回状态码: {response.status_code}")
if response.status_code == 200:
res_json = response.json()
if (
"data" in res_json and isinstance(res_json["data"], list) and len(res_json["data"]) > 0 and "embedding" in res_json["data"][0] and len(res_json["data"][0]["embedding"]) > 0
) or (isinstance(res_json, list) and len(res_json) > 0 and isinstance(res_json[0], list) and len(res_json[0]) > 0):
print(f"连接测试成功: {current_test_url}")
return True, "连接成功"
else:
print(f"连接成功但响应格式不正确于 {current_test_url}")
except Exception as json_e:
print(f"解析 JSON 响应失败于 {current_test_url}: {json_e}")
return False, "连接失败: 响应错误"
except Exception as e:
print(f"连接测试发生未知错误: {str(e)}")
traceback.print_exc()
return False, f"测试时发生未知错误: {str(e)}"
# --- 获取系统 Embedding 配置 ---
@classmethod
def get_system_embedding_config(cls):
"""获取系统级(最早用户)的 Embedding 配置"""
conn = None
cursor = None
try:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True) # 使用字典游标方便访问列名
# 1. 找到最早创建的用户ID
query_earliest_user = """
SELECT id FROM user
ORDER BY create_time ASC
LIMIT 1
"""
cursor.execute(query_earliest_user)
earliest_user = cursor.fetchone()
if not earliest_user:
# 如果没有用户,返回空配置
return {"llm_name": "", "api_key": "", "api_base": ""}
earliest_user_id = earliest_user["id"]
# 2. 根据最早用户ID查询 tenant_llm 表中 model_type 为 embedding 的配置
query_embedding_config = """
SELECT llm_name, api_key, api_base
FROM tenant_llm
WHERE tenant_id = %s AND model_type = 'embedding'
ORDER BY create_time DESC # 如果一个用户可能有多个embedding配置取最早的
LIMIT 1
"""
cursor.execute(query_embedding_config, (earliest_user_id,))
config = cursor.fetchone()
if config:
llm_name = config.get("llm_name", "")
api_key = config.get("api_key", "")
api_base = config.get("api_base", "")
# 对模型名称进行处理 (可选,根据需要保留或移除)
if llm_name and "___" in llm_name:
llm_name = llm_name.split("___")[0]
# (对硅基流动平台进行特异性处理)
if llm_name == "netease-youdao/bce-embedding-base_v1":
llm_name = "BAAI/bge-m3"
# 如果 API 基础地址为空字符串,设置为硅基流动嵌入模型的 API 地址
if api_base == "":
api_base = "https://api.siliconflow.cn/v1/embeddings"
# 如果有配置,返回
return {"llm_name": llm_name, "api_key": api_key, "api_base": api_base}
else:
# 如果最早的用户没有 embedding 配置,返回空
return {"llm_name": "", "api_key": "", "api_base": ""}
except Exception as e:
print(f"获取系统 Embedding 配置时出错: {e}")
traceback.print_exc()
# 保持原有的异常处理逻辑,向上抛出,让调用者处理
raise Exception(f"获取配置时数据库出错: {e}")
finally:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()
# --- 设置系统 Embedding 配置 ---
@classmethod
def set_system_embedding_config(cls, llm_name, api_base, api_key):
"""设置系统级(最早用户)的 Embedding 配置"""
tenant_id = cls._get_earliest_user_tenant_id()
if not tenant_id:
raise Exception("无法找到系统基础用户")
print(f"开始设置系统 Embedding 配置: {llm_name}, {api_base}, {api_key}")
# 执行连接测试
is_connected, message = cls._test_embedding_connection(base_url=api_base, model_name=llm_name, api_key=api_key)
if not is_connected:
# 返回具体的测试失败原因给调用者(路由层)处理
return False, f"连接测试失败: {message}"
return True, f"连接成功: {message}"
# 顺序批量解析 (核心逻辑,在后台线程运行)
@classmethod
def _run_sequential_batch_parse(cls, kb_id):
"""顺序执行批量解析,并在 SEQUENTIAL_BATCH_TASKS 中更新状态"""
global SEQUENTIAL_BATCH_TASKS
task_info = SEQUENTIAL_BATCH_TASKS.get(kb_id)
if not task_info:
print(f"[Seq Batch ERROR] Task info for KB {kb_id} not found at start.")
return # 理论上不应发生
conn = None
cursor = None
parsed_count = 0
failed_count = 0
total_count = 0
try:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
# 查询需要解析的文档
query = """
SELECT id, name FROM document
WHERE kb_id = %s AND run != '3'
"""
cursor.execute(query, (kb_id,))
documents_to_parse = cursor.fetchall()
total_count = len(documents_to_parse)
# 更新任务总数
task_info["total"] = total_count
task_info["status"] = "running"
task_info["message"] = f"共找到 {total_count} 个文档待解析。"
task_info["start_time"] = time.time()
start_time = time.time()
SEQUENTIAL_BATCH_TASKS[kb_id] = task_info # 更新字典
if not documents_to_parse:
task_info["status"] = "completed"
task_info["message"] = "没有需要解析的文档。"
SEQUENTIAL_BATCH_TASKS[kb_id] = task_info
print(f"[Seq Batch] KB {kb_id}: 没有需要解析的文档。")
return
print(f"[Seq Batch] KB {kb_id}: 开始顺序解析 {total_count} 个文档...")
# 按顺序解析每个文档
for i, doc in enumerate(documents_to_parse):
doc_id = doc["id"]
doc_name = doc["name"]
# 更新当前进度
task_info["current"] = i + 1
task_info["message"] = f"正在解析: {doc_name} ({i + 1}/{total_count})"
SEQUENTIAL_BATCH_TASKS[kb_id] = task_info
print(f"[Seq Batch] KB {kb_id}: ({i + 1}/{total_count}) Parsing {doc_name} (ID: {doc_id})...")
try:
# 调用同步的 parse_document 方法
# 这个方法内部会处理单个文档的状态更新 (run, status)
result = cls.parse_document(doc_id)
if result and result.get("success"):
parsed_count += 1
print(f"[Seq Batch] KB {kb_id}: Document {doc_id} parsed successfully.")
else:
failed_count += 1
error_msg = result.get("message", "未知错误") if result else "未知错误"
print(f"[Seq Batch] KB {kb_id}: Document {doc_id} parsing failed: {error_msg}")
except Exception as e:
failed_count += 1
print(f"[Seq Batch ERROR] KB {kb_id}: Error calling parse_document for {doc_id}: {str(e)}")
traceback.print_exc()
# 更新文档状态为失败
try:
_update_document_progress(doc_id, status="1", run="0", progress=0.0, message=f"批量任务中解析失败: {str(e)[:255]}")
except Exception as update_err:
print(f"[Service-ERROR] 更新文档 {doc_id} 失败状态时出错: {str(update_err)}")
# 任务完成
end_time = time.time()
duration = round(end_time - task_info.get("start_time", start_time), 2)
final_message = f"批量顺序解析完成。总计 {total_count} 个,成功 {parsed_count} 个,失败 {failed_count} 个。耗时 {duration} 秒。"
task_info["status"] = "completed"
task_info["message"] = final_message
task_info["current"] = total_count
SEQUENTIAL_BATCH_TASKS[kb_id] = task_info
print(f"[Seq Batch] KB {kb_id}: {final_message}")
except Exception as e:
# 任务执行中发生严重错误
error_message = f"批量顺序解析过程中发生严重错误: {str(e)}"
print(f"[Seq Batch ERROR] KB {kb_id}: {error_message}")
traceback.print_exc()
task_info["status"] = "failed"
task_info["message"] = error_message
SEQUENTIAL_BATCH_TASKS[kb_id] = task_info
finally:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()
# 启动顺序批量解析 (异步请求)
@classmethod
def start_sequential_batch_parse_async(cls, kb_id):
"""异步启动知识库的顺序批量解析任务"""
global SEQUENTIAL_BATCH_TASKS
if kb_id in SEQUENTIAL_BATCH_TASKS and SEQUENTIAL_BATCH_TASKS[kb_id].get("status") == "running":
return {"success": False, "message": "该知识库的批量解析任务已在运行中。"}
# 初始化任务状态
start_time = time.time()
SEQUENTIAL_BATCH_TASKS[kb_id] = {"status": "starting", "total": 0, "current": 0, "message": "任务准备启动...", "start_time": start_time}
try:
# 启动后台线程执行顺序解析逻辑
thread = threading.Thread(target=cls._run_sequential_batch_parse, args=(kb_id,))
thread.daemon = True
thread.start()
print(f"[Seq Batch] KB {kb_id}: 已启动后台顺序解析线程。")
return {"success": True, "message": "顺序批量解析任务已启动。"}
except Exception as e:
error_message = f"启动顺序批量解析任务失败: {str(e)}"
print(f"[Seq Batch ERROR] KB {kb_id}: {error_message}")
traceback.print_exc()
# 更新任务状态为失败
SEQUENTIAL_BATCH_TASKS[kb_id] = {"status": "failed", "total": 0, "current": 0, "message": error_message, "start_time": start_time}
return {"success": False, "message": error_message}
# 获取顺序批量解析进度
@classmethod
def get_sequential_batch_parse_progress(cls, kb_id):
"""获取指定知识库的顺序批量解析任务进度"""
global SEQUENTIAL_BATCH_TASKS
task_info = SEQUENTIAL_BATCH_TASKS.get(kb_id)
if not task_info:
return {"status": "not_found", "message": "未找到该知识库的批量解析任务记录。"}
# 返回当前任务状态
return task_info
# 获取知识库所有文档状态 (用于刷新列表)
@classmethod
def get_knowledgebase_parse_progress(cls, kb_id):
"""获取指定知识库下所有文档的解析进度和状态 (保持不变)"""
conn = None
cursor = None
try:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
query = """
SELECT id, name, progress, progress_msg, status, run
FROM document
WHERE kb_id = %s
ORDER BY create_date DESC -- 或者其他排序方式
"""
cursor.execute(query, (kb_id,))
documents_status = cursor.fetchall()
# 处理 progress 确保是浮点数
for doc in documents_status:
progress_value = 0.0
if doc.get("progress") is not None:
try:
progress_value = float(doc["progress"])
except (ValueError, TypeError):
progress_value = 0.0
doc["progress"] = progress_value
# 确保其他字段存在,给予默认值
doc["progress_msg"] = doc.get("progress_msg", "")
doc["status"] = doc.get("status", "0")
doc["run"] = doc.get("run", "0")
return {"documents": documents_status}
except Exception as e:
print(f"获取知识库 {kb_id} 文档进度失败: {str(e)}")
traceback.print_exc()
return {"error": f"获取知识库文档进度失败: {str(e)}"}
finally:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()