feat: 添加系统Embedding配置功能并优化文档解析 (#35)

在知识库模块中新增了获取和设置系统Embedding配置的API接口,支持动态配置Embedding模型的基础URL、模型名称和API Key。同时,优化了文档解析逻辑,使用系统配置的Embedding模型生成文本块的向量,并将图片与文本块关联存储。
This commit is contained in:
zstar 2025-04-18 22:34:25 +08:00 committed by GitHub
parent 61d924a4fa
commit 803cc7e656
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 686 additions and 73 deletions

View File

@ -161,11 +161,13 @@ def chat(dialog, messages, stream=True, **kwargs):
if p["key"] not in kwargs: if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace( prompt_config["system"] = prompt_config["system"].replace(
"{%s}" % p["key"], " ") "{%s}" % p["key"], " ")
if len(questions) > 1 and prompt_config.get("refine_multiturn"): # 不再使用多轮对话优化
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] # if len(questions) > 1 and prompt_config.get("refine_multiturn"):
else: # questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
questions = questions[-1:] # else:
# questions = questions[-1:]
questions = questions[-1:]
refine_question_ts = timer() refine_question_ts = timer()
@ -188,40 +190,50 @@ def chat(dialog, messages, stream=True, **kwargs):
tenant_ids = list(set([kb.tenant_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs]))
knowledges = [] knowledges = []
if prompt_config.get("reasoning", False):
reasoner = DeepResearcher(chat_mdl, # 不再使用推理
prompt_config, # if prompt_config.get("reasoning", False):
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3)) # reasoner = DeepResearcher(chat_mdl,
# prompt_config,
# partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3))
for think in reasoner.thinking(kbinfos, " ".join(questions)): # for think in reasoner.thinking(kbinfos, " ".join(questions)):
if isinstance(think, str): # if isinstance(think, str):
thought = think # thought = think
knowledges = [t for t in think.split("\n") if t] # knowledges = [t for t in think.split("\n") if t]
elif stream: # elif stream:
yield think # yield think
else: # else:
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, # kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold, # dialog.similarity_threshold,
dialog.vector_similarity_weight, # dialog.vector_similarity_weight,
doc_ids=attachments, # doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, # top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs) # rank_feature=label_question(" ".join(questions), kbs)
) # )
if prompt_config.get("tavily_api_key"): # if prompt_config.get("tavily_api_key"):
tav = Tavily(prompt_config["tavily_api_key"]) # tav = Tavily(prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(" ".join(questions)) # tav_res = tav.retrieve_chunks(" ".join(questions))
kbinfos["chunks"].extend(tav_res["chunks"]) # kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) # kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"): # if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(" ".join(questions), # ck = settings.kg_retrievaler.retrieval(" ".join(questions),
tenant_ids, # tenant_ids,
dialog.kb_ids, # dialog.kb_ids,
embd_mdl, # embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT)) # LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]: # if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck) # kbinfos["chunks"].insert(0, ck)
knowledges = kb_prompt(kbinfos, max_tokens) # knowledges = kb_prompt(kbinfos, max_tokens)
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs)
)
knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug( logging.debug(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges))) "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
@ -255,6 +267,7 @@ def chat(dialog, messages, stream=True, **kwargs):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
refs = [] refs = []
image_markdowns = [] # 用于存储图片的 Markdown 字符串
ans = answer.split("</think>") ans = answer.split("</think>")
think = "" think = ""
if len(ans) == 2: if len(ans) == 2:
@ -262,6 +275,7 @@ def chat(dialog, messages, stream=True, **kwargs):
answer = ans[1] answer = ans[1]
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL) answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
cited_chunk_indices = set() # 用于存储被引用的 chunk 索引
if not re.search(r"##[0-9]+\$\$", answer): if not re.search(r"##[0-9]+\$\$", answer):
answer, idx = retriever.insert_citations(answer, answer, idx = retriever.insert_citations(answer,
[ck["content_ltks"] [ck["content_ltks"]
@ -271,12 +285,34 @@ def chat(dialog, messages, stream=True, **kwargs):
embd_mdl, embd_mdl,
tkweight=1 - dialog.vector_similarity_weight, tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight) vtweight=dialog.vector_similarity_weight)
cited_chunk_indices = idx # 获取 insert_citations 返回的索引
else: else:
idx = set([]) idx = set([])
for r in re.finditer(r"##([0-9]+)\$\$", answer): for r in re.finditer(r"##([0-9]+)\$\$", answer):
i = int(r.group(1)) i = int(r.group(1))
if i < len(kbinfos["chunks"]): if i < len(kbinfos["chunks"]):
idx.add(i) idx.add(i)
cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引
# 根据引用的 chunk 索引提取图像信息并生成 Markdown
cited_doc_ids = set()
processed_image_urls = set() # 避免重复添加同一张图片
print(f"DEBUG: cited_chunk_indices = {cited_chunk_indices}")
for i in cited_chunk_indices:
i_int = int(i)
if i_int < len(kbinfos["chunks"]):
chunk = kbinfos["chunks"][i_int]
cited_doc_ids.add(chunk["doc_id"])
print(f"DEBUG: chunk = {chunk}")
# 检查 chunk 是否有关联的 image_id (URL) 且未被处理过
print(f"DEBUG: chunk_id={chunk.get('chunk_id', i_int)}, image_id={chunk.get('image_id')}")
img_url = chunk.get("image_id")
if img_url and img_url not in processed_image_urls:
# 生成 Markdown 字符串alt text 可以简单设为 "image" 或 chunk ID
alt_text = f"image_chunk_{chunk.get('chunk_id', i_int)}"
image_markdowns.append(f"\n![{alt_text}]({img_url})")
processed_image_urls.add(img_url) # 标记为已处理
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [ recall_docs = [
@ -289,6 +325,10 @@ def chat(dialog, messages, stream=True, **kwargs):
for c in refs["chunks"]: for c in refs["chunks"]:
if c.get("vector"): if c.get("vector"):
del c["vector"] del c["vector"]
# 将图片的 Markdown 字符串追加到回答末尾
if image_markdowns:
answer += "".join(image_markdowns)
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'" answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"

View File

@ -1,3 +1,4 @@
import traceback
from flask import Blueprint, request from flask import Blueprint, request
from services.knowledgebases.service import KnowledgebaseService from services.knowledgebases.service import KnowledgebaseService
from utils import success_response, error_response from utils import success_response, error_response
@ -132,13 +133,12 @@ def add_documents_to_knowledgebase(kb_id):
) )
except Exception as service_error: except Exception as service_error:
print(f"[ERROR] 服务层错误详情: {str(service_error)}") print(f"[ERROR] 服务层错误详情: {str(service_error)}")
import traceback
traceback.print_exc() traceback.print_exc()
return error_response(str(service_error), code=500) return error_response(str(service_error), code=500)
except Exception as e: except Exception as e:
print(f"[ERROR] 路由层错误详情: {str(e)}") print(f"[ERROR] 路由层错误详情: {str(e)}")
import traceback
traceback.print_exc() traceback.print_exc()
return error_response(str(e), code=500) return error_response(str(e), code=500)
@ -193,5 +193,50 @@ def get_parse_progress(doc_id):
return error_response(result['error'], code=404) return error_response(result['error'], code=404)
return success_response(data=result) return success_response(data=result)
except Exception as e: except Exception as e:
current_app.logger.error(f"获取解析进度失败: {str(e)}") print(f"获取解析进度失败: {str(e)}")
return error_response("解析进行中,请稍后重试", code=202) return error_response("解析进行中,请稍后重试", code=202)
# 获取系统 Embedding 配置路由
@knowledgebase_bp.route('/system_embedding_config', methods=['GET'])
def get_system_embedding_config_route():
"""获取系统级 Embedding 配置的API端点"""
try:
config_data = KnowledgebaseService.get_system_embedding_config()
return success_response(data=config_data)
except Exception as e:
print(f"获取系统 Embedding 配置失败: {str(e)}")
return error_response(message=f"获取配置失败: {str(e)}", code=500) # 返回通用错误信息
# 设置系统 Embedding 配置路由
@knowledgebase_bp.route('/system_embedding_config', methods=['POST'])
def set_system_embedding_config_route():
"""设置系统级 Embedding 配置的API端点"""
try:
data = request.json
if not data:
return error_response('请求数据不能为空', code=400)
llm_name = data.get('llm_name', '').strip()
api_base = data.get('api_base', '').strip()
api_key = data.get('api_key', '').strip() # 允许空
if not llm_name or not api_base:
return error_response('模型名称和 API 地址不能为空', code=400)
# 调用服务层进行处理(包括连接测试和数据库操作)
success, message = KnowledgebaseService.set_system_embedding_config(
llm_name=llm_name,
api_base=api_base,
api_key=api_key
)
if success:
return success_response(message=message)
else:
# 如果服务层返回失败(例如连接测试失败或数据库错误),将消息返回给前端
return error_response(message=message, code=400) # 使用 400 表示操作失败
except Exception as e:
# 捕获路由层或未预料的服务层异常
print(f"设置系统 Embedding 配置失败: {str(e)}")
return error_response(message=f"设置配置时发生内部错误: {str(e)}", code=500)

View File

@ -5,6 +5,8 @@ import json
import mysql.connector import mysql.connector
import time import time
import traceback import traceback
import re
import requests
from io import BytesIO from io import BytesIO
from datetime import datetime from datetime import datetime
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
@ -16,6 +18,7 @@ from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.read_api import read_local_office from magic_pdf.data.read_api import read_local_office
from utils import generate_uuid from utils import generate_uuid
# 自定义tokenizer和文本处理函数替代rag.nlp中的功能 # 自定义tokenizer和文本处理函数替代rag.nlp中的功能
def tokenize_text(text): def tokenize_text(text):
"""将文本分词替代rag_tokenizer功能""" """将文本分词替代rag_tokenizer功能"""
@ -173,7 +176,7 @@ def get_text_from_block(block):
block_text += content block_text += content
return ' '.join(block_text.split()) return ' '.join(block_text.split())
def perform_parse(doc_id, doc_info, file_info): def perform_parse(doc_id, doc_info, file_info, embedding_config):
""" """
执行文档解析的核心逻辑 执行文档解析的核心逻辑
@ -189,7 +192,27 @@ def perform_parse(doc_id, doc_info, file_info):
temp_image_dir = None temp_image_dir = None
start_time = time.time() start_time = time.time()
middle_json_content = None # 初始化 middle_json_content middle_json_content = None # 初始化 middle_json_content
image_info_list = [] # 图片信息列表
# 默认值处理
embedding_model_name = embedding_config.get("llm_name") if embedding_config and embedding_config.get("llm_name") else "bge-m3" # 默认模型
# 对模型名称进行处理
if embedding_model_name and '___' in embedding_model_name:
embedding_model_name = embedding_model_name.split('___')[0]
embedding_api_base = embedding_config.get("api_base") if embedding_config and embedding_config.get("api_base") else "http://localhost:8000" # 默认基础 URL
embedding_api_key = embedding_config.get("api_key") if embedding_config else None # 可能为 None 或空字符串
# 构建完整的 Embedding API URL
if embedding_api_base:
if not embedding_api_base.startswith(('http://', 'https://')):
embedding_api_base = 'http://' + embedding_api_base
# 标准端点是 /embeddings
embedding_url = embedding_api_base.rstrip('/') + "/embeddings"
else:
embedding_url = None # 如果没有配置 Base URL则无法请求
print(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}")
try: try:
kb_id = doc_info['kb_id'] kb_id = doc_info['kb_id']
file_location = doc_info['location'] file_location = doc_info['location']
@ -330,8 +353,18 @@ def perform_parse(doc_id, doc_info, file_info):
es_client.indices.create( es_client.indices.create(
index=index_name, index=index_name,
body={ body={
"settings": {"number_of_replicas": 0}, # 单节点设为0 "settings": {"number_of_replicas": 0},
"mappings": { "properties": { "doc_id": {"type": "keyword"}, "kb_id": {"type": "keyword"}, "content_with_weight": {"type": "text"} } } # 简化字段 "mappings": {
"properties": {
"doc_id": {"type": "keyword"},
"kb_id": {"type": "keyword"},
"content_with_weight": {"type": "text"},
"q_1024_vec": {
"type": "dense_vector",
"dims": 1024
}
}
}
} }
) )
print(f"[Parser-INFO] 创建Elasticsearch索引: {index_name}") print(f"[Parser-INFO] 创建Elasticsearch索引: {index_name}")
@ -347,7 +380,43 @@ def perform_parse(doc_id, doc_info, file_info):
content = chunk_data["text"] content = chunk_data["text"]
if not content or not content.strip(): if not content or not content.strip():
continue continue
# 过滤 markdown 特殊符号
content = re.sub(r"[!#\\$/]", "", content)
q_1024_vec = [] # 初始化为空列表
# 获取embedding向量
try:
# embedding_resp = requests.post(
# "http://localhost:8000/v1/embeddings",
# json={
# "model": "bge-m3", # 你的embedding模型名
# "input": content
# },
# timeout=10
# )
headers = {"Content-Type": "application/json"}
if embedding_api_key:
headers["Authorization"] = f"Bearer {embedding_api_key}"
embedding_resp = requests.post(
embedding_url, # 使用动态构建的 URL
headers=headers, # 添加 headers (包含可能的 API Key)
json={
"model": embedding_model_name, # 使用动态获取或默认的模型名
"input": content
},
timeout=15 # 稍微增加超时时间
)
embedding_resp.raise_for_status()
embedding_data = embedding_resp.json()
q_1024_vec = embedding_data["data"][0]["embedding"]
print(f"[Parser-INFO] 获取embedding成功长度: {len(q_1024_vec)}")
except Exception as e:
print(f"[Parser-ERROR] 获取embedding失败: {e}")
q_1024_vec = []
chunk_id = generate_uuid() chunk_id = generate_uuid()
page_idx = 0 # 默认页面索引 page_idx = 0 # 默认页面索引
bbox = [0, 0, 0, 0] # 默认 bbox bbox = [0, 0, 0, 0] # 默认 bbox
@ -362,8 +431,7 @@ def perform_parse(doc_id, doc_info, file_info):
# 如果 block_info_list 耗尽,打印警告 # 如果 block_info_list 耗尽,打印警告
if processed_text_chunks == len(block_info_list) + 1: # 只在第一次耗尽时警告一次 if processed_text_chunks == len(block_info_list) + 1: # 只在第一次耗尽时警告一次
print(f"[Parser-WARNING] middle_data 提供的块信息少于 content_list 中的文本块数量。后续文本块将使用默认 page/bbox。") print(f"[Parser-WARNING] middle_data 提供的块信息少于 content_list 中的文本块数量。后续文本块将使用默认 page/bbox。")
try: try:
# 上传文本块到 MinIO # 上传文本块到 MinIO
minio_client.put_object( minio_client.put_object(
@ -382,7 +450,6 @@ def perform_parse(doc_id, doc_info, file_info):
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
bbox_reordered = [x1, x2, y1, y2] bbox_reordered = [x1, x2, y1, y2]
es_doc = { es_doc = {
"doc_id": doc_id, "doc_id": doc_id,
"kb_id": kb_id, "kb_id": kb_id,
@ -390,19 +457,19 @@ def perform_parse(doc_id, doc_info, file_info):
"title_tks": doc_info['name'], "title_tks": doc_info['name'],
"title_sm_tks": doc_info['name'], "title_sm_tks": doc_info['name'],
"content_with_weight": content, "content_with_weight": content,
"content_ltks": content_tokens, "content_ltks": " ".join(content_tokens), # 字符串类型
"content_sm_ltks": content_tokens, "content_sm_ltks": " ".join(content_tokens), # 字符串类型
"page_num_int": [page_idx + 1], "page_num_int": [page_idx + 1],
"position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]] "position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]]
"top_int": [1], "top_int": [1],
"create_time": current_time_es, "create_time": current_time_es,
"create_timestamp_flt": current_timestamp_es, "create_timestamp_flt": current_timestamp_es,
"img_id": "", "img_id": "",
"q_1024_vec": [] # 向量字段留空 "q_1024_vec": q_1024_vec
} }
# 存储到Elasticsearch # 存储到Elasticsearch
es_client.index(index=index_name, document=es_doc) # 使用 document 参数 es_client.index(index=index_name, id=chunk_id, document=es_doc) # 使用 document 参数
chunk_count += 1 chunk_count += 1
chunk_ids_list.append(chunk_id) chunk_ids_list.append(chunk_id)
@ -428,27 +495,95 @@ def perform_parse(doc_id, doc_info, file_info):
content_type = f"image/{img_ext[1:].lower()}" content_type = f"image/{img_ext[1:].lower()}"
if content_type == "image/jpg": content_type = "image/jpeg" if content_type == "image/jpg": content_type = "image/jpeg"
# try: try:
# # 上传图片到MinIO (桶为kb_id) # 上传图片到MinIO (桶为kb_id)
# minio_client.fput_object( minio_client.fput_object(
# bucket_name=output_bucket, bucket_name=output_bucket,
# object_name=img_key, object_name=img_key,
# file_path=img_path_abs, file_path=img_path_abs,
# content_type=content_type content_type=content_type
# ) )
# print(f"成功上传图片: {img_key}")
# # 注意设置公共访问权限可能需要额外配置MinIO服务器和存储桶策略
# except Exception as e: # 设置图片的公共访问权限
# print(f"[Parser-ERROR] 上传图片 {img_path_abs} 失败: {e}") policy = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {"AWS": "*"},
"Action": ["s3:GetObject"],
"Resource": [f"arn:aws:s3:::{kb_id}/{img_key}"]
}
]
}
minio_client.set_bucket_policy(kb_id, json.dumps(policy))
print(f"成功上传图片: {img_key}")
minio_endpoint = MINIO_CONFIG["endpoint"]
use_ssl = MINIO_CONFIG.get("secure", False)
protocol = "https" if use_ssl else "http"
img_url = f"{protocol}://{minio_endpoint}/{output_bucket}/{img_key}"
# 记录图片信息包括URL和位置信息
image_info = {
"url": img_url,
"position": processed_text_chunks # 使用当前处理的文本块数作为位置参考
}
image_info_list.append(image_info)
print(f"图片访问链接: {img_url}")
except Exception as e:
print(f"[Parser-ERROR] 上传图片 {img_path_abs} 失败: {e}")
# 打印匹配总结信息 # 打印匹配总结信息
print(f"[Parser-INFO] 共处理 {processed_text_chunks} 个文本块。") print(f"[Parser-INFO] 共处理 {processed_text_chunks} 个文本块。")
if middle_block_idx < len(block_info_list): if middle_block_idx < len(block_info_list):
print(f"[Parser-WARNING] middle_data 中还有 {len(block_info_list) - middle_block_idx} 个提取的块信息未被使用。") print(f"[Parser-WARNING] middle_data 中还有 {len(block_info_list) - middle_block_idx} 个提取的块信息未被使用。")
# 4. 更新文本块的图像信息
if image_info_list and chunk_ids_list:
conn = None
cursor = None
try:
conn = _get_db_connection()
cursor = conn.cursor()
# 为每个文本块找到最近的图片
for i, chunk_id in enumerate(chunk_ids_list):
# 找到与当前文本块最近的图片
nearest_image = None
for img_info in image_info_list:
# 计算文本块与图片的"距离"
distance = abs(i - img_info["position"]) # 使用位置差作为距离度量
# 如果文本块与图片的距离间隔小于10个块,则认为块与图片是相关的
if distance < 10:
nearest_image = img_info
# 如果找到了最近的图片则更新文本块的img_id
if nearest_image:
# 更新ES中的文档
direct_update = {
"doc": {
"img_id": nearest_image["url"]
}
}
es_client.update(index=index_name, id=chunk_id, body=direct_update, refresh=True)
index_name = f"ragflow_{tenant_id}"
print(f"[Parser-INFO] 更新文本块 {chunk_id} 的图片关联: {nearest_image['url']}")
except Exception as e:
print(f"[Parser-ERROR] 更新文本块图片关联失败: {e}")
finally:
if cursor:
cursor.close()
if conn:
conn.close()
# 5. 更新最终状态
# 4. 更新最终状态
process_duration = time.time() - start_time process_duration = time.time() - start_time
_update_document_progress(doc_id, progress=1.0, message="解析完成", status='1', run='3', chunk_count=chunk_count, process_duration=process_duration) _update_document_progress(doc_id, progress=1.0, message="解析完成", status='1', run='3', chunk_count=chunk_count, process_duration=process_duration)
_update_kb_chunk_count(kb_id, chunk_count) # 更新知识库总块数 _update_kb_chunk_count(kb_id, chunk_count) # 更新知识库总块数

