feat: 添加系统Embedding配置功能并优化文档解析 (#35)

在知识库模块中新增了获取和设置系统Embedding配置的API接口,支持动态配置Embedding模型的基础URL、模型名称和API Key。同时,优化了文档解析逻辑,使用系统配置的Embedding模型生成文本块的向量,并将图片与文本块关联存储。
This commit is contained in:
zstar 2025-04-18 22:34:25 +08:00 committed by GitHub
parent 61d924a4fa
commit 803cc7e656
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 686 additions and 73 deletions

View File

@ -162,10 +162,12 @@ def chat(dialog, messages, stream=True, **kwargs):
prompt_config["system"] = prompt_config["system"].replace(
"{%s}" % p["key"], " ")
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
else:
questions = questions[-1:]
# 不再使用多轮对话优化
# if len(questions) > 1 and prompt_config.get("refine_multiturn"):
# questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
# else:
# questions = questions[-1:]
questions = questions[-1:]
refine_question_ts = timer()
@ -188,40 +190,50 @@ def chat(dialog, messages, stream=True, **kwargs):
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
knowledges = []
if prompt_config.get("reasoning", False):
reasoner = DeepResearcher(chat_mdl,
prompt_config,
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3))
for think in reasoner.thinking(kbinfos, " ".join(questions)):
if isinstance(think, str):
thought = think
knowledges = [t for t in think.split("\n") if t]
elif stream:
yield think
else:
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs)
)
if prompt_config.get("tavily_api_key"):
tav = Tavily(prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(" ".join(questions))
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(" ".join(questions),
tenant_ids,
dialog.kb_ids,
embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
# 不再使用推理
# if prompt_config.get("reasoning", False):
# reasoner = DeepResearcher(chat_mdl,
# prompt_config,
# partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3))
knowledges = kb_prompt(kbinfos, max_tokens)
# for think in reasoner.thinking(kbinfos, " ".join(questions)):
# if isinstance(think, str):
# thought = think
# knowledges = [t for t in think.split("\n") if t]
# elif stream:
# yield think
# else:
# kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
# dialog.similarity_threshold,
# dialog.vector_similarity_weight,
# doc_ids=attachments,
# top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
# rank_feature=label_question(" ".join(questions), kbs)
# )
# if prompt_config.get("tavily_api_key"):
# tav = Tavily(prompt_config["tavily_api_key"])
# tav_res = tav.retrieve_chunks(" ".join(questions))
# kbinfos["chunks"].extend(tav_res["chunks"])
# kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
# if prompt_config.get("use_kg"):
# ck = settings.kg_retrievaler.retrieval(" ".join(questions),
# tenant_ids,
# dialog.kb_ids,
# embd_mdl,
# LLMBundle(dialog.tenant_id, LLMType.CHAT))
# if ck["content_with_weight"]:
# kbinfos["chunks"].insert(0, ck)
# knowledges = kb_prompt(kbinfos, max_tokens)
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs)
)
knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
@ -255,6 +267,7 @@ def chat(dialog, messages, stream=True, **kwargs):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
refs = []
image_markdowns = [] # 用于存储图片的 Markdown 字符串
ans = answer.split("</think>")
think = ""
if len(ans) == 2:
@ -262,6 +275,7 @@ def chat(dialog, messages, stream=True, **kwargs):
answer = ans[1]
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
cited_chunk_indices = set() # 用于存储被引用的 chunk 索引
if not re.search(r"##[0-9]+\$\$", answer):
answer, idx = retriever.insert_citations(answer,
[ck["content_ltks"]
@ -271,12 +285,34 @@ def chat(dialog, messages, stream=True, **kwargs):
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
cited_chunk_indices = idx # 获取 insert_citations 返回的索引
else:
idx = set([])
for r in re.finditer(r"##([0-9]+)\$\$", answer):
i = int(r.group(1))
if i < len(kbinfos["chunks"]):
idx.add(i)
cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引
# 根据引用的 chunk 索引提取图像信息并生成 Markdown
cited_doc_ids = set()
processed_image_urls = set() # 避免重复添加同一张图片
print(f"DEBUG: cited_chunk_indices = {cited_chunk_indices}")
for i in cited_chunk_indices:
i_int = int(i)
if i_int < len(kbinfos["chunks"]):
chunk = kbinfos["chunks"][i_int]
cited_doc_ids.add(chunk["doc_id"])
print(f"DEBUG: chunk = {chunk}")
# 检查 chunk 是否有关联的 image_id (URL) 且未被处理过
print(f"DEBUG: chunk_id={chunk.get('chunk_id', i_int)}, image_id={chunk.get('image_id')}")
img_url = chunk.get("image_id")
if img_url and img_url not in processed_image_urls:
# 生成 Markdown 字符串alt text 可以简单设为 "image" 或 chunk ID
alt_text = f"image_chunk_{chunk.get('chunk_id', i_int)}"
image_markdowns.append(f"\n![{alt_text}]({img_url})")
processed_image_urls.add(img_url) # 标记为已处理
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
@ -290,6 +326,10 @@ def chat(dialog, messages, stream=True, **kwargs):
if c.get("vector"):
del c["vector"]
# 将图片的 Markdown 字符串追加到回答末尾
if image_markdowns:
answer += "".join(image_markdowns)
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
finish_chat_ts = timer()

View File

@ -1,3 +1,4 @@
import traceback
from flask import Blueprint, request
from services.knowledgebases.service import KnowledgebaseService
from utils import success_response, error_response
@ -132,13 +133,12 @@ def add_documents_to_knowledgebase(kb_id):
)
except Exception as service_error:
print(f"[ERROR] 服务层错误详情: {str(service_error)}")
import traceback
traceback.print_exc()
return error_response(str(service_error), code=500)
except Exception as e:
print(f"[ERROR] 路由层错误详情: {str(e)}")
import traceback
traceback.print_exc()
return error_response(str(e), code=500)
@ -193,5 +193,50 @@ def get_parse_progress(doc_id):
return error_response(result['error'], code=404)
return success_response(data=result)
except Exception as e:
current_app.logger.error(f"获取解析进度失败: {str(e)}")
print(f"获取解析进度失败: {str(e)}")
return error_response("解析进行中,请稍后重试", code=202)
# 获取系统 Embedding 配置路由
@knowledgebase_bp.route('/system_embedding_config', methods=['GET'])
def get_system_embedding_config_route():
"""获取系统级 Embedding 配置的API端点"""
try:
config_data = KnowledgebaseService.get_system_embedding_config()
return success_response(data=config_data)
except Exception as e:
print(f"获取系统 Embedding 配置失败: {str(e)}")
return error_response(message=f"获取配置失败: {str(e)}", code=500) # 返回通用错误信息
# 设置系统 Embedding 配置路由
@knowledgebase_bp.route('/system_embedding_config', methods=['POST'])
def set_system_embedding_config_route():
"""设置系统级 Embedding 配置的API端点"""
try:
data = request.json
if not data:
return error_response('请求数据不能为空', code=400)
llm_name = data.get('llm_name', '').strip()
api_base = data.get('api_base', '').strip()
api_key = data.get('api_key', '').strip() # 允许空
if not llm_name or not api_base:
return error_response('模型名称和 API 地址不能为空', code=400)
# 调用服务层进行处理(包括连接测试和数据库操作)
success, message = KnowledgebaseService.set_system_embedding_config(
llm_name=llm_name,
api_base=api_base,
api_key=api_key
)
if success:
return success_response(message=message)
else:
# 如果服务层返回失败(例如连接测试失败或数据库错误),将消息返回给前端
return error_response(message=message, code=400) # 使用 400 表示操作失败
except Exception as e:
# 捕获路由层或未预料的服务层异常
print(f"设置系统 Embedding 配置失败: {str(e)}")
return error_response(message=f"设置配置时发生内部错误: {str(e)}", code=500)

View File

@ -5,6 +5,8 @@ import json
import mysql.connector
import time
import traceback
import re
import requests
from io import BytesIO
from datetime import datetime
from elasticsearch import Elasticsearch
@ -16,6 +18,7 @@ from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.read_api import read_local_office
from utils import generate_uuid
# 自定义tokenizer和文本处理函数替代rag.nlp中的功能
def tokenize_text(text):
"""将文本分词替代rag_tokenizer功能"""
@ -173,7 +176,7 @@ def get_text_from_block(block):
block_text += content
return ' '.join(block_text.split())
def perform_parse(doc_id, doc_info, file_info):
def perform_parse(doc_id, doc_info, file_info, embedding_config):
"""
执行文档解析的核心逻辑
@ -189,6 +192,26 @@ def perform_parse(doc_id, doc_info, file_info):
temp_image_dir = None
start_time = time.time()
middle_json_content = None # 初始化 middle_json_content
image_info_list = [] # 图片信息列表
# 默认值处理
embedding_model_name = embedding_config.get("llm_name") if embedding_config and embedding_config.get("llm_name") else "bge-m3" # 默认模型
# 对模型名称进行处理
if embedding_model_name and '___' in embedding_model_name:
embedding_model_name = embedding_model_name.split('___')[0]
embedding_api_base = embedding_config.get("api_base") if embedding_config and embedding_config.get("api_base") else "http://localhost:8000" # 默认基础 URL
embedding_api_key = embedding_config.get("api_key") if embedding_config else None # 可能为 None 或空字符串
# 构建完整的 Embedding API URL
if embedding_api_base:
if not embedding_api_base.startswith(('http://', 'https://')):
embedding_api_base = 'http://' + embedding_api_base
# 标准端点是 /embeddings
embedding_url = embedding_api_base.rstrip('/') + "/embeddings"
else:
embedding_url = None # 如果没有配置 Base URL则无法请求
print(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}")
try:
kb_id = doc_info['kb_id']
@ -330,8 +353,18 @@ def perform_parse(doc_id, doc_info, file_info):
es_client.indices.create(
index=index_name,
body={
"settings": {"number_of_replicas": 0}, # 单节点设为0
"mappings": { "properties": { "doc_id": {"type": "keyword"}, "kb_id": {"type": "keyword"}, "content_with_weight": {"type": "text"} } } # 简化字段
"settings": {"number_of_replicas": 0},
"mappings": {
"properties": {
"doc_id": {"type": "keyword"},
"kb_id": {"type": "keyword"},
"content_with_weight": {"type": "text"},
"q_1024_vec": {
"type": "dense_vector",
"dims": 1024
}
}
}
}
)
print(f"[Parser-INFO] 创建Elasticsearch索引: {index_name}")
@ -348,6 +381,42 @@ def perform_parse(doc_id, doc_info, file_info):
if not content or not content.strip():
continue
# 过滤 markdown 特殊符号
content = re.sub(r"[!#\\$/]", "", content)
q_1024_vec = [] # 初始化为空列表
# 获取embedding向量
try:
# embedding_resp = requests.post(
# "http://localhost:8000/v1/embeddings",
# json={
# "model": "bge-m3", # 你的embedding模型名
# "input": content
# },
# timeout=10
# )
headers = {"Content-Type": "application/json"}
if embedding_api_key:
headers["Authorization"] = f"Bearer {embedding_api_key}"
embedding_resp = requests.post(
embedding_url, # 使用动态构建的 URL
headers=headers, # 添加 headers (包含可能的 API Key)
json={
"model": embedding_model_name, # 使用动态获取或默认的模型名
"input": content
},
timeout=15 # 稍微增加超时时间
)
embedding_resp.raise_for_status()
embedding_data = embedding_resp.json()
q_1024_vec = embedding_data["data"][0]["embedding"]
print(f"[Parser-INFO] 获取embedding成功长度: {len(q_1024_vec)}")
except Exception as e:
print(f"[Parser-ERROR] 获取embedding失败: {e}")
q_1024_vec = []
chunk_id = generate_uuid()
page_idx = 0 # 默认页面索引
bbox = [0, 0, 0, 0] # 默认 bbox
@ -363,7 +432,6 @@ def perform_parse(doc_id, doc_info, file_info):
if processed_text_chunks == len(block_info_list) + 1: # 只在第一次耗尽时警告一次
print(f"[Parser-WARNING] middle_data 提供的块信息少于 content_list 中的文本块数量。后续文本块将使用默认 page/bbox。")
try:
# 上传文本块到 MinIO
minio_client.put_object(
@ -382,7 +450,6 @@ def perform_parse(doc_id, doc_info, file_info):
x1, y1, x2, y2 = bbox
bbox_reordered = [x1, x2, y1, y2]
es_doc = {
"doc_id": doc_id,
"kb_id": kb_id,
@ -390,19 +457,19 @@ def perform_parse(doc_id, doc_info, file_info):
"title_tks": doc_info['name'],
"title_sm_tks": doc_info['name'],
"content_with_weight": content,
"content_ltks": content_tokens,
"content_sm_ltks": content_tokens,
"content_ltks": " ".join(content_tokens), # 字符串类型
"content_sm_ltks": " ".join(content_tokens), # 字符串类型
"page_num_int": [page_idx + 1],
"position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]]
"top_int": [1],
"create_time": current_time_es,
"create_timestamp_flt": current_timestamp_es,
"img_id": "",
"q_1024_vec": [] # 向量字段留空
"q_1024_vec": q_1024_vec
}
# 存储到Elasticsearch
es_client.index(index=index_name, document=es_doc) # 使用 document 参数
es_client.index(index=index_name, id=chunk_id, document=es_doc) # 使用 document 参数
chunk_count += 1
chunk_ids_list.append(chunk_id)
@ -428,27 +495,95 @@ def perform_parse(doc_id, doc_info, file_info):
content_type = f"image/{img_ext[1:].lower()}"
if content_type == "image/jpg": content_type = "image/jpeg"
# try:
# # 上传图片到MinIO (桶为kb_id)
# minio_client.fput_object(
# bucket_name=output_bucket,
# object_name=img_key,
# file_path=img_path_abs,
# content_type=content_type
# )
# print(f"成功上传图片: {img_key}")
# # 注意设置公共访问权限可能需要额外配置MinIO服务器和存储桶策略
try:
# 上传图片到MinIO (桶为kb_id)
minio_client.fput_object(
bucket_name=output_bucket,
object_name=img_key,
file_path=img_path_abs,
content_type=content_type
)
# except Exception as e:
# print(f"[Parser-ERROR] 上传图片 {img_path_abs} 失败: {e}")
# 设置图片的公共访问权限
policy = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {"AWS": "*"},
"Action": ["s3:GetObject"],
"Resource": [f"arn:aws:s3:::{kb_id}/{img_key}"]
}
]
}
minio_client.set_bucket_policy(kb_id, json.dumps(policy))
print(f"成功上传图片: {img_key}")
minio_endpoint = MINIO_CONFIG["endpoint"]
use_ssl = MINIO_CONFIG.get("secure", False)
protocol = "https" if use_ssl else "http"
img_url = f"{protocol}://{minio_endpoint}/{output_bucket}/{img_key}"
# 记录图片信息包括URL和位置信息
image_info = {
"url": img_url,
"position": processed_text_chunks # 使用当前处理的文本块数作为位置参考
}
image_info_list.append(image_info)
print(f"图片访问链接: {img_url}")
except Exception as e:
print(f"[Parser-ERROR] 上传图片 {img_path_abs} 失败: {e}")
# 打印匹配总结信息
print(f"[Parser-INFO] 共处理 {processed_text_chunks} 个文本块。")
if middle_block_idx < len(block_info_list):
print(f"[Parser-WARNING] middle_data 中还有 {len(block_info_list) - middle_block_idx} 个提取的块信息未被使用。")
# 4. 更新文本块的图像信息
if image_info_list and chunk_ids_list:
conn = None
cursor = None
try:
conn = _get_db_connection()
cursor = conn.cursor()
# 4. 更新最终状态
# 为每个文本块找到最近的图片
for i, chunk_id in enumerate(chunk_ids_list):
# 找到与当前文本块最近的图片
nearest_image = None
for img_info in image_info_list:
# 计算文本块与图片的"距离"
distance = abs(i - img_info["position"]) # 使用位置差作为距离度量
# 如果文本块与图片的距离间隔小于10个块,则认为块与图片是相关的
if distance < 10:
nearest_image = img_info
# 如果找到了最近的图片则更新文本块的img_id
if nearest_image:
# 更新ES中的文档
direct_update = {
"doc": {
"img_id": nearest_image["url"]
}
}
es_client.update(index=index_name, id=chunk_id, body=direct_update, refresh=True)
index_name = f"ragflow_{tenant_id}"
print(f"[Parser-INFO] 更新文本块 {chunk_id} 的图片关联: {nearest_image['url']}")
except Exception as e:
print(f"[Parser-ERROR] 更新文本块图片关联失败: {e}")
finally:
if cursor:
cursor.close()
if conn:
conn.close()
# 5. 更新最终状态
process_duration = time.time() - start_time
_update_document_progress(doc_id, progress=1.0, message="解析完成", status='1', run='3', chunk_count=chunk_count, process_duration=process_duration)
_update_kb_chunk_count(kb_id, chunk_count) # 更新知识库总块数

View File

@ -1,12 +1,15 @@
import mysql.connector
import json
import threading
import requests
import traceback
from datetime import datetime
from utils import generate_uuid
from database import DB_CONFIG
# 解析相关模块
from .document_parser import perform_parse, _update_document_progress
class KnowledgebaseService:
@classmethod
@ -704,7 +707,8 @@ class KnowledgebaseService:
_update_document_progress(doc_id, status='2', run='1', progress=0.0, message='开始解析')
# 3. 调用后台解析函数
parse_result = perform_parse(doc_id, doc_info, file_info)
embedding_config = cls.get_system_embedding_config()
parse_result = perform_parse(doc_id, doc_info, file_info, embedding_config)
# 4. 返回解析结果
return parse_result
@ -792,3 +796,199 @@ class KnowledgebaseService:
cursor.close()
if conn:
conn.close()
# --- 获取最早用户 ID ---
@classmethod
def _get_earliest_user_tenant_id(cls):
"""获取创建时间最早的用户的 ID (作为 tenant_id)"""
conn = None
cursor = None
try:
conn = cls._get_db_connection()
cursor = conn.cursor()
query = "SELECT id FROM user ORDER BY create_time ASC LIMIT 1"
cursor.execute(query)
result = cursor.fetchone()
if result:
return result[0] # 返回用户 ID
else:
print("警告: 数据库中没有用户!")
return None
except Exception as e:
print(f"查询最早用户时出错: {e}")
traceback.print_exc()
return None
finally:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()
# --- 测试 Embedding 连接 ---
@classmethod
def _test_embedding_connection(cls, base_url, model_name, api_key):
"""
测试与自定义 Embedding 模型的连接 (使用 requests)
"""
print(f"开始测试连接: base_url={base_url}, model_name={model_name}")
try:
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
payload = {"input": ["Test connection"], "model": model_name}
if not base_url.startswith(('http://', 'https://')):
base_url = 'http://' + base_url
if not base_url.endswith('/'):
base_url += '/'
endpoint = "embeddings"
current_test_url = base_url + endpoint
print(f"尝试请求 URL: {current_test_url}")
try:
response = requests.post(current_test_url, headers=headers, json=payload, timeout=15)
print(f"请求 {current_test_url} 返回状态码: {response.status_code}")
if response.status_code == 200:
res_json = response.json()
if ("data" in res_json and isinstance(res_json["data"], list) and len(res_json["data"]) > 0 and "embedding" in res_json["data"][0] and len(res_json["data"][0]["embedding"]) > 0) or \
(isinstance(res_json, list) and len(res_json) > 0 and isinstance(res_json[0], list) and len(res_json[0]) > 0):
print(f"连接测试成功: {current_test_url}")
return True, "连接成功"
else:
print(f"连接成功但响应格式不正确于 {current_test_url}")
except Exception as json_e:
print(f"解析 JSON 响应失败于 {current_test_url}: {json_e}")
return False, "连接失败: 响应错误"
except Exception as e:
print(f"连接测试发生未知错误: {str(e)}")
traceback.print_exc()
return False, f"测试时发生未知错误: {str(e)}"
# --- 获取系统 Embedding 配置 ---
@classmethod
def get_system_embedding_config(cls):
"""获取系统级(最早用户)的 Embedding 配置"""
tenant_id = cls._get_earliest_user_tenant_id()
if not tenant_id:
raise Exception("无法找到系统基础用户") # 在服务层抛出异常
conn = None
cursor = None
try:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True) # 使用字典游标方便访问列名
query = """
SELECT llm_name, api_key, api_base
FROM tenant_llm
WHERE tenant_id = %s
LIMIT 1
"""
cursor.execute(query, (tenant_id,))
config = cursor.fetchone()
if config:
llm_name = config.get("llm_name", "")
api_key = config.get("api_key", "")
api_base = config.get("api_base", "")
# 对模型名称进行处理
if llm_name and '___' in llm_name:
llm_name = llm_name.split('___')[0]
# 如果有配置,返回
return {
"llm_name": llm_name,
"api_key": api_key,
"api_base": api_base
}
else:
# 如果没有配置,返回空
return {
"llm_name": "",
"api_key": "",
"api_base": ""
}
except Exception as e:
print(f"获取系统 Embedding 配置时出错: {e}")
traceback.print_exc()
raise Exception(f"获取配置时数据库出错: {e}") # 重新抛出异常
finally:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()
# --- 设置系统 Embedding 配置 ---
@classmethod
def set_system_embedding_config(cls, llm_name, api_base, api_key):
"""设置系统级(最早用户)的 Embedding 配置"""
tenant_id = cls._get_earliest_user_tenant_id()
if not tenant_id:
raise Exception("无法找到系统基础用户")
# 执行连接测试
is_connected, message = cls._test_embedding_connection(
base_url=api_base,
model_name=llm_name,
api_key=api_key
)
if not is_connected:
# 返回具体的测试失败原因给调用者(路由层)处理
return False, f"连接测试失败: {message}"
return True, f"连接成功: {message}"
# 测试通过,保存或更新配置到数据库(先不保存,以防冲突)
# conn = None
# cursor = None
# try:
# conn = cls._get_db_connection()
# cursor = conn.cursor()
# # 检查 TenantLLM 记录是否存在
# check_query = """
# SELECT id FROM tenant_llm
# WHERE tenant_id = %s AND llm_name = %s
# """
# cursor.execute(check_query, (tenant_id, llm_name))
# existing_config = cursor.fetchone()
# now = datetime.now()
# if existing_config:
# # 更新记录
# update_sql = """
# UPDATE tenant_llm
# SET api_key = %s, api_base = %s, max_tokens = %s, update_time = %s, update_date = %s
# WHERE id = %s
# """
# update_params = (api_key, api_base, max_tokens, now, now.date(), existing_config[0])
# cursor.execute(update_sql, update_params)
# print(f"已更新 TenantLLM 记录 (ID: {existing_config[0]})")
# else:
# # 插入新记录
# insert_sql = """
# INSERT INTO tenant_llm (tenant_id, llm_factory, model_type, llm_name, api_key, api_base, max_tokens, create_time, create_date, update_time, update_date, used_tokens)
# VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
# """
# insert_params = (tenant_id, "VLLM", "embedding", llm_name, api_key, api_base, max_tokens, now, now.date(), now, now.date(), 0) # used_tokens 默认为 0
# cursor.execute(insert_sql, insert_params)
# print(f"已创建新的 TenantLLM 记录")
# conn.commit() # 提交事务
# return True, "配置已成功保存"
# except Exception as e:
# if conn:
# conn.rollback() # 出错时回滚
# print(f"保存系统 Embedding 配置时数据库出错: {e}")
# traceback.print_exc()
# # 返回 False 和错误信息给路由层
# return False, f"保存配置时数据库出错: {e}"
# finally:
# if cursor:
# cursor.close()
# if conn and conn.is_connected():
# conn.close()

View File

@ -77,3 +77,24 @@ export function addDocumentToKnowledgeBaseApi(data: {
data: { file_ids: data.file_ids }
})
}
// 获取系统 Embedding 配置
export function getSystemEmbeddingConfigApi() {
return request({
url: "/api/v1/knowledgebases/system_embedding_config", // 确认 API 路径前缀是否正确
method: "get"
})
}
// 设置系统 Embedding 配置
export function setSystemEmbeddingConfigApi(data: {
llm_name: string
api_base: string
api_key?: string
}) {
return request({
url: "/api/v1/knowledgebases/system_embedding_config", // 确认 API 路径前缀是否正确
method: "post",
data
})
}

