refactor: 优化 Embedding URL 拼接逻辑,以兼容vllm和ollama等不同框架 (#50)

- 在 document_parser.py 和 service.py 中优化 Embedding URL 拼接逻辑,支持不同形式的 base_url
- 在 axios.ts 中将 400 错误消息从 "账号密码不正确" 更新为 "请求错误"
This commit is contained in:
zstar 2025-04-24 23:29:47 +08:00 committed by GitHub
parent 45bd222176
commit 51f4381a65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 65 additions and 23 deletions

View File

@ -203,13 +203,24 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
embedding_api_key = embedding_config.get("api_key") if embedding_config else None # 可能为 None 或空字符串 embedding_api_key = embedding_config.get("api_key") if embedding_config else None # 可能为 None 或空字符串
# 构建完整的 Embedding API URL # 构建完整的 Embedding API URL
embedding_url = None # 默认为 None
if embedding_api_base: if embedding_api_base:
# 确保 embedding_api_base 包含协议头 (http:// 或 https://)
if not embedding_api_base.startswith(('http://', 'https://')): if not embedding_api_base.startswith(('http://', 'https://')):
embedding_api_base = 'http://' + embedding_api_base embedding_api_base = 'http://' + embedding_api_base
# 标准端点是 /embeddings
embedding_url = embedding_api_base.rstrip('/') + "/embeddings" # --- URL 拼接优化 (处理 /v1) ---
else: endpoint_segment = "embeddings"
embedding_url = None # 如果没有配置 Base URL则无法请求 full_endpoint_path = "v1/embeddings"
# 移除末尾斜杠以方便判断
normalized_base_url = embedding_api_base.rstrip('/')
if normalized_base_url.endswith('/v1'):
# 如果 base_url 已经是 http://host/v1 形式
embedding_url = normalized_base_url + '/' + endpoint_segment
else:
# 如果 base_url 是 http://host 或 http://host/api 等其他形式
embedding_url = normalized_base_url + '/' + full_endpoint_path
print(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}") print(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}")

View File

