fix(knowledgebase): 修复私人知识库访问权限问题

- 修改文档解析逻辑,使用知识库创建者作为 tenant_id
This commit is contained in:
zstar 2025-06-02 01:48:11 +08:00
parent d0f17da9b4
commit 04439e6640
5 changed files with 57 additions and 108 deletions

View File

@ -14,16 +14,12 @@
# limitations under the License.
#
import datetime
import json
import xxhash
import re
from flask import request
from flask_login import login_required, current_user
from rag.app.qa import rmPrefix, beAdoc
from rag.app.tag import label_question
from rag.nlp import search, rag_tokenizer
# from rag.prompts import keyword_extraction, cross_languages
from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace
from api.db import LLMType, ParserType
@ -34,8 +30,6 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
from api.db.services.document_service import DocumentService
from api import settings
from api.utils.api_utils import get_json_result
import xxhash
import re
@manager.route("/list", methods=["POST"]) # noqa: F821
@ -248,28 +242,6 @@ def create():
return server_error_response(e)
"""
{
"similarity_threshold": 0.2,
"vector_similarity_weight": 0.30000000000000004,
"question": "香港",
"doc_ids": [],
"kb_id": "4b071030bc8e43f1bfb8b7831f320d2f",
"page": 1,
"size": 10
},
{
"similarity_threshold": 0.2,
"vector_similarity_weight": 0.30000000000000004,
"question": "显著优势",
"doc_ids": [],
"kb_id": "1848bc54384611f0b33e4e66786d0323",
"page": 1,
"size": 10
}
"""
@manager.route("/retrieval_test", methods=["POST"]) # noqa: F821
@login_required
@validate_request("kb_id", "question")
@ -285,13 +257,14 @@ def retrieval_test():
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.0))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))
top = int(req.get("top_k", 1024)) # 此参数前端请求不会携带默认即1024
# langs = req.get("cross_languages", []) # 获取跨语言设定
tenant_ids = []
try:
# 查询当前用户所属的租户
tenants = UserTenantService.query(user_id=current_user.id)
# 验证知识库权限
for kb_id in kb_ids:
for tenant in tenants:
@ -300,6 +273,7 @@ def retrieval_test():
break
else:
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
# 获取知识库信息
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e:

View File