View File

@ -1,12 +1,15 @@
import mysql.connector import mysql.connector
import json import json
import threading import threading
import requests
import traceback
from datetime import datetime from datetime import datetime
from utils import generate_uuid from utils import generate_uuid
from database import DB_CONFIG from database import DB_CONFIG
# 解析相关模块 # 解析相关模块
from .document_parser import perform_parse, _update_document_progress from .document_parser import perform_parse, _update_document_progress
class KnowledgebaseService: class KnowledgebaseService:
@classmethod @classmethod
@ -704,7 +707,8 @@ class KnowledgebaseService:
_update_document_progress(doc_id, status='2', run='1', progress=0.0, message='开始解析') _update_document_progress(doc_id, status='2', run='1', progress=0.0, message='开始解析')
# 3. 调用后台解析函数 # 3. 调用后台解析函数
parse_result = perform_parse(doc_id, doc_info, file_info) embedding_config = cls.get_system_embedding_config()
parse_result = perform_parse(doc_id, doc_info, file_info, embedding_config)
# 4. 返回解析结果 # 4. 返回解析结果
return parse_result return parse_result
@ -791,4 +795,200 @@ class KnowledgebaseService:
if cursor: if cursor:
cursor.close() cursor.close()
if conn: if conn:
conn.close() conn.close()
# --- 获取最早用户 ID ---
@classmethod
def _get_earliest_user_tenant_id(cls):
"""获取创建时间最早的用户的 ID (作为 tenant_id)"""
conn = None
cursor = None
try:
conn = cls._get_db_connection()
cursor = conn.cursor()
query = "SELECT id FROM user ORDER BY create_time ASC LIMIT 1"
cursor.execute(query)
result = cursor.fetchone()
if result:
return result[0] # 返回用户 ID
else:
print("警告: 数据库中没有用户!")
return None
except Exception as e:
print(f"查询最早用户时出错: {e}")
traceback.print_exc()
return None
finally:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()
# --- 测试 Embedding 连接 ---
@classmethod
def _test_embedding_connection(cls, base_url, model_name, api_key):
"""
测试与自定义 Embedding 模型的连接 (使用 requests)
"""
print(f"开始测试连接: base_url={base_url}, model_name={model_name}")
try:
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
payload = {"input": ["Test connection"], "model": model_name}
if not base_url.startswith(('http://', 'https://')):
base_url = 'http://' + base_url
if not base_url.endswith('/'):
base_url += '/'
endpoint = "embeddings"
current_test_url = base_url + endpoint
print(f"尝试请求 URL: {current_test_url}")
try:
response = requests.post(current_test_url, headers=headers, json=payload, timeout=15)
print(f"请求 {current_test_url} 返回状态码: {response.status_code}")
if response.status_code == 200:
res_json = response.json()
if ("data" in res_json and isinstance(res_json["data"], list) and len(res_json["data"]) > 0 and "embedding" in res_json["data"][0] and len(res_json["data"][0]["embedding"]) > 0) or \
(isinstance(res_json, list) and len(res_json) > 0 and isinstance(res_json[0], list) and len(res_json[0]) > 0):
print(f"连接测试成功: {current_test_url}")
return True, "连接成功"
else:
print(f"连接成功但响应格式不正确于 {current_test_url}")
except Exception as json_e:
print(f"解析 JSON 响应失败于 {current_test_url}: {json_e}")
return False, "连接失败: 响应错误"
except Exception as e:
print(f"连接测试发生未知错误: {str(e)}")
traceback.print_exc()
return False, f"测试时发生未知错误: {str(e)}"
# --- 获取系统 Embedding 配置 ---
@classmethod
def get_system_embedding_config(cls):
"""获取系统级(最早用户)的 Embedding 配置"""
tenant_id = cls._get_earliest_user_tenant_id()
if not tenant_id:
raise Exception("无法找到系统基础用户") # 在服务层抛出异常
conn = None
cursor = None
try:
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True) # 使用字典游标方便访问列名
query = """
SELECT llm_name, api_key, api_base
FROM tenant_llm
WHERE tenant_id = %s
LIMIT 1
"""
cursor.execute(query, (tenant_id,))
config = cursor.fetchone()
if config:
llm_name = config.get("llm_name", "")
api_key = config.get("api_key", "")
api_base = config.get("api_base", "")
# 对模型名称进行处理
if llm_name and '___' in llm_name:
llm_name = llm_name.split('___')[0]
# 如果有配置,返回
return {
"llm_name": llm_name,
"api_key": api_key,
"api_base": api_base
}
else:
# 如果没有配置,返回空
return {
"llm_name": "",
"api_key": "",
"api_base": ""
}
except Exception as e:
print(f"获取系统 Embedding 配置时出错: {e}")
traceback.print_exc()
raise Exception(f"获取配置时数据库出错: {e}") # 重新抛出异常
finally:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()
# --- 设置系统 Embedding 配置 ---
@classmethod
def set_system_embedding_config(cls, llm_name, api_base, api_key):
"""设置系统级(最早用户)的 Embedding 配置"""
tenant_id = cls._get_earliest_user_tenant_id()
if not tenant_id:
raise Exception("无法找到系统基础用户")
# 执行连接测试
is_connected, message = cls._test_embedding_connection(
base_url=api_base,
model_name=llm_name,
api_key=api_key
)
if not is_connected:
# 返回具体的测试失败原因给调用者(路由层)处理
return False, f"连接测试失败: {message}"
return True, f"连接成功: {message}"
# 测试通过,保存或更新配置到数据库(先不保存,以防冲突)
# conn = None
# cursor = None
# try:
# conn = cls._get_db_connection()
# cursor = conn.cursor()
# # 检查 TenantLLM 记录是否存在
# check_query = """
# SELECT id FROM tenant_llm
# WHERE tenant_id = %s AND llm_name = %s
# """
# cursor.execute(check_query, (tenant_id, llm_name))
# existing_config = cursor.fetchone()
# now = datetime.now()
# if existing_config:
# # 更新记录
# update_sql = """
# UPDATE tenant_llm
# SET api_key = %s, api_base = %s, max_tokens = %s, update_time = %s, update_date = %s
# WHERE id = %s
# """
# update_params = (api_key, api_base, max_tokens, now, now.date(), existing_config[0])
# cursor.execute(update_sql, update_params)
# print(f"已更新 TenantLLM 记录 (ID: {existing_config[0]})")
# else:
# # 插入新记录
# insert_sql = """
# INSERT INTO tenant_llm (tenant_id, llm_factory, model_type, llm_name, api_key, api_base, max_tokens, create_time, create_date, update_time, update_date, used_tokens)
# VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
# """
# insert_params = (tenant_id, "VLLM", "embedding", llm_name, api_key, api_base, max_tokens, now, now.date(), now, now.date(), 0) # used_tokens 默认为 0
# cursor.execute(insert_sql, insert_params)
# print(f"已创建新的 TenantLLM 记录")
# conn.commit() # 提交事务
# return True, "配置已成功保存"
# except Exception as e:
# if conn:
# conn.rollback() # 出错时回滚
# print(f"保存系统 Embedding 配置时数据库出错: {e}")
# traceback.print_exc()
# # 返回 False 和错误信息给路由层
# return False, f"保存配置时数据库出错: {e}"
# finally:
# if cursor:
# cursor.close()
# if conn and conn.is_connected():
# conn.close()