@ -873,8 +873,20 @@ class KnowledgebaseService:
if not base_url.endswith('/'): if not base_url.endswith('/'):
base_url += '/' base_url += '/'
endpoint = "embeddings" # --- URL 拼接优化 ---
current_test_url = base_url + endpoint endpoint_segment = "embeddings"
full_endpoint_path = "v1/embeddings"
# 移除末尾斜杠以方便判断
normalized_base_url = base_url.rstrip('/')
if normalized_base_url.endswith('/v1'):
# 如果 base_url 已经是 http://host/v1 形式
current_test_url = normalized_base_url + '/' + endpoint_segment
else:
# 如果 base_url 是 http://host 或 http://host/api 形式
current_test_url = normalized_base_url + '/' + full_endpoint_path
# --- 结束 URL 拼接优化 ---
print(f"尝试请求 URL: {current_test_url}") print(f"尝试请求 URL: {current_test_url}")
try: try:
response = requests.post(current_test_url, headers=headers, json=payload, timeout=15) response = requests.post(current_test_url, headers=headers, json=payload, timeout=15)
@ -903,31 +915,47 @@ class KnowledgebaseService:
@classmethod @classmethod
def get_system_embedding_config(cls): def get_system_embedding_config(cls):
"""获取系统级(最早用户)的 Embedding 配置""" """获取系统级(最早用户)的 Embedding 配置"""
tenant_id = cls._get_earliest_user_tenant_id()
if not tenant_id:
raise Exception("无法找到系统基础用户") # 在服务层抛出异常
conn = None conn = None
cursor = None cursor = None
# TUDO: 修改查询逻辑
try: try:
conn = cls._get_db_connection() conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True) # 使用字典游标方便访问列名 cursor = conn.cursor(dictionary=True) # 使用字典游标方便访问列名
query = """
SELECT llm_name, api_key, api_base # 1. 找到最早创建的用户ID
FROM tenant_llm query_earliest_user = """
WHERE tenant_id = %s SELECT id FROM user
AND model type = 'embedding' ORDER BY create_time ASC
LIMIT 1 LIMIT 1
""" """
cursor.execute(query, (tenant_id,)) cursor.execute(query_earliest_user)
earliest_user = cursor.fetchone()
if not earliest_user:
# 如果没有用户,返回空配置
return {
"llm_name": "",
"api_key": "",
"api_base": ""
}
earliest_user_id = earliest_user['id']
# 2. 根据最早用户ID查询 tenant_llm 表中 model_type 为 embedding 的配置
query_embedding_config = """
SELECT llm_name, api_key, api_base
FROM tenant_llm
WHERE tenant_id = %s AND model_type = 'embedding'
ORDER BY create_time DESC # 如果一个用户可能有多个embedding配置取最早的
LIMIT 1
"""
cursor.execute(query_embedding_config, (earliest_user_id,))
config = cursor.fetchone() config = cursor.fetchone()
if config: if config:
llm_name = config.get("llm_name", "") llm_name = config.get("llm_name", "")
api_key = config.get("api_key", "") api_key = config.get("api_key", "")
api_base = config.get("api_base", "") api_base = config.get("api_base", "")
# 对模型名称进行处理 # 对模型名称进行处理 (可选,根据需要保留或移除)
if llm_name and '___' in llm_name: if llm_name and '___' in llm_name:
llm_name = llm_name.split('___')[0] llm_name = llm_name.split('___')[0]
# 如果有配置,返回 # 如果有配置,返回
@ -937,7 +965,7 @@ class KnowledgebaseService:
"api_base": api_base "api_base": api_base
} }
else: else:
# 如果没有配置,返回空 # 如果最早的用户没有 embedding 配置,返回空
return { return {
"llm_name": "", "llm_name": "",
"api_key": "", "api_key": "",
@ -946,13 +974,14 @@ class KnowledgebaseService:
except Exception as e: except Exception as e:
print(f"获取系统 Embedding 配置时出错: {e}") print(f"获取系统 Embedding 配置时出错: {e}")
traceback.print_exc() traceback.print_exc()
raise Exception(f"获取配置时数据库出错: {e}") # 重新抛出异常 # 保持原有的异常处理逻辑,向上抛出,让调用者处理
raise Exception(f"获取配置时数据库出错: {e}")
finally: finally:
if cursor: if cursor:
cursor.close() cursor.close()
if conn and conn.is_connected(): if conn and conn.is_connected():
conn.close() conn.close()
# --- 设置系统 Embedding 配置 --- # --- 设置系统 Embedding 配置 ---
@classmethod @classmethod
def set_system_embedding_config(cls, llm_name, api_base, api_key): def set_system_embedding_config(cls, llm_name, api_base, api_key):

View File

@ -56,7 +56,7 @@ function createInstance() {
const message = get(error, "response.data.message") const message = get(error, "response.data.message")
switch (status) { switch (status) {
case 400: case 400:
error.message = "账号密码不正确" error.message = "请求错误"
break break
case 401: case 401:
// Token 过期时 // Token 过期时

View File

@ -8,6 +8,7 @@ export {}
/* prettier-ignore */ /* prettier-ignore */
declare module 'vue' { declare module 'vue' {
export interface GlobalComponents { export interface GlobalComponents {
ElAlert: typeof import('element-plus/es')['ElAlert']
ElAside: typeof import('element-plus/es')['ElAside'] ElAside: typeof import('element-plus/es')['ElAside']
ElAvatar: typeof import('element-plus/es')['ElAvatar'] ElAvatar: typeof import('element-plus/es')['ElAvatar']
ElBacktop: typeof import('element-plus/es')['ElBacktop'] ElBacktop: typeof import('element-plus/es')['ElBacktop']
@ -49,6 +50,7 @@ declare module 'vue' {
ElTabs: typeof import('element-plus/es')['ElTabs'] ElTabs: typeof import('element-plus/es')['ElTabs']
ElTag: typeof import('element-plus/es')['ElTag'] ElTag: typeof import('element-plus/es')['ElTag']
ElTooltip: typeof import('element-plus/es')['ElTooltip'] ElTooltip: typeof import('element-plus/es')['ElTooltip']
ElUpload: typeof import('element-plus/es')['ElUpload']
RouterLink: typeof import('vue-router')['RouterLink'] RouterLink: typeof import('vue-router')['RouterLink']
RouterView: typeof import('vue-router')['RouterView'] RouterView: typeof import('vue-router')['RouterView']
} }