View File

@ -11,13 +11,15 @@ import {
batchDeleteKnowledgeBaseApi,
createKnowledgeBaseApi,
deleteKnowledgeBaseApi,
getKnowledgeBaseListApi
getKnowledgeBaseListApi,
getSystemEmbeddingConfigApi,
setSystemEmbeddingConfigApi
} from "@@/apis/kbs/knowledgebase"
import { usePagination } from "@@/composables/usePagination"
import { CaretRight, Delete, Plus, Refresh, Search, View } from "@element-plus/icons-vue"
import { CaretRight, Delete, Plus, Refresh, Search, Setting, View } from "@element-plus/icons-vue"
import axios from "axios"
import { ElMessage, ElMessageBox } from "element-plus"
import { onActivated, onBeforeUnmount, onMounted, reactive, ref, watch } from "vue"
import { nextTick, onActivated, onBeforeUnmount, onDeactivated, onMounted, reactive, ref, watch } from "vue"
import "element-plus/dist/index.css"
import "element-plus/theme-chalk/el-message-box.css"
import "element-plus/theme-chalk/el-message.css"
@ -615,6 +617,110 @@ onMounted(() => {
onActivated(() => {
getTableData()
})
// Embedding
const configModalVisible = ref(false)
const configFormRef = ref<FormInstance>() //
const configFormLoading = ref(false) //
const configSubmitLoading = ref(false) //
const configForm = reactive({
llm_name: "",
api_base: "",
api_key: ""
})
// URL
function validateUrl(rule: any, value: any, callback: any) {
if (!value) {
return callback(new Error("请输入模型 API 地址"))
}
// http, https IP
//
const urlPattern = /^(https?:\/\/)?([a-zA-Z0-9.-]+|\[[a-fA-F0-9:]+\])(:\d+)?(\/[^?#]*)?$/
if (!urlPattern.test(value)) {
callback(new Error("请输入有效的 Base URL (例如 http://host:port 或 https://domain/path)"))
} else {
callback()
}
}
const configFormRules = reactive({
llm_name: [{ required: true, message: "请输入模型名称", trigger: "blur" }],
api_base: [{ required: true, validator: validateUrl, trigger: "blur" }]
// api_key
})
//
async function showConfigModal() {
configModalVisible.value = true
configFormLoading.value = true
// nextTick DOM
await nextTick()
configFormRef.value?.resetFields() //
try {
// API
const res = await getSystemEmbeddingConfigApi() as ApiResponse<{ llm_name?: string, api_base?: string, api_key?: string }>
if (res.code === 0 && res.data) {
configForm.llm_name = res.data.llm_name || ""
configForm.api_base = res.data.api_base || ""
// API Key GET
configForm.api_key = res.data.api_key || ""
} else if (res.code !== 0) {
ElMessage.error(res.message || "获取配置失败")
} else {
// code === 0 data
console.log("当前未配置编码模型。")
}
} catch (error: any) {
ElMessage.error(error.message || "获取配置请求失败")
console.error("获取配置失败:", error)
} finally {
configFormLoading.value = false
}
}
//
function handleModalClose() {
//
configFormRef.value?.resetFields()
}
//
async function handleConfigSubmit() {
if (!configFormRef.value) return
// 使 .then() .catch() validate Promise
configFormRef.value.validate().then(async () => {
//
configSubmitLoading.value = true
try {
const payload = {
llm_name: configForm.llm_name.trim(),
api_base: configForm.api_base.trim(),
api_key: configForm.api_key
}
// API
const res = await setSystemEmbeddingConfigApi(payload) as ApiResponse<any> // 使any
if (res.code === 0) {
ElMessage.success("连接验证成功!")
configModalVisible.value = false
} else {
// res.message
ElMessage.error(res.message || "连接验证失败")
}
} catch (error: any) {
ElMessage.error(error.message || "连接验证请求失败")
console.error("连接验证失败:", error)
} finally {
configSubmitLoading.value = false
}
}).catch((errorFields) => {
//
console.log("表单验证失败!", errorFields)
// falsevalidate Promise reject
})
}
</script>
<template>
@ -654,6 +760,12 @@ onActivated(() => {
批量删除
</el-button>
</div>
<div>
<el-button type="primary" :icon="Setting" @click="showConfigModal">
编码模型配置
</el-button>
</div>
</div>
<div class="table-wrapper">
@ -922,6 +1034,56 @@ onActivated(() => {
</span>
</template>
</el-dialog>
<!-- 系统 Embedding 配置模态框 -->
<el-dialog
v-model="configModalVisible"
title="编码模型配置"
width="500px"
:close-on-click-modal="false"
@close="handleModalClose"
append-to-body
>
<el-form
ref="configFormRef"
:model="configForm"
:rules="configFormRules"
label-width="120px"
v-loading="configFormLoading"
>
<el-form-item label="模型名称" prop="llm_name">
<el-input v-model="configForm.llm_name" placeholder="请先在前台进行配置" disabled />
<div class="form-tip">
与模型服务中部署的名称一致
</div>
</el-form-item>
<el-form-item label="模型 API 地址" prop="api_base">
<el-input v-model="configForm.api_base" placeholder="请先在前台进行配置" disabled />
<div class="form-tip">
模型的 Base URL
</div>
</el-form-item>
<el-form-item label="API Key (可选)" prop="api_key">
<el-input v-model="configForm.api_key" type="password" show-password placeholder="请先在前台进行配置" disabled />
<div class="form-tip">
如果模型服务需要认证请提供
</div>
</el-form-item>
<el-form-item>
<div style="color: #909399; font-size: 12px; line-height: 1.5;">
此配置将作为知识库解析时默认的 Embedding 模型
</div>
</el-form-item>
</el-form>
<template #footer>
<span class="dialog-footer">
<el-button @click="configModalVisible = false">取消</el-button>
<el-button type="primary" @click="handleConfigSubmit" :loading="configSubmitLoading">
测试连接
</el-button>
</span>
</template>
</el-dialog>
</div>
<DocumentParseProgress
:document-id="currentDocId"
@ -952,7 +1114,8 @@ onActivated(() => {
.toolbar-wrapper {
display: flex;
justify-content: space-between;
justify-content: space-between; //
align-items: center; //
margin-bottom: 20px;
}
@ -1011,4 +1174,11 @@ onActivated(() => {
text-align: center;
}
}
.form-tip {
color: #909399;
font-size: 12px;
line-height: 1.5;
margin-top: 4px;
}
</style>

View File

@ -12,10 +12,12 @@ def test_embedding(model, text):
)
# 打印嵌入响应内容
print(f"Embedding response: {response}")
# print(f"Embedding response: {response}")
result = response.data[0].embedding
if response and response.data:
print(f"Embedding: {response.data[0].embedding}")
print(len(result))
else:
print("Failed to get embedding.")