View File

@ -77,3 +77,24 @@ export function addDocumentToKnowledgeBaseApi(data: {
data: { file_ids: data.file_ids } data: { file_ids: data.file_ids }
}) })
} }
// 获取系统 Embedding 配置
export function getSystemEmbeddingConfigApi() {
return request({
url: "/api/v1/knowledgebases/system_embedding_config", // 确认 API 路径前缀是否正确
method: "get"
})
}
// 设置系统 Embedding 配置
export function setSystemEmbeddingConfigApi(data: {
llm_name: string
api_base: string
api_key?: string
}) {
return request({
url: "/api/v1/knowledgebases/system_embedding_config", // 确认 API 路径前缀是否正确
method: "post",
data
})
}

View File

@ -11,13 +11,15 @@ import {
batchDeleteKnowledgeBaseApi, batchDeleteKnowledgeBaseApi,
createKnowledgeBaseApi, createKnowledgeBaseApi,
deleteKnowledgeBaseApi, deleteKnowledgeBaseApi,
getKnowledgeBaseListApi getKnowledgeBaseListApi,
getSystemEmbeddingConfigApi,
setSystemEmbeddingConfigApi
} from "@@/apis/kbs/knowledgebase" } from "@@/apis/kbs/knowledgebase"
import { usePagination } from "@@/composables/usePagination" import { usePagination } from "@@/composables/usePagination"
import { CaretRight, Delete, Plus, Refresh, Search, View } from "@element-plus/icons-vue" import { CaretRight, Delete, Plus, Refresh, Search, Setting, View } from "@element-plus/icons-vue"
import axios from "axios" import axios from "axios"
import { ElMessage, ElMessageBox } from "element-plus" import { ElMessage, ElMessageBox } from "element-plus"
import { onActivated, onBeforeUnmount, onMounted, reactive, ref, watch } from "vue" import { nextTick, onActivated, onBeforeUnmount, onDeactivated, onMounted, reactive, ref, watch } from "vue"
import "element-plus/dist/index.css" import "element-plus/dist/index.css"
import "element-plus/theme-chalk/el-message-box.css" import "element-plus/theme-chalk/el-message-box.css"
import "element-plus/theme-chalk/el-message.css" import "element-plus/theme-chalk/el-message.css"
@ -615,6 +617,110 @@ onMounted(() => {
onActivated(() => { onActivated(() => {
getTableData() getTableData()
}) })
// Embedding
const configModalVisible = ref(false)
const configFormRef = ref<FormInstance>() //
const configFormLoading = ref(false) //
const configSubmitLoading = ref(false) //
const configForm = reactive({
llm_name: "",
api_base: "",
api_key: ""
})
// URL
function validateUrl(rule: any, value: any, callback: any) {
if (!value) {
return callback(new Error("请输入模型 API 地址"))
}
// http, https IP
//
const urlPattern = /^(https?:\/\/)?([a-zA-Z0-9.-]+|\[[a-fA-F0-9:]+\])(:\d+)?(\/[^?#]*)?$/
if (!urlPattern.test(value)) {
callback(new Error("请输入有效的 Base URL (例如 http://host:port 或 https://domain/path)"))
} else {
callback()
}
}
const configFormRules = reactive({
llm_name: [{ required: true, message: "请输入模型名称", trigger: "blur" }],
api_base: [{ required: true, validator: validateUrl, trigger: "blur" }]
// api_key
})
//
async function showConfigModal() {
configModalVisible.value = true
configFormLoading.value = true
// nextTick DOM
await nextTick()
configFormRef.value?.resetFields() //
try {
// API
const res = await getSystemEmbeddingConfigApi() as ApiResponse<{ llm_name?: string, api_base?: string, api_key?: string }>
if (res.code === 0 && res.data) {
configForm.llm_name = res.data.llm_name || ""
configForm.api_base = res.data.api_base || ""
// API Key GET
configForm.api_key = res.data.api_key || ""
} else if (res.code !== 0) {
ElMessage.error(res.message || "获取配置失败")
} else {
// code === 0 data
console.log("当前未配置编码模型。")
}
} catch (error: any) {
ElMessage.error(error.message || "获取配置请求失败")
console.error("获取配置失败:", error)
} finally {
configFormLoading.value = false
}
}
//
function handleModalClose() {
//
configFormRef.value?.resetFields()
}
//
async function handleConfigSubmit() {
if (!configFormRef.value) return
// 使 .then() .catch() validate Promise
configFormRef.value.validate().then(async () => {
//
configSubmitLoading.value = true
try {
const payload = {
llm_name: configForm.llm_name.trim(),
api_base: configForm.api_base.trim(),
api_key: configForm.api_key
}
// API
const res = await setSystemEmbeddingConfigApi(payload) as ApiResponse<any> // 使any
if (res.code === 0) {
ElMessage.success("连接验证成功!")
configModalVisible.value = false
} else {
// res.message
ElMessage.error(res.message || "连接验证失败")
}
} catch (error: any) {
ElMessage.error(error.message || "连接验证请求失败")
console.error("连接验证失败:", error)
} finally {
configSubmitLoading.value = false
}
}).catch((errorFields) => {
//
console.log("表单验证失败!", errorFields)
// falsevalidate Promise reject
})
}
</script> </script>
<template> <template>
@ -654,6 +760,12 @@ onActivated(() => {
批量删除 批量删除
</el-button> </el-button>
</div> </div>
<div>
<el-button type="primary" :icon="Setting" @click="showConfigModal">
编码模型配置
</el-button>
</div>
</div> </div>
<div class="table-wrapper"> <div class="table-wrapper">
@ -922,6 +1034,56 @@ onActivated(() => {
</span> </span>
</template> </template>
</el-dialog> </el-dialog>
<!-- 系统 Embedding 配置模态框 -->
<el-dialog
v-model="configModalVisible"
title="编码模型配置"
width="500px"
:close-on-click-modal="false"
@close="handleModalClose"
append-to-body
>
<el-form
ref="configFormRef"
:model="configForm"
:rules="configFormRules"
label-width="120px"
v-loading="configFormLoading"
>
<el-form-item label="模型名称" prop="llm_name">
<el-input v-model="configForm.llm_name" placeholder="请先在前台进行配置" disabled />
<div class="form-tip">
与模型服务中部署的名称一致
</div>
</el-form-item>
<el-form-item label="模型 API 地址" prop="api_base">
<el-input v-model="configForm.api_base" placeholder="请先在前台进行配置" disabled />
<div class="form-tip">
模型的 Base URL
</div>
</el-form-item>
<el-form-item label="API Key (可选)" prop="api_key">
<el-input v-model="configForm.api_key" type="password" show-password placeholder="请先在前台进行配置" disabled />
<div class="form-tip">
如果模型服务需要认证请提供
</div>
</el-form-item>
<el-form-item>
<div style="color: #909399; font-size: 12px; line-height: 1.5;">
此配置将作为知识库解析时默认的 Embedding 模型
</div>
</el-form-item>
</el-form>
<template #footer>
<span class="dialog-footer">
<el-button @click="configModalVisible = false">取消</el-button>
<el-button type="primary" @click="handleConfigSubmit" :loading="configSubmitLoading">
测试连接
</el-button>
</span>
</template>
</el-dialog>
</div> </div>
<DocumentParseProgress <DocumentParseProgress
:document-id="currentDocId" :document-id="currentDocId"
@ -952,7 +1114,8 @@ onActivated(() => {
.toolbar-wrapper { .toolbar-wrapper {
display: flex; display: flex;
justify-content: space-between; justify-content: space-between; //
align-items: center; //
margin-bottom: 20px; margin-bottom: 20px;
} }
@ -1011,4 +1174,11 @@ onActivated(() => {
text-align: center; text-align: center;
} }
} }
.form-tip {
color: #909399;
font-size: 12px;
line-height: 1.5;
margin-top: 4px;
}
</style> </style>

View File

@ -12,10 +12,12 @@ def test_embedding(model, text):
) )
# 打印嵌入响应内容 # 打印嵌入响应内容
print(f"Embedding response: {response}") # print(f"Embedding response: {response}")
result = response.data[0].embedding
if response and response.data: if response and response.data:
print(f"Embedding: {response.data[0].embedding}") print(len(result))
else: else:
print("Failed to get embedding.") print("Failed to get embedding.")