From 04439e66401356cd42b2f9b7ba069e8b4f9d2ba2 Mon Sep 17 00:00:00 2001 From: zstar <65890619+zstar1003@users.noreply.github.com> Date: Mon, 2 Jun 2025 01:48:11 +0800 Subject: [PATCH] =?UTF-8?q?fix(knowledgebase):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E7=A7=81=E4=BA=BA=E7=9F=A5=E8=AF=86=E5=BA=93=E8=AE=BF=E9=97=AE?= =?UTF-8?q?=E6=9D=83=E9=99=90=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修改文档解析逻辑,使用知识库创建者作为 tenant_id --- api/apps/chunk_app.py | 36 ++------- .../knowledgebases/document_parser.py | 5 +- .../services/knowledgebases/rag_tokenizer.py | 10 --- .../server/services/knowledgebases/service.py | 80 +++++-------------- rag/nlp/search.py | 34 ++++++-- 5 files changed, 57 insertions(+), 108 deletions(-) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 2680058..85fea03 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -14,16 +14,12 @@ # limitations under the License. # import datetime -import json - +import xxhash +import re from flask import request from flask_login import login_required, current_user - from rag.app.qa import rmPrefix, beAdoc -from rag.app.tag import label_question from rag.nlp import search, rag_tokenizer - -# from rag.prompts import keyword_extraction, cross_languages from rag.settings import PAGERANK_FLD from rag.utils import rmSpace from api.db import LLMType, ParserType @@ -34,8 +30,6 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va from api.db.services.document_service import DocumentService from api import settings from api.utils.api_utils import get_json_result -import xxhash -import re @manager.route("/list", methods=["POST"]) # noqa: F821 @@ -248,28 +242,6 @@ def create(): return server_error_response(e) -""" -{ - "similarity_threshold": 0.2, - "vector_similarity_weight": 0.30000000000000004, - "question": "香港", - "doc_ids": [], - "kb_id": "4b071030bc8e43f1bfb8b7831f320d2f", - "page": 1, - "size": 10 -}, -{ - "similarity_threshold": 0.2, - "vector_similarity_weight": 0.30000000000000004, - "question": "显著优势", - "doc_ids": [], - "kb_id": "1848bc54384611f0b33e4e66786d0323", - "page": 1, - "size": 10 -} -""" - - @manager.route("/retrieval_test", methods=["POST"]) # noqa: F821 @login_required @validate_request("kb_id", "question") @@ -285,13 +257,14 @@ def retrieval_test(): doc_ids = req.get("doc_ids", []) similarity_threshold = float(req.get("similarity_threshold", 0.0)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) - top = int(req.get("top_k", 1024)) + top = int(req.get("top_k", 1024)) # 此参数前端请求不会携带,默认即1024 # langs = req.get("cross_languages", []) # 获取跨语言设定 tenant_ids = [] try: # 查询当前用户所属的租户 tenants = UserTenantService.query(user_id=current_user.id) + # 验证知识库权限 for kb_id in kb_ids: for tenant in tenants: @@ -300,6 +273,7 @@ def retrieval_test(): break else: return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) + # 获取知识库信息 e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) if not e: diff --git a/management/server/services/knowledgebases/document_parser.py b/management/server/services/knowledgebases/document_parser.py index 28dfffa..48e3ee8 100644 --- a/management/server/services/knowledgebases/document_parser.py +++ b/management/server/services/knowledgebases/document_parser.py @@ -236,7 +236,7 @@ def process_table_content(content_list): return new_content_list -def perform_parse(doc_id, doc_info, file_info, embedding_config): +def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info): """ 执行文档解析的核心逻辑 @@ -244,6 +244,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config): doc_id (str): 文档ID. doc_info (dict): 包含文档信息的字典 (name, location, type, kb_id, parser_config, created_by). file_info (dict): 包含文件信息的字典 (parent_id/bucket_name). + kb_info (dict): 包含知识库信息的字典 (created_by). Returns: dict: 包含解析结果的字典 (success, chunk_count). @@ -305,7 +306,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config): _, file_extension = os.path.splitext(file_location) file_type = doc_info["type"].lower() bucket_name = file_info["parent_id"] # 文件存储的桶是 parent_id - tenant_id = doc_info["created_by"] # 文档创建者作为 tenant_id + tenant_id = kb_info["created_by"] # 知识库创建者作为 tenant_id # 进度更新回调 (直接调用内部更新函数) def update_progress(prog=None, msg=None): diff --git a/management/server/services/knowledgebases/rag_tokenizer.py b/management/server/services/knowledgebases/rag_tokenizer.py index 2bff5c3..7d253f7 100644 --- a/management/server/services/knowledgebases/rag_tokenizer.py +++ b/management/server/services/knowledgebases/rag_tokenizer.py @@ -10,16 +10,6 @@ from nltk import word_tokenize from nltk.stem import PorterStemmer, WordNetLemmatizer -# def get_project_base_directory(): -# return os.path.abspath( -# os.path.join( -# os.path.dirname(os.path.realpath(__file__)), -# os.pardir, -# os.pardir, -# ) -# ) - - class RagTokenizer: def key_(self, line): return str(line.lower().encode("utf-8"))[2:-1] diff --git a/management/server/services/knowledgebases/service.py b/management/server/services/knowledgebases/service.py index ee16a59..b303b4b 100644 --- a/management/server/services/knowledgebases/service.py +++ b/management/server/services/knowledgebases/service.py @@ -741,11 +741,11 @@ class KnowledgebaseService: @classmethod def parse_document(cls, doc_id): - """解析文档(调用解析逻辑)""" + """解析文档""" conn = None cursor = None try: - # 1. 获取文档和文件信息 + # 获取文档和文件信息 conn = cls._get_db_connection() cursor = conn.cursor(dictionary=True) @@ -765,28 +765,40 @@ class KnowledgebaseService: f2d_query = "SELECT file_id FROM file2document WHERE document_id = %s" cursor.execute(f2d_query, (doc_id,)) f2d_result = cursor.fetchone() + if not f2d_result: raise Exception("无法找到文件到文档的映射关系") - file_id = f2d_result["file_id"] + file_id = f2d_result["file_id"] file_query = "SELECT parent_id FROM file WHERE id = %s" cursor.execute(file_query, (file_id,)) file_info = cursor.fetchone() + if not file_info: raise Exception("无法找到文件记录") + # 获取知识库创建人信息 + # 根据doc_id查询document这张表,得到kb_id + kb_id_query = "SELECT kb_id FROM document WHERE id = %s" + cursor.execute(kb_id_query, (doc_id,)) + kb_id = cursor.fetchone() + # 根据kb_id查询knowledgebase这张表,得到created_by + kb_query = "SELECT created_by FROM knowledgebase WHERE id = %s" + cursor.execute(kb_query, (kb_id["kb_id"],)) + kb_info = cursor.fetchone() + cursor.close() conn.close() conn = None # 确保连接已关闭 - # 2. 更新文档状态为处理中 (使用 parser 模块的函数) + # 更新文档状态为处理中 (使用 parser 模块的函数) _update_document_progress(doc_id, status="2", run="1", progress=0.0, message="开始解析") - # 3. 调用后台解析函数 + # 调用后台解析函数 embedding_config = cls.get_system_embedding_config() - parse_result = perform_parse(doc_id, doc_info, file_info, embedding_config) + parse_result = perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info) - # 4. 返回解析结果 + # 返回解析结果 return parse_result except Exception as e: @@ -1045,62 +1057,11 @@ class KnowledgebaseService: return False, f"连接测试失败: {message}" return True, f"连接成功: {message}" - # 测试通过,保存或更新配置到数据库(先不保存,以防冲突) - # conn = None - # cursor = None - # try: - # conn = cls._get_db_connection() - # cursor = conn.cursor() - - # # 检查 TenantLLM 记录是否存在 - # check_query = """ - # SELECT id FROM tenant_llm - # WHERE tenant_id = %s AND llm_name = %s - # """ - # cursor.execute(check_query, (tenant_id, llm_name)) - # existing_config = cursor.fetchone() - - # now = datetime.now() - # if existing_config: - # # 更新记录 - # update_sql = """ - # UPDATE tenant_llm - # SET api_key = %s, api_base = %s, max_tokens = %s, update_time = %s, update_date = %s - # WHERE id = %s - # """ - # update_params = (api_key, api_base, max_tokens, now, now.date(), existing_config[0]) - # cursor.execute(update_sql, update_params) - # print(f"已更新 TenantLLM 记录 (ID: {existing_config[0]})") - # else: - # # 插入新记录 - # insert_sql = """ - # INSERT INTO tenant_llm (tenant_id, llm_factory, model_type, llm_name, api_key, api_base, max_tokens, create_time, create_date, update_time, update_date, used_tokens) - # VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - # """ - # insert_params = (tenant_id, "VLLM", "embedding", llm_name, api_key, api_base, max_tokens, now, now.date(), now, now.date(), 0) # used_tokens 默认为 0 - # cursor.execute(insert_sql, insert_params) - # print(f"已创建新的 TenantLLM 记录") - - # conn.commit() # 提交事务 - # return True, "配置已成功保存" - - # except Exception as e: - # if conn: - # conn.rollback() # 出错时回滚 - # print(f"保存系统 Embedding 配置时数据库出错: {e}") - # traceback.print_exc() - # # 返回 False 和错误信息给路由层 - # return False, f"保存配置时数据库出错: {e}" - # finally: - # if cursor: - # cursor.close() - # if conn and conn.is_connected(): - # conn.close() # 顺序批量解析 (核心逻辑,在后台线程运行) @classmethod def _run_sequential_batch_parse(cls, kb_id): - """【内部方法】顺序执行批量解析,并在 SEQUENTIAL_BATCH_TASKS 中更新状态""" + """顺序执行批量解析,并在 SEQUENTIAL_BATCH_TASKS 中更新状态""" global SEQUENTIAL_BATCH_TASKS task_info = SEQUENTIAL_BATCH_TASKS.get(kb_id) if not task_info: @@ -1118,7 +1079,6 @@ class KnowledgebaseService: cursor = conn.cursor(dictionary=True) # 查询需要解析的文档 - # 注意:这里的条件要和前端期望的一致 query = """ SELECT id, name FROM document WHERE kb_id = %s AND run != '3' diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 080bfeb..c66cbd6 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -162,7 +162,7 @@ class Dealer: q_vec = matchDense.embedding_data # 在返回字段中加入查询向量字段 src.append(f"q_{len(q_vec)}_vec") - # 创建融合表达式:设置向量匹配为95%,全文为5%(可以调整权重) + # 创建融合表达式:设置向量匹配为95%,全文为5% fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"}) # 构建混合查询表达式 matchExprs = [matchText, matchDense, fusionExpr] @@ -210,10 +210,6 @@ class Dealer: keywords = list(kwds) # 转为列表格式返回 highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") # 获取高亮内容 aggs = self.dataStore.getAggregation(res, "docnm_kwd") # 执行基于文档名的聚合分析 - print(f"ids:{ids}") - print(f"keywords:{keywords}") - print(f"highlight:{highlight}") - print(f"aggs:{aggs}") return self.SearchResult(total=total, ids=ids, query_vector=q_vec, aggregation=aggs, highlight=highlight, field=self.dataStore.getFields(res, src), keywords=keywords) @staticmethod @@ -322,6 +318,33 @@ class Dealer: return np.array(rank_fea) * 10.0 + pageranks def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks", rank_feature: dict | None = None): + """ + 对初步检索到的结果 (sres) 进行重排序。 + + 该方法结合了多种相似度/特征来计算每个结果的新排序分数: + 1. 文本相似度 (Token Similarity): 基于查询关键词与文档内容的词元匹配。 + 2. 向量相似度 (Vector Similarity): 基于查询向量与文档向量的余弦相似度。 + 3. 排序特征分数 (Rank Feature Scores): 如文档的 PageRank 值或与查询相关的标签特征得分。 + + 最终的排序分数是这几种分数的加权组合(或直接相加)。 + + Args: + sres (SearchResult): 初步检索的结果对象,包含查询向量、文档ID、字段内容等。 + query (str): 原始用户查询字符串。 + tkweight (float): 文本相似度在混合相似度计算中的权重。 + vtweight (float): 向量相似度在混合相似度计算中的权重。 + cfield (str): 用于提取主要文本内容以进行词元匹配的字段名,默认为 "content_ltks"。 + rank_feature (dict | None): 用于计算排序特征分数的查询侧特征, + 例如 {PAGERANK_FLD: 10} 表示 PageRank 权重, + 或包含其他标签及其权重的字典。 + + Returns: + tuple[np.ndarray, np.ndarray, np.ndarray]: + - sim (np.ndarray): 每个文档的最终重排序分数 (混合相似度 + 排序特征分数)。 + - tksim (np.ndarray): 每个文档的纯文本相似度分数。 + - vtsim (np.ndarray): 每个文档的纯向量相似度分数。 + 如果初步检索结果为空 (sres.ids is empty),则返回三个空列表。 + """ _, keywords = self.qryr.question(query) vector_size = len(sres.query_vector) vector_column = f"q_{vector_size}_vec" @@ -446,6 +469,7 @@ class Dealer: # 执行搜索操作 sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature) + # 执行重排序操作 if rerank_mdl and sres.total > 0: sim, tsim, vsim = self.rerank_by_model(rerank_mdl, sres, question, 1 - vector_similarity_weight, vector_similarity_weight, rank_feature=rank_feature) else: