refactor: 优化 Embedding URL 拼接逻辑,以兼容vllm和ollama等不同框架 (#50)

- 在 document_parser.py 和 service.py 中优化 Embedding URL 拼接逻辑,支持不同形式的 base_url
- 在 axios.ts 中将 400 错误消息从 "账号密码不正确" 更新为 "请求错误"
This commit is contained in:
zstar 2025-04-24 23:29:47 +08:00 committed by GitHub
parent 45bd222176
commit 51f4381a65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 65 additions and 23 deletions

View File

@ -203,13 +203,24 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
embedding_api_key = embedding_config.get("api_key") if embedding_config else None # 可能为 None 或空字符串
# 构建完整的 Embedding API URL
embedding_url = None # 默认为 None
if embedding_api_base:
# 确保 embedding_api_base 包含协议头 (http:// 或 https://)
if not embedding_api_base.startswith(('http://', 'https://')):
embedding_api_base = 'http://' + embedding_api_base
# 标准端点是 /embeddings
embedding_url = embedding_api_base.rstrip('/') + "/embeddings"
else:
embedding_url = None # 如果没有配置 Base URL则无法请求
# --- URL 拼接优化 (处理 /v1) ---
endpoint_segment = "embeddings"
full_endpoint_path = "v1/embeddings"
# 移除末尾斜杠以方便判断
normalized_base_url = embedding_api_base.rstrip('/')
if normalized_base_url.endswith('/v1'):
# 如果 base_url 已经是 http://host/v1 形式
embedding_url = normalized_base_url + '/' + endpoint_segment
else:
# 如果 base_url 是 http://host 或 http://host/api 等其他形式
embedding_url = normalized_base_url + '/' + full_endpoint_path
print(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}")

View File

@ -873,8 +873,20 @@ class KnowledgebaseService:
if not base_url.endswith('/'):
base_url += '/'
endpoint = "embeddings"
current_test_url = base_url + endpoint
# --- URL 拼接优化 ---
endpoint_segment = "embeddings"
full_endpoint_path = "v1/embeddings"
# 移除末尾斜杠以方便判断
normalized_base_url = base_url.rstrip('/')
if normalized_base_url.endswith('/v1'):
# 如果 base_url 已经是 http://host/v1 形式
current_test_url = normalized_base_url + '/' + endpoint_segment
else:
# 如果 base_url 是 http://host 或 http://host/api 形式
current_test_url = normalized_base_url + '/' + full_endpoint_path
# --- 结束 URL 拼接优化 ---
print(f"尝试请求 URL: {current_test_url}")
try:
response = requests.post(current_test_url, headers=headers, json=payload, timeout=15)
@ -903,31 +915,47 @@ class KnowledgebaseService:
@classmethod
def get_system_embedding_config(cls):
"""获取系统级(最早用户)的 Embedding 配置"""
tenant_id = cls._get_earliest_user_tenant_id()
if not tenant_id:
raise Exception("无法找到系统基础用户") # 在服务层抛出异常
conn = None
cursor = None
# TUDO: 修改查询逻辑
try:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True) # 使用字典游标方便访问列名
query = """
SELECT llm_name, api_key, api_base
FROM tenant_llm
WHERE tenant_id = %s
AND model type = 'embedding'
# 1. 找到最早创建的用户ID
query_earliest_user = """
SELECT id FROM user
ORDER BY create_time ASC
LIMIT 1
"""
cursor.execute(query, (tenant_id,))
cursor.execute(query_earliest_user)
earliest_user = cursor.fetchone()
if not earliest_user:
# 如果没有用户,返回空配置
return {
"llm_name": "",
"api_key": "",
"api_base": ""
}
earliest_user_id = earliest_user['id']
# 2. 根据最早用户ID查询 tenant_llm 表中 model_type 为 embedding 的配置
query_embedding_config = """
SELECT llm_name, api_key, api_base
FROM tenant_llm
WHERE tenant_id = %s AND model_type = 'embedding'
ORDER BY create_time DESC # 如果一个用户可能有多个embedding配置取最早的
LIMIT 1
"""
cursor.execute(query_embedding_config, (earliest_user_id,))
config = cursor.fetchone()
if config:
llm_name = config.get("llm_name", "")
api_key = config.get("api_key", "")
api_base = config.get("api_base", "")
# 对模型名称进行处理
# 对模型名称进行处理 (可选,根据需要保留或移除)
if llm_name and '___' in llm_name:
llm_name = llm_name.split('___')[0]
# 如果有配置,返回
@ -937,7 +965,7 @@ class KnowledgebaseService:
"api_base": api_base
}
else:
# 如果没有配置,返回空
# 如果最早的用户没有 embedding 配置,返回空
return {
"llm_name": "",
"api_key": "",
@ -946,13 +974,14 @@ class KnowledgebaseService:
except Exception as e:
print(f"获取系统 Embedding 配置时出错: {e}")
traceback.print_exc()
raise Exception(f"获取配置时数据库出错: {e}") # 重新抛出异常
# 保持原有的异常处理逻辑,向上抛出,让调用者处理
raise Exception(f"获取配置时数据库出错: {e}")
finally:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()
# --- 设置系统 Embedding 配置 ---
@classmethod
def set_system_embedding_config(cls, llm_name, api_base, api_key):

View File

@ -56,7 +56,7 @@ function createInstance() {
const message = get(error, "response.data.message")
switch (status) {
case 400:
error.message = "账号密码不正确"
error.message = "请求错误"
break
case 401:
// Token 过期时

View File

@ -8,6 +8,7 @@ export {}
/* prettier-ignore */
declare module 'vue' {
export interface GlobalComponents {
ElAlert: typeof import('element-plus/es')['ElAlert']
ElAside: typeof import('element-plus/es')['ElAside']
ElAvatar: typeof import('element-plus/es')['ElAvatar']
ElBacktop: typeof import('element-plus/es')['ElBacktop']
@ -49,6 +50,7 @@ declare module 'vue' {
ElTabs: typeof import('element-plus/es')['ElTabs']
ElTag: typeof import('element-plus/es')['ElTag']
ElTooltip: typeof import('element-plus/es')['ElTooltip']
ElUpload: typeof import('element-plus/es')['ElUpload']
RouterLink: typeof import('vue-router')['RouterLink']
RouterView: typeof import('vue-router')['RouterView']
}