@ -236,7 +236,7 @@ def process_table_content(content_list):
return new_content_list
def perform_parse(doc_id, doc_info, file_info, embedding_config):
def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
"""
执行文档解析的核心逻辑
@ -244,6 +244,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
doc_id (str): 文档ID.
doc_info (dict): 包含文档信息的字典 (name, location, type, kb_id, parser_config, created_by).
file_info (dict): 包含文件信息的字典 (parent_id/bucket_name).
kb_info (dict): 包含知识库信息的字典 (created_by).
Returns:
dict: 包含解析结果的字典 (success, chunk_count).
@ -305,7 +306,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
_, file_extension = os.path.splitext(file_location)
file_type = doc_info["type"].lower()
bucket_name = file_info["parent_id"] # 文件存储的桶是 parent_id
tenant_id = doc_info["created_by"] # 文档创建者作为 tenant_id
tenant_id = kb_info["created_by"] # 知识库创建者作为 tenant_id
# 进度更新回调 (直接调用内部更新函数)
def update_progress(prog=None, msg=None):

View File

@ -10,16 +10,6 @@ from nltk import word_tokenize
from nltk.stem import PorterStemmer, WordNetLemmatizer
# def get_project_base_directory():
# return os.path.abspath(
# os.path.join(
# os.path.dirname(os.path.realpath(__file__)),
# os.pardir,
# os.pardir,
# )
# )
class RagTokenizer:
def key_(self, line):
return str(line.lower().encode("utf-8"))[2:-1]

View File

@ -741,11 +741,11 @@ class KnowledgebaseService:
@classmethod
def parse_document(cls, doc_id):
"""解析文档(调用解析逻辑)"""
"""解析文档"""
conn = None
cursor = None
try:
# 1. 获取文档和文件信息
# 获取文档和文件信息
conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True)
@ -765,28 +765,40 @@ class KnowledgebaseService:
f2d_query = "SELECT file_id FROM file2document WHERE document_id = %s"
cursor.execute(f2d_query, (doc_id,))
f2d_result = cursor.fetchone()
if not f2d_result:
raise Exception("无法找到文件到文档的映射关系")
file_id = f2d_result["file_id"]
file_id = f2d_result["file_id"]
file_query = "SELECT parent_id FROM file WHERE id = %s"
cursor.execute(file_query, (file_id,))
file_info = cursor.fetchone()
if not file_info:
raise Exception("无法找到文件记录")
# 获取知识库创建人信息
# 根据doc_id查询document这张表得到kb_id
kb_id_query = "SELECT kb_id FROM document WHERE id = %s"
cursor.execute(kb_id_query, (doc_id,))
kb_id = cursor.fetchone()
# 根据kb_id查询knowledgebase这张表得到created_by
kb_query = "SELECT created_by FROM knowledgebase WHERE id = %s"
cursor.execute(kb_query, (kb_id["kb_id"],))
kb_info = cursor.fetchone()
cursor.close()
conn.close()
conn = None # 确保连接已关闭
# 2. 更新文档状态为处理中 (使用 parser 模块的函数)
# 更新文档状态为处理中 (使用 parser 模块的函数)
_update_document_progress(doc_id, status="2", run="1", progress=0.0, message="开始解析")
# 3. 调用后台解析函数
# 调用后台解析函数
embedding_config = cls.get_system_embedding_config()
parse_result = perform_parse(doc_id, doc_info, file_info, embedding_config)
parse_result = perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info)
# 4. 返回解析结果
# 返回解析结果
return parse_result
except Exception as e:
@ -1045,62 +1057,11 @@ class KnowledgebaseService:
return False, f"连接测试失败: {message}"
return True, f"连接成功: {message}"
# 测试通过,保存或更新配置到数据库(先不保存,以防冲突)
# conn = None
# cursor = None
# try:
# conn = cls._get_db_connection()
# cursor = conn.cursor()
# # 检查 TenantLLM 记录是否存在
# check_query = """
# SELECT id FROM tenant_llm
# WHERE tenant_id = %s AND llm_name = %s
# """
# cursor.execute(check_query, (tenant_id, llm_name))
# existing_config = cursor.fetchone()
# now = datetime.now()
# if existing_config:
# # 更新记录
# update_sql = """
# UPDATE tenant_llm
# SET api_key = %s, api_base = %s, max_tokens = %s, update_time = %s, update_date = %s
# WHERE id = %s
# """
# update_params = (api_key, api_base, max_tokens, now, now.date(), existing_config[0])
# cursor.execute(update_sql, update_params)
# print(f"已更新 TenantLLM 记录 (ID: {existing_config[0]})")
# else:
# # 插入新记录
# insert_sql = """
# INSERT INTO tenant_llm (tenant_id, llm_factory, model_type, llm_name, api_key, api_base, max_tokens, create_time, create_date, update_time, update_date, used_tokens)
# VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
# """
# insert_params = (tenant_id, "VLLM", "embedding", llm_name, api_key, api_base, max_tokens, now, now.date(), now, now.date(), 0) # used_tokens 默认为 0
# cursor.execute(insert_sql, insert_params)
# print(f"已创建新的 TenantLLM 记录")
# conn.commit() # 提交事务
# return True, "配置已成功保存"
# except Exception as e:
# if conn:
# conn.rollback() # 出错时回滚
# print(f"保存系统 Embedding 配置时数据库出错: {e}")
# traceback.print_exc()
# # 返回 False 和错误信息给路由层
# return False, f"保存配置时数据库出错: {e}"
# finally:
# if cursor:
# cursor.close()
# if conn and conn.is_connected():
# conn.close()
# 顺序批量解析 (核心逻辑,在后台线程运行)
@classmethod
def _run_sequential_batch_parse(cls, kb_id):
"""【内部方法】顺序执行批量解析,并在 SEQUENTIAL_BATCH_TASKS 中更新状态"""
"""顺序执行批量解析,并在 SEQUENTIAL_BATCH_TASKS 中更新状态"""
global SEQUENTIAL_BATCH_TASKS
task_info = SEQUENTIAL_BATCH_TASKS.get(kb_id)
if not task_info:
@ -1118,7 +1079,6 @@ class KnowledgebaseService:
cursor = conn.cursor(dictionary=True)
# 查询需要解析的文档
# 注意:这里的条件要和前端期望的一致
query = """
SELECT id, name FROM document
WHERE kb_id = %s AND run != '3'

View File

@ -162,7 +162,7 @@ class Dealer:
q_vec = matchDense.embedding_data
# 在返回字段中加入查询向量字段
src.append(f"q_{len(q_vec)}_vec")
# 创建融合表达式设置向量匹配为95%全文为5%(可以调整权重)
# 创建融合表达式设置向量匹配为95%全文为5%
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
# 构建混合查询表达式
matchExprs = [matchText, matchDense, fusionExpr]
@ -210,10 +210,6 @@ class Dealer:
keywords = list(kwds) # 转为列表格式返回
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") # 获取高亮内容
aggs = self.dataStore.getAggregation(res, "docnm_kwd") # 执行基于文档名的聚合分析
print(f"ids:{ids}")
print(f"keywords:{keywords}")
print(f"highlight:{highlight}")
print(f"aggs:{aggs}")
return self.SearchResult(total=total, ids=ids, query_vector=q_vec, aggregation=aggs, highlight=highlight, field=self.dataStore.getFields(res, src), keywords=keywords)
@staticmethod
@ -322,6 +318,33 @@ class Dealer:
return np.array(rank_fea) * 10.0 + pageranks
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks", rank_feature: dict | None = None):
"""
对初步检索到的结果 (sres) 进行重排序
该方法结合了多种相似度/特征来计算每个结果的新排序分数
1. 文本相似度 (Token Similarity): 基于查询关键词与文档内容的词元匹配
2. 向量相似度 (Vector Similarity): 基于查询向量与文档向量的余弦相似度
3. 排序特征分数 (Rank Feature Scores): 如文档的 PageRank 值或与查询相关的标签特征得分
最终的排序分数是这几种分数的加权组合或直接相加
Args:
sres (SearchResult): 初步检索的结果对象包含查询向量文档ID字段内容等
query (str): 原始用户查询字符串
tkweight (float): 文本相似度在混合相似度计算中的权重
vtweight (float): 向量相似度在混合相似度计算中的权重
cfield (str): 用于提取主要文本内容以进行词元匹配的字段名默认为 "content_ltks"
rank_feature (dict | None): 用于计算排序特征分数的查询侧特征
例如 {PAGERANK_FLD: 10} 表示 PageRank 权重
或包含其他标签及其权重的字典
Returns:
tuple[np.ndarray, np.ndarray, np.ndarray]:
- sim (np.ndarray): 每个文档的最终重排序分数 (混合相似度 + 排序特征分数)
- tksim (np.ndarray): 每个文档的纯文本相似度分数
- vtsim (np.ndarray): 每个文档的纯向量相似度分数
如果初步检索结果为空 (sres.ids is empty)则返回三个空列表
"""
_, keywords = self.qryr.question(query)
vector_size = len(sres.query_vector)
vector_column = f"q_{vector_size}_vec"
@ -446,6 +469,7 @@ class Dealer:
# 执行搜索操作
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
# 执行重排序操作
if rerank_mdl and sres.total > 0:
sim, tsim, vsim = self.rerank_by_model(rerank_mdl, sres, question, 1 - vector_similarity_weight, vector_similarity_weight, rank_feature=rank_feature)
else: