diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index d8500a1..59d4b4f 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -161,11 +161,13 @@ def chat(dialog, messages, stream=True, **kwargs): if p["key"] not in kwargs: prompt_config["system"] = prompt_config["system"].replace( "{%s}" % p["key"], " ") - - if len(questions) > 1 and prompt_config.get("refine_multiturn"): - questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] - else: - questions = questions[-1:] + + # 不再使用多轮对话优化 + # if len(questions) > 1 and prompt_config.get("refine_multiturn"): + # questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] + # else: + # questions = questions[-1:] + questions = questions[-1:] refine_question_ts = timer() @@ -188,40 +190,50 @@ def chat(dialog, messages, stream=True, **kwargs): tenant_ids = list(set([kb.tenant_id for kb in kbs])) knowledges = [] - if prompt_config.get("reasoning", False): - reasoner = DeepResearcher(chat_mdl, - prompt_config, - partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3)) + + # 不再使用推理 + # if prompt_config.get("reasoning", False): + # reasoner = DeepResearcher(chat_mdl, + # prompt_config, + # partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3)) - for think in reasoner.thinking(kbinfos, " ".join(questions)): - if isinstance(think, str): - thought = think - knowledges = [t for t in think.split("\n") if t] - elif stream: - yield think - else: - kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, - dialog.similarity_threshold, - dialog.vector_similarity_weight, - doc_ids=attachments, - top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, - rank_feature=label_question(" ".join(questions), kbs) - ) - if prompt_config.get("tavily_api_key"): - tav = Tavily(prompt_config["tavily_api_key"]) - tav_res = tav.retrieve_chunks(" ".join(questions)) - kbinfos["chunks"].extend(tav_res["chunks"]) - kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) - if prompt_config.get("use_kg"): - ck = settings.kg_retrievaler.retrieval(" ".join(questions), - tenant_ids, - dialog.kb_ids, - embd_mdl, - LLMBundle(dialog.tenant_id, LLMType.CHAT)) - if ck["content_with_weight"]: - kbinfos["chunks"].insert(0, ck) + # for think in reasoner.thinking(kbinfos, " ".join(questions)): + # if isinstance(think, str): + # thought = think + # knowledges = [t for t in think.split("\n") if t] + # elif stream: + # yield think + # else: + # kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, + # dialog.similarity_threshold, + # dialog.vector_similarity_weight, + # doc_ids=attachments, + # top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, + # rank_feature=label_question(" ".join(questions), kbs) + # ) + # if prompt_config.get("tavily_api_key"): + # tav = Tavily(prompt_config["tavily_api_key"]) + # tav_res = tav.retrieve_chunks(" ".join(questions)) + # kbinfos["chunks"].extend(tav_res["chunks"]) + # kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) + # if prompt_config.get("use_kg"): + # ck = settings.kg_retrievaler.retrieval(" ".join(questions), + # tenant_ids, + # dialog.kb_ids, + # embd_mdl, + # LLMBundle(dialog.tenant_id, LLMType.CHAT)) + # if ck["content_with_weight"]: + # kbinfos["chunks"].insert(0, ck) - knowledges = kb_prompt(kbinfos, max_tokens) + # knowledges = kb_prompt(kbinfos, max_tokens) + kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, + dialog.similarity_threshold, + dialog.vector_similarity_weight, + doc_ids=attachments, + top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, + rank_feature=label_question(" ".join(questions), kbs) + ) + knowledges = kb_prompt(kbinfos, max_tokens) logging.debug( "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) @@ -255,6 +267,7 @@ def chat(dialog, messages, stream=True, **kwargs): nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions refs = [] + image_markdowns = [] # 用于存储图片的 Markdown 字符串 ans = answer.split("") think = "" if len(ans) == 2: @@ -262,6 +275,7 @@ def chat(dialog, messages, stream=True, **kwargs): answer = ans[1] if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL) + cited_chunk_indices = set() # 用于存储被引用的 chunk 索引 if not re.search(r"##[0-9]+\$\$", answer): answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] @@ -271,12 +285,34 @@ def chat(dialog, messages, stream=True, **kwargs): embd_mdl, tkweight=1 - dialog.vector_similarity_weight, vtweight=dialog.vector_similarity_weight) + cited_chunk_indices = idx # 获取 insert_citations 返回的索引 + else: idx = set([]) for r in re.finditer(r"##([0-9]+)\$\$", answer): i = int(r.group(1)) if i < len(kbinfos["chunks"]): idx.add(i) + cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引 + + # 根据引用的 chunk 索引提取图像信息并生成 Markdown + cited_doc_ids = set() + processed_image_urls = set() # 避免重复添加同一张图片 + print(f"DEBUG: cited_chunk_indices = {cited_chunk_indices}") + for i in cited_chunk_indices: + i_int = int(i) + if i_int < len(kbinfos["chunks"]): + chunk = kbinfos["chunks"][i_int] + cited_doc_ids.add(chunk["doc_id"]) + print(f"DEBUG: chunk = {chunk}") + # 检查 chunk 是否有关联的 image_id (URL) 且未被处理过 + print(f"DEBUG: chunk_id={chunk.get('chunk_id', i_int)}, image_id={chunk.get('image_id')}") + img_url = chunk.get("image_id") + if img_url and img_url not in processed_image_urls: + # 生成 Markdown 字符串,alt text 可以简单设为 "image" 或 chunk ID + alt_text = f"image_chunk_{chunk.get('chunk_id', i_int)}" + image_markdowns.append(f"\n![{alt_text}]({img_url})") + processed_image_urls.add(img_url) # 标记为已处理 idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [ @@ -289,6 +325,10 @@ def chat(dialog, messages, stream=True, **kwargs): for c in refs["chunks"]: if c.get("vector"): del c["vector"] + + # 将图片的 Markdown 字符串追加到回答末尾 + if image_markdowns: + answer += "".join(image_markdowns) if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'" diff --git a/management/server/routes/knowledgebases/routes.py b/management/server/routes/knowledgebases/routes.py index 85a033a..91fca0e 100644 --- a/management/server/routes/knowledgebases/routes.py +++ b/management/server/routes/knowledgebases/routes.py @@ -1,3 +1,4 @@ +import traceback from flask import Blueprint, request from services.knowledgebases.service import KnowledgebaseService from utils import success_response, error_response @@ -132,13 +133,12 @@ def add_documents_to_knowledgebase(kb_id): ) except Exception as service_error: print(f"[ERROR] 服务层错误详情: {str(service_error)}") - import traceback + traceback.print_exc() return error_response(str(service_error), code=500) except Exception as e: print(f"[ERROR] 路由层错误详情: {str(e)}") - import traceback traceback.print_exc() return error_response(str(e), code=500) @@ -193,5 +193,50 @@ def get_parse_progress(doc_id): return error_response(result['error'], code=404) return success_response(data=result) except Exception as e: - current_app.logger.error(f"获取解析进度失败: {str(e)}") - return error_response("解析进行中,请稍后重试", code=202) \ No newline at end of file + print(f"获取解析进度失败: {str(e)}") + return error_response("解析进行中,请稍后重试", code=202) + +# 获取系统 Embedding 配置路由 +@knowledgebase_bp.route('/system_embedding_config', methods=['GET']) +def get_system_embedding_config_route(): + """获取系统级 Embedding 配置的API端点""" + try: + config_data = KnowledgebaseService.get_system_embedding_config() + return success_response(data=config_data) + except Exception as e: + print(f"获取系统 Embedding 配置失败: {str(e)}") + return error_response(message=f"获取配置失败: {str(e)}", code=500) # 返回通用错误信息 + +# 设置系统 Embedding 配置路由 +@knowledgebase_bp.route('/system_embedding_config', methods=['POST']) +def set_system_embedding_config_route(): + """设置系统级 Embedding 配置的API端点""" + try: + data = request.json + if not data: + return error_response('请求数据不能为空', code=400) + + llm_name = data.get('llm_name', '').strip() + api_base = data.get('api_base', '').strip() + api_key = data.get('api_key', '').strip() # 允许空 + + if not llm_name or not api_base: + return error_response('模型名称和 API 地址不能为空', code=400) + + # 调用服务层进行处理(包括连接测试和数据库操作) + success, message = KnowledgebaseService.set_system_embedding_config( + llm_name=llm_name, + api_base=api_base, + api_key=api_key + ) + + if success: + return success_response(message=message) + else: + # 如果服务层返回失败(例如连接测试失败或数据库错误),将消息返回给前端 + return error_response(message=message, code=400) # 使用 400 表示操作失败 + + except Exception as e: + # 捕获路由层或未预料的服务层异常 + print(f"设置系统 Embedding 配置失败: {str(e)}") + return error_response(message=f"设置配置时发生内部错误: {str(e)}", code=500) \ No newline at end of file diff --git a/management/server/services/knowledgebases/document_parser.py b/management/server/services/knowledgebases/document_parser.py index 3f1695c..f245ba0 100644 --- a/management/server/services/knowledgebases/document_parser.py +++ b/management/server/services/knowledgebases/document_parser.py @@ -5,6 +5,8 @@ import json import mysql.connector import time import traceback +import re +import requests from io import BytesIO from datetime import datetime from elasticsearch import Elasticsearch @@ -16,6 +18,7 @@ from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.data.read_api import read_local_office from utils import generate_uuid + # 自定义tokenizer和文本处理函数,替代rag.nlp中的功能 def tokenize_text(text): """将文本分词,替代rag_tokenizer功能""" @@ -173,7 +176,7 @@ def get_text_from_block(block): block_text += content return ' '.join(block_text.split()) -def perform_parse(doc_id, doc_info, file_info): +def perform_parse(doc_id, doc_info, file_info, embedding_config): """ 执行文档解析的核心逻辑 @@ -189,7 +192,27 @@ def perform_parse(doc_id, doc_info, file_info): temp_image_dir = None start_time = time.time() middle_json_content = None # 初始化 middle_json_content + image_info_list = [] # 图片信息列表 + + # 默认值处理 + embedding_model_name = embedding_config.get("llm_name") if embedding_config and embedding_config.get("llm_name") else "bge-m3" # 默认模型 + # 对模型名称进行处理 + if embedding_model_name and '___' in embedding_model_name: + embedding_model_name = embedding_model_name.split('___')[0] + embedding_api_base = embedding_config.get("api_base") if embedding_config and embedding_config.get("api_base") else "http://localhost:8000" # 默认基础 URL + embedding_api_key = embedding_config.get("api_key") if embedding_config else None # 可能为 None 或空字符串 + + # 构建完整的 Embedding API URL + if embedding_api_base: + if not embedding_api_base.startswith(('http://', 'https://')): + embedding_api_base = 'http://' + embedding_api_base + # 标准端点是 /embeddings + embedding_url = embedding_api_base.rstrip('/') + "/embeddings" + else: + embedding_url = None # 如果没有配置 Base URL,则无法请求 + print(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}") + try: kb_id = doc_info['kb_id'] file_location = doc_info['location'] @@ -330,8 +353,18 @@ def perform_parse(doc_id, doc_info, file_info): es_client.indices.create( index=index_name, body={ - "settings": {"number_of_replicas": 0}, # 单节点设为0 - "mappings": { "properties": { "doc_id": {"type": "keyword"}, "kb_id": {"type": "keyword"}, "content_with_weight": {"type": "text"} } } # 简化字段 + "settings": {"number_of_replicas": 0}, + "mappings": { + "properties": { + "doc_id": {"type": "keyword"}, + "kb_id": {"type": "keyword"}, + "content_with_weight": {"type": "text"}, + "q_1024_vec": { + "type": "dense_vector", + "dims": 1024 + } + } + } } ) print(f"[Parser-INFO] 创建Elasticsearch索引: {index_name}") @@ -347,7 +380,43 @@ def perform_parse(doc_id, doc_info, file_info): content = chunk_data["text"] if not content or not content.strip(): continue + + # 过滤 markdown 特殊符号 + content = re.sub(r"[!#\\$/]", "", content) + q_1024_vec = [] # 初始化为空列表 + + # 获取embedding向量 + try: + # embedding_resp = requests.post( + # "http://localhost:8000/v1/embeddings", + # json={ + # "model": "bge-m3", # 你的embedding模型名 + # "input": content + # }, + # timeout=10 + # ) + headers = {"Content-Type": "application/json"} + if embedding_api_key: + headers["Authorization"] = f"Bearer {embedding_api_key}" + embedding_resp = requests.post( + embedding_url, # 使用动态构建的 URL + headers=headers, # 添加 headers (包含可能的 API Key) + json={ + "model": embedding_model_name, # 使用动态获取或默认的模型名 + "input": content + }, + timeout=15 # 稍微增加超时时间 + ) + + embedding_resp.raise_for_status() + embedding_data = embedding_resp.json() + q_1024_vec = embedding_data["data"][0]["embedding"] + print(f"[Parser-INFO] 获取embedding成功,长度: {len(q_1024_vec)}") + except Exception as e: + print(f"[Parser-ERROR] 获取embedding失败: {e}") + q_1024_vec = [] + chunk_id = generate_uuid() page_idx = 0 # 默认页面索引 bbox = [0, 0, 0, 0] # 默认 bbox @@ -362,8 +431,7 @@ def perform_parse(doc_id, doc_info, file_info): # 如果 block_info_list 耗尽,打印警告 if processed_text_chunks == len(block_info_list) + 1: # 只在第一次耗尽时警告一次 print(f"[Parser-WARNING] middle_data 提供的块信息少于 content_list 中的文本块数量。后续文本块将使用默认 page/bbox。") - - + try: # 上传文本块到 MinIO minio_client.put_object( @@ -382,7 +450,6 @@ def perform_parse(doc_id, doc_info, file_info): x1, y1, x2, y2 = bbox bbox_reordered = [x1, x2, y1, y2] - es_doc = { "doc_id": doc_id, "kb_id": kb_id, @@ -390,19 +457,19 @@ def perform_parse(doc_id, doc_info, file_info): "title_tks": doc_info['name'], "title_sm_tks": doc_info['name'], "content_with_weight": content, - "content_ltks": content_tokens, - "content_sm_ltks": content_tokens, + "content_ltks": " ".join(content_tokens), # 字符串类型 + "content_sm_ltks": " ".join(content_tokens), # 字符串类型 "page_num_int": [page_idx + 1], "position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]] "top_int": [1], "create_time": current_time_es, "create_timestamp_flt": current_timestamp_es, "img_id": "", - "q_1024_vec": [] # 向量字段留空 + "q_1024_vec": q_1024_vec } # 存储到Elasticsearch - es_client.index(index=index_name, document=es_doc) # 使用 document 参数 + es_client.index(index=index_name, id=chunk_id, document=es_doc) # 使用 document 参数 chunk_count += 1 chunk_ids_list.append(chunk_id) @@ -428,27 +495,95 @@ def perform_parse(doc_id, doc_info, file_info): content_type = f"image/{img_ext[1:].lower()}" if content_type == "image/jpg": content_type = "image/jpeg" - # try: - # # 上传图片到MinIO (桶为kb_id) - # minio_client.fput_object( - # bucket_name=output_bucket, - # object_name=img_key, - # file_path=img_path_abs, - # content_type=content_type - # ) - # print(f"成功上传图片: {img_key}") - # # 注意:设置公共访问权限可能需要额外配置MinIO服务器和存储桶策略 + try: + # 上传图片到MinIO (桶为kb_id) + minio_client.fput_object( + bucket_name=output_bucket, + object_name=img_key, + file_path=img_path_abs, + content_type=content_type + ) - # except Exception as e: - # print(f"[Parser-ERROR] 上传图片 {img_path_abs} 失败: {e}") + # 设置图片的公共访问权限 + policy = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"AWS": "*"}, + "Action": ["s3:GetObject"], + "Resource": [f"arn:aws:s3:::{kb_id}/{img_key}"] + } + ] + } + minio_client.set_bucket_policy(kb_id, json.dumps(policy)) + print(f"成功上传图片: {img_key}") + minio_endpoint = MINIO_CONFIG["endpoint"] + use_ssl = MINIO_CONFIG.get("secure", False) + protocol = "https" if use_ssl else "http" + img_url = f"{protocol}://{minio_endpoint}/{output_bucket}/{img_key}" + + # 记录图片信息,包括URL和位置信息 + image_info = { + "url": img_url, + "position": processed_text_chunks # 使用当前处理的文本块数作为位置参考 + } + image_info_list.append(image_info) + + print(f"图片访问链接: {img_url}") + + except Exception as e: + print(f"[Parser-ERROR] 上传图片 {img_path_abs} 失败: {e}") + # 打印匹配总结信息 print(f"[Parser-INFO] 共处理 {processed_text_chunks} 个文本块。") if middle_block_idx < len(block_info_list): print(f"[Parser-WARNING] middle_data 中还有 {len(block_info_list) - middle_block_idx} 个提取的块信息未被使用。") + + # 4. 更新文本块的图像信息 + if image_info_list and chunk_ids_list: + conn = None + cursor = None + try: + conn = _get_db_connection() + cursor = conn.cursor() + + # 为每个文本块找到最近的图片 + for i, chunk_id in enumerate(chunk_ids_list): + # 找到与当前文本块最近的图片 + nearest_image = None + + for img_info in image_info_list: + # 计算文本块与图片的"距离" + distance = abs(i - img_info["position"]) # 使用位置差作为距离度量 + # 如果文本块与图片的距离间隔小于10个块,则认为块与图片是相关的 + if distance < 10: + nearest_image = img_info + + # 如果找到了最近的图片,则更新文本块的img_id + if nearest_image: + # 更新ES中的文档 + direct_update = { + "doc": { + "img_id": nearest_image["url"] + } + } + es_client.update(index=index_name, id=chunk_id, body=direct_update, refresh=True) + + index_name = f"ragflow_{tenant_id}" + + print(f"[Parser-INFO] 更新文本块 {chunk_id} 的图片关联: {nearest_image['url']}") + + except Exception as e: + print(f"[Parser-ERROR] 更新文本块图片关联失败: {e}") + finally: + if cursor: + cursor.close() + if conn: + conn.close() - - # 4. 更新最终状态 + # 5. 更新最终状态 process_duration = time.time() - start_time _update_document_progress(doc_id, progress=1.0, message="解析完成", status='1', run='3', chunk_count=chunk_count, process_duration=process_duration) _update_kb_chunk_count(kb_id, chunk_count) # 更新知识库总块数 diff --git a/management/server/services/knowledgebases/service.py b/management/server/services/knowledgebases/service.py index ae2be4b..daff470 100644 --- a/management/server/services/knowledgebases/service.py +++ b/management/server/services/knowledgebases/service.py @@ -1,12 +1,15 @@ import mysql.connector import json import threading +import requests +import traceback from datetime import datetime from utils import generate_uuid from database import DB_CONFIG # 解析相关模块 from .document_parser import perform_parse, _update_document_progress + class KnowledgebaseService: @classmethod @@ -704,7 +707,8 @@ class KnowledgebaseService: _update_document_progress(doc_id, status='2', run='1', progress=0.0, message='开始解析') # 3. 调用后台解析函数 - parse_result = perform_parse(doc_id, doc_info, file_info) + embedding_config = cls.get_system_embedding_config() + parse_result = perform_parse(doc_id, doc_info, file_info, embedding_config) # 4. 返回解析结果 return parse_result @@ -791,4 +795,200 @@ class KnowledgebaseService: if cursor: cursor.close() if conn: - conn.close() \ No newline at end of file + conn.close() + + # --- 获取最早用户 ID --- + @classmethod + def _get_earliest_user_tenant_id(cls): + """获取创建时间最早的用户的 ID (作为 tenant_id)""" + conn = None + cursor = None + try: + conn = cls._get_db_connection() + cursor = conn.cursor() + query = "SELECT id FROM user ORDER BY create_time ASC LIMIT 1" + cursor.execute(query) + result = cursor.fetchone() + if result: + return result[0] # 返回用户 ID + else: + print("警告: 数据库中没有用户!") + return None + except Exception as e: + print(f"查询最早用户时出错: {e}") + traceback.print_exc() + return None + finally: + if cursor: + cursor.close() + if conn and conn.is_connected(): + conn.close() + + # --- 测试 Embedding 连接 --- + @classmethod + def _test_embedding_connection(cls, base_url, model_name, api_key): + """ + 测试与自定义 Embedding 模型的连接 (使用 requests)。 + """ + print(f"开始测试连接: base_url={base_url}, model_name={model_name}") + try: + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + payload = {"input": ["Test connection"], "model": model_name} + + if not base_url.startswith(('http://', 'https://')): + base_url = 'http://' + base_url + if not base_url.endswith('/'): + base_url += '/' + + endpoint = "embeddings" + current_test_url = base_url + endpoint + print(f"尝试请求 URL: {current_test_url}") + try: + response = requests.post(current_test_url, headers=headers, json=payload, timeout=15) + print(f"请求 {current_test_url} 返回状态码: {response.status_code}") + + if response.status_code == 200: + res_json = response.json() + if ("data" in res_json and isinstance(res_json["data"], list) and len(res_json["data"]) > 0 and "embedding" in res_json["data"][0] and len(res_json["data"][0]["embedding"]) > 0) or \ + (isinstance(res_json, list) and len(res_json) > 0 and isinstance(res_json[0], list) and len(res_json[0]) > 0): + print(f"连接测试成功: {current_test_url}") + return True, "连接成功" + else: + print(f"连接成功但响应格式不正确于 {current_test_url}") + + except Exception as json_e: + print(f"解析 JSON 响应失败于 {current_test_url}: {json_e}") + + return False, "连接失败: 响应错误" + + except Exception as e: + print(f"连接测试发生未知错误: {str(e)}") + traceback.print_exc() + return False, f"测试时发生未知错误: {str(e)}" + + # --- 获取系统 Embedding 配置 --- + @classmethod + def get_system_embedding_config(cls): + """获取系统级(最早用户)的 Embedding 配置""" + tenant_id = cls._get_earliest_user_tenant_id() + if not tenant_id: + raise Exception("无法找到系统基础用户") # 在服务层抛出异常 + + conn = None + cursor = None + try: + conn = cls._get_db_connection() + cursor = conn.cursor(dictionary=True) # 使用字典游标方便访问列名 + query = """ + SELECT llm_name, api_key, api_base + FROM tenant_llm + WHERE tenant_id = %s + LIMIT 1 + """ + cursor.execute(query, (tenant_id,)) + config = cursor.fetchone() + + if config: + llm_name = config.get("llm_name", "") + api_key = config.get("api_key", "") + api_base = config.get("api_base", "") + # 对模型名称进行处理 + if llm_name and '___' in llm_name: + llm_name = llm_name.split('___')[0] + # 如果有配置,返回 + return { + "llm_name": llm_name, + "api_key": api_key, + "api_base": api_base + } + else: + # 如果没有配置,返回空 + return { + "llm_name": "", + "api_key": "", + "api_base": "" + } + except Exception as e: + print(f"获取系统 Embedding 配置时出错: {e}") + traceback.print_exc() + raise Exception(f"获取配置时数据库出错: {e}") # 重新抛出异常 + finally: + if cursor: + cursor.close() + if conn and conn.is_connected(): + conn.close() + + # --- 设置系统 Embedding 配置 --- + @classmethod + def set_system_embedding_config(cls, llm_name, api_base, api_key): + """设置系统级(最早用户)的 Embedding 配置""" + tenant_id = cls._get_earliest_user_tenant_id() + if not tenant_id: + raise Exception("无法找到系统基础用户") + + # 执行连接测试 + is_connected, message = cls._test_embedding_connection( + base_url=api_base, + model_name=llm_name, + api_key=api_key + ) + + if not is_connected: + # 返回具体的测试失败原因给调用者(路由层)处理 + return False, f"连接测试失败: {message}" + + return True, f"连接成功: {message}" + # 测试通过,保存或更新配置到数据库(先不保存,以防冲突) + # conn = None + # cursor = None + # try: + # conn = cls._get_db_connection() + # cursor = conn.cursor() + + # # 检查 TenantLLM 记录是否存在 + # check_query = """ + # SELECT id FROM tenant_llm + # WHERE tenant_id = %s AND llm_name = %s + # """ + # cursor.execute(check_query, (tenant_id, llm_name)) + # existing_config = cursor.fetchone() + + # now = datetime.now() + # if existing_config: + # # 更新记录 + # update_sql = """ + # UPDATE tenant_llm + # SET api_key = %s, api_base = %s, max_tokens = %s, update_time = %s, update_date = %s + # WHERE id = %s + # """ + # update_params = (api_key, api_base, max_tokens, now, now.date(), existing_config[0]) + # cursor.execute(update_sql, update_params) + # print(f"已更新 TenantLLM 记录 (ID: {existing_config[0]})") + # else: + # # 插入新记录 + # insert_sql = """ + # INSERT INTO tenant_llm (tenant_id, llm_factory, model_type, llm_name, api_key, api_base, max_tokens, create_time, create_date, update_time, update_date, used_tokens) + # VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + # """ + # insert_params = (tenant_id, "VLLM", "embedding", llm_name, api_key, api_base, max_tokens, now, now.date(), now, now.date(), 0) # used_tokens 默认为 0 + # cursor.execute(insert_sql, insert_params) + # print(f"已创建新的 TenantLLM 记录") + + # conn.commit() # 提交事务 + # return True, "配置已成功保存" + + # except Exception as e: + # if conn: + # conn.rollback() # 出错时回滚 + # print(f"保存系统 Embedding 配置时数据库出错: {e}") + # traceback.print_exc() + # # 返回 False 和错误信息给路由层 + # return False, f"保存配置时数据库出错: {e}" + # finally: + # if cursor: + # cursor.close() + # if conn and conn.is_connected(): + # conn.close() \ No newline at end of file diff --git a/management/web/src/common/apis/kbs/knowledgebase.ts b/management/web/src/common/apis/kbs/knowledgebase.ts index e5b9c5a..d10fd9d 100644 --- a/management/web/src/common/apis/kbs/knowledgebase.ts +++ b/management/web/src/common/apis/kbs/knowledgebase.ts @@ -77,3 +77,24 @@ export function addDocumentToKnowledgeBaseApi(data: { data: { file_ids: data.file_ids } }) } + +// 获取系统 Embedding 配置 +export function getSystemEmbeddingConfigApi() { + return request({ + url: "/api/v1/knowledgebases/system_embedding_config", // 确认 API 路径前缀是否正确 + method: "get" + }) +} + +// 设置系统 Embedding 配置 +export function setSystemEmbeddingConfigApi(data: { + llm_name: string + api_base: string + api_key?: string +}) { + return request({ + url: "/api/v1/knowledgebases/system_embedding_config", // 确认 API 路径前缀是否正确 + method: "post", + data + }) +} diff --git a/management/web/src/pages/knowledgebase/index.vue b/management/web/src/pages/knowledgebase/index.vue index 03c4f5c..b691e0a 100644 --- a/management/web/src/pages/knowledgebase/index.vue +++ b/management/web/src/pages/knowledgebase/index.vue @@ -11,13 +11,15 @@ import { batchDeleteKnowledgeBaseApi, createKnowledgeBaseApi, deleteKnowledgeBaseApi, - getKnowledgeBaseListApi + getKnowledgeBaseListApi, + getSystemEmbeddingConfigApi, + setSystemEmbeddingConfigApi } from "@@/apis/kbs/knowledgebase" import { usePagination } from "@@/composables/usePagination" -import { CaretRight, Delete, Plus, Refresh, Search, View } from "@element-plus/icons-vue" +import { CaretRight, Delete, Plus, Refresh, Search, Setting, View } from "@element-plus/icons-vue" import axios from "axios" import { ElMessage, ElMessageBox } from "element-plus" -import { onActivated, onBeforeUnmount, onMounted, reactive, ref, watch } from "vue" +import { nextTick, onActivated, onBeforeUnmount, onDeactivated, onMounted, reactive, ref, watch } from "vue" import "element-plus/dist/index.css" import "element-plus/theme-chalk/el-message-box.css" import "element-plus/theme-chalk/el-message.css" @@ -615,6 +617,110 @@ onMounted(() => { onActivated(() => { getTableData() }) + +// 系统 Embedding 配置逻辑 +const configModalVisible = ref(false) +const configFormRef = ref() // 表单引用 +const configFormLoading = ref(false) // 表单加载状态 +const configSubmitLoading = ref(false) // 提交按钮加载状态 + +const configForm = reactive({ + llm_name: "", + api_base: "", + api_key: "" +}) + +// 简单的 URL 校验规则 +function validateUrl(rule: any, value: any, callback: any) { + if (!value) { + return callback(new Error("请输入模型 API 地址")) + } + // 允许 http, https 开头,允许 IP 地址和域名,允许端口 + // 修正:允许路径,但不允许查询参数或片段 + const urlPattern = /^(https?:\/\/)?([a-zA-Z0-9.-]+|\[[a-fA-F0-9:]+\])(:\d+)?(\/[^?#]*)?$/ + if (!urlPattern.test(value)) { + callback(new Error("请输入有效的 Base URL (例如 http://host:port 或 https://domain/path)")) + } else { + callback() + } +} + +const configFormRules = reactive({ + llm_name: [{ required: true, message: "请输入模型名称", trigger: "blur" }], + api_base: [{ required: true, validator: validateUrl, trigger: "blur" }] + // api_key 不是必填项 +}) + +// 显示配置模态框 +async function showConfigModal() { + configModalVisible.value = true + configFormLoading.value = true + // 重置表单可能需要在 nextTick 中执行,确保 DOM 更新完毕 + await nextTick() + configFormRef.value?.resetFields() // 清空上次的输入和校验状态 + + try { + // 确认 API 函数名称是否正确,并添加类型断言 + const res = await getSystemEmbeddingConfigApi() as ApiResponse<{ llm_name?: string, api_base?: string, api_key?: string }> + if (res.code === 0 && res.data) { + configForm.llm_name = res.data.llm_name || "" + configForm.api_base = res.data.api_base || "" + // 注意:API Key 通常不应在 GET 请求中返回,如果后端不返回,这里会是空字符串 + configForm.api_key = res.data.api_key || "" + } else if (res.code !== 0) { + ElMessage.error(res.message || "获取配置失败") + } else { + // code === 0 但 data 为空,说明没有配置 + console.log("当前未配置编码模型。") + } + } catch (error: any) { + ElMessage.error(error.message || "获取配置请求失败") + console.error("获取配置失败:", error) + } finally { + configFormLoading.value = false + } +} + +// 处理模态框关闭 +function handleModalClose() { + // 可以在这里再次重置表单,以防用户未保存直接关闭 + configFormRef.value?.resetFields() +} + +// 处理配置提交 +async function handleConfigSubmit() { + if (!configFormRef.value) return + // 使用 .then() .catch() 处理 validate 的 Promise + configFormRef.value.validate().then(async () => { + // 验证通过 + configSubmitLoading.value = true + try { + const payload = { + llm_name: configForm.llm_name.trim(), + api_base: configForm.api_base.trim(), + api_key: configForm.api_key + } + // 确认 API 函数名称是否正确,并添加类型断言 + const res = await setSystemEmbeddingConfigApi(payload) as ApiResponse // 使用类型断言并指定泛型参数为any + if (res.code === 0) { + ElMessage.success("连接验证成功!") + configModalVisible.value = false + } else { + // 后端应在 res.message 中返回错误信息,包括连接测试失败的原因 + ElMessage.error(res.message || "连接验证失败") + } + } catch (error: any) { + ElMessage.error(error.message || "连接验证请求失败") + console.error("连接验证失败:", error) + } finally { + configSubmitLoading.value = false + } + }).catch((errorFields) => { + // 验证失败 + console.log("表单验证失败!", errorFields) + // 这里不需要返回 false,validate 的 Promise reject 就表示失败了 + }) +} + + + + + + +
+ 与模型服务中部署的名称一致 +
+
+ + +
+ 模型的 Base URL +
+
+ + +
+ 如果模型服务需要认证,请提供 +
+
+ +
+ 此配置将作为知识库解析时默认的 Embedding 模型。 +
+
+
+ +
{ .toolbar-wrapper { display: flex; - justify-content: space-between; + justify-content: space-between; // 确保左右两边对齐 + align-items: center; // 垂直居中对齐 margin-bottom: 20px; } @@ -1011,4 +1174,11 @@ onActivated(() => { text-align: center; } } + +.form-tip { + color: #909399; + font-size: 12px; + line-height: 1.5; + margin-top: 4px; +} diff --git a/vllm/model_test.py b/vllm/model_test.py index 780e8a6..6f6b45a 100644 --- a/vllm/model_test.py +++ b/vllm/model_test.py @@ -12,10 +12,12 @@ def test_embedding(model, text): ) # 打印嵌入响应内容 - print(f"Embedding response: {response}") + # print(f"Embedding response: {response}") + result = response.data[0].embedding + if response and response.data: - print(f"Embedding: {response.data[0].embedding}") + print(len(result)) else: print("Failed to get embedding.")