fix(knowledgebase): 修复私人知识库访问权限问题
- 修改文档解析逻辑,使用知识库创建者作为 tenant_id
This commit is contained in:
parent
d0f17da9b4
commit
04439e6640
|
@ -14,16 +14,12 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
import datetime
|
||||
import json
|
||||
|
||||
import xxhash
|
||||
import re
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from rag.app.qa import rmPrefix, beAdoc
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
|
||||
# from rag.prompts import keyword_extraction, cross_languages
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils import rmSpace
|
||||
from api.db import LLMType, ParserType
|
||||
|
@ -34,8 +30,6 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
|||
from api.db.services.document_service import DocumentService
|
||||
from api import settings
|
||||
from api.utils.api_utils import get_json_result
|
||||
import xxhash
|
||||
import re
|
||||
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
|
@ -248,28 +242,6 @@ def create():
|
|||
return server_error_response(e)
|
||||
|
||||
|
||||
"""
|
||||
{
|
||||
"similarity_threshold": 0.2,
|
||||
"vector_similarity_weight": 0.30000000000000004,
|
||||
"question": "香港",
|
||||
"doc_ids": [],
|
||||
"kb_id": "4b071030bc8e43f1bfb8b7831f320d2f",
|
||||
"page": 1,
|
||||
"size": 10
|
||||
},
|
||||
{
|
||||
"similarity_threshold": 0.2,
|
||||
"vector_similarity_weight": 0.30000000000000004,
|
||||
"question": "显著优势",
|
||||
"doc_ids": [],
|
||||
"kb_id": "1848bc54384611f0b33e4e66786d0323",
|
||||
"page": 1,
|
||||
"size": 10
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
@manager.route("/retrieval_test", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id", "question")
|
||||
|
@ -285,13 +257,14 @@ def retrieval_test():
|
|||
doc_ids = req.get("doc_ids", [])
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
top = int(req.get("top_k", 1024)) # 此参数前端请求不会携带,默认即1024
|
||||
# langs = req.get("cross_languages", []) # 获取跨语言设定
|
||||
tenant_ids = []
|
||||
|
||||
try:
|
||||
# 查询当前用户所属的租户
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
|
||||
# 验证知识库权限
|
||||
for kb_id in kb_ids:
|
||||
for tenant in tenants:
|
||||
|
@ -300,6 +273,7 @@ def retrieval_test():
|
|||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
# 获取知识库信息
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||
if not e:
|
||||
|
|
|
@ -236,7 +236,7 @@ def process_table_content(content_list):
|
|||
return new_content_list
|
||||
|
||||
|
||||
def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
||||
def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
|
||||
"""
|
||||
执行文档解析的核心逻辑
|
||||
|
||||
|
@ -244,6 +244,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
doc_id (str): 文档ID.
|
||||
doc_info (dict): 包含文档信息的字典 (name, location, type, kb_id, parser_config, created_by).
|
||||
file_info (dict): 包含文件信息的字典 (parent_id/bucket_name).
|
||||
kb_info (dict): 包含知识库信息的字典 (created_by).
|
||||
|
||||
Returns:
|
||||
dict: 包含解析结果的字典 (success, chunk_count).
|
||||
|
@ -305,7 +306,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
_, file_extension = os.path.splitext(file_location)
|
||||
file_type = doc_info["type"].lower()
|
||||
bucket_name = file_info["parent_id"] # 文件存储的桶是 parent_id
|
||||
tenant_id = doc_info["created_by"] # 文档创建者作为 tenant_id
|
||||
tenant_id = kb_info["created_by"] # 知识库创建者作为 tenant_id
|
||||
|
||||
# 进度更新回调 (直接调用内部更新函数)
|
||||
def update_progress(prog=None, msg=None):
|
||||
|
|
|
@ -10,16 +10,6 @@ from nltk import word_tokenize
|
|||
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
||||
|
||||
|
||||
# def get_project_base_directory():
|
||||
# return os.path.abspath(
|
||||
# os.path.join(
|
||||
# os.path.dirname(os.path.realpath(__file__)),
|
||||
# os.pardir,
|
||||
# os.pardir,
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
class RagTokenizer:
|
||||
def key_(self, line):
|
||||
return str(line.lower().encode("utf-8"))[2:-1]
|
||||
|
|
|
@ -741,11 +741,11 @@ class KnowledgebaseService:
|
|||
|
||||
@classmethod
|
||||
def parse_document(cls, doc_id):
|
||||
"""解析文档(调用解析逻辑)"""
|
||||
"""解析文档"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 1. 获取文档和文件信息
|
||||
# 获取文档和文件信息
|
||||
conn = cls._get_db_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
|
@ -765,28 +765,40 @@ class KnowledgebaseService:
|
|||
f2d_query = "SELECT file_id FROM file2document WHERE document_id = %s"
|
||||
cursor.execute(f2d_query, (doc_id,))
|
||||
f2d_result = cursor.fetchone()
|
||||
|
||||
if not f2d_result:
|
||||
raise Exception("无法找到文件到文档的映射关系")
|
||||
file_id = f2d_result["file_id"]
|
||||
|
||||
file_id = f2d_result["file_id"]
|
||||
file_query = "SELECT parent_id FROM file WHERE id = %s"
|
||||
cursor.execute(file_query, (file_id,))
|
||||
file_info = cursor.fetchone()
|
||||
|
||||
if not file_info:
|
||||
raise Exception("无法找到文件记录")
|
||||
|
||||
# 获取知识库创建人信息
|
||||
# 根据doc_id查询document这张表,得到kb_id
|
||||
kb_id_query = "SELECT kb_id FROM document WHERE id = %s"
|
||||
cursor.execute(kb_id_query, (doc_id,))
|
||||
kb_id = cursor.fetchone()
|
||||
# 根据kb_id查询knowledgebase这张表,得到created_by
|
||||
kb_query = "SELECT created_by FROM knowledgebase WHERE id = %s"
|
||||
cursor.execute(kb_query, (kb_id["kb_id"],))
|
||||
kb_info = cursor.fetchone()
|
||||
|
||||
cursor.close()
|
||||
conn.close()
|
||||
conn = None # 确保连接已关闭
|
||||
|
||||
# 2. 更新文档状态为处理中 (使用 parser 模块的函数)
|
||||
# 更新文档状态为处理中 (使用 parser 模块的函数)
|
||||
_update_document_progress(doc_id, status="2", run="1", progress=0.0, message="开始解析")
|
||||
|
||||
# 3. 调用后台解析函数
|
||||
# 调用后台解析函数
|
||||
embedding_config = cls.get_system_embedding_config()
|
||||
parse_result = perform_parse(doc_id, doc_info, file_info, embedding_config)
|
||||
parse_result = perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info)
|
||||
|
||||
# 4. 返回解析结果
|
||||
# 返回解析结果
|
||||
return parse_result
|
||||
|
||||
except Exception as e:
|
||||
|
@ -1045,62 +1057,11 @@ class KnowledgebaseService:
|
|||
return False, f"连接测试失败: {message}"
|
||||
|
||||
return True, f"连接成功: {message}"
|
||||
# 测试通过,保存或更新配置到数据库(先不保存,以防冲突)
|
||||
# conn = None
|
||||
# cursor = None
|
||||
# try:
|
||||
# conn = cls._get_db_connection()
|
||||
# cursor = conn.cursor()
|
||||
|
||||
# # 检查 TenantLLM 记录是否存在
|
||||
# check_query = """
|
||||
# SELECT id FROM tenant_llm
|
||||
# WHERE tenant_id = %s AND llm_name = %s
|
||||
# """
|
||||
# cursor.execute(check_query, (tenant_id, llm_name))
|
||||
# existing_config = cursor.fetchone()
|
||||
|
||||
# now = datetime.now()
|
||||
# if existing_config:
|
||||
# # 更新记录
|
||||
# update_sql = """
|
||||
# UPDATE tenant_llm
|
||||
# SET api_key = %s, api_base = %s, max_tokens = %s, update_time = %s, update_date = %s
|
||||
# WHERE id = %s
|
||||
# """
|
||||
# update_params = (api_key, api_base, max_tokens, now, now.date(), existing_config[0])
|
||||
# cursor.execute(update_sql, update_params)
|
||||
# print(f"已更新 TenantLLM 记录 (ID: {existing_config[0]})")
|
||||
# else:
|
||||
# # 插入新记录
|
||||
# insert_sql = """
|
||||
# INSERT INTO tenant_llm (tenant_id, llm_factory, model_type, llm_name, api_key, api_base, max_tokens, create_time, create_date, update_time, update_date, used_tokens)
|
||||
# VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
# """
|
||||
# insert_params = (tenant_id, "VLLM", "embedding", llm_name, api_key, api_base, max_tokens, now, now.date(), now, now.date(), 0) # used_tokens 默认为 0
|
||||
# cursor.execute(insert_sql, insert_params)
|
||||
# print(f"已创建新的 TenantLLM 记录")
|
||||
|
||||
# conn.commit() # 提交事务
|
||||
# return True, "配置已成功保存"
|
||||
|
||||
# except Exception as e:
|
||||
# if conn:
|
||||
# conn.rollback() # 出错时回滚
|
||||
# print(f"保存系统 Embedding 配置时数据库出错: {e}")
|
||||
# traceback.print_exc()
|
||||
# # 返回 False 和错误信息给路由层
|
||||
# return False, f"保存配置时数据库出错: {e}"
|
||||
# finally:
|
||||
# if cursor:
|
||||
# cursor.close()
|
||||
# if conn and conn.is_connected():
|
||||
# conn.close()
|
||||
|
||||
# 顺序批量解析 (核心逻辑,在后台线程运行)
|
||||
@classmethod
|
||||
def _run_sequential_batch_parse(cls, kb_id):
|
||||
"""【内部方法】顺序执行批量解析,并在 SEQUENTIAL_BATCH_TASKS 中更新状态"""
|
||||
"""顺序执行批量解析,并在 SEQUENTIAL_BATCH_TASKS 中更新状态"""
|
||||
global SEQUENTIAL_BATCH_TASKS
|
||||
task_info = SEQUENTIAL_BATCH_TASKS.get(kb_id)
|
||||
if not task_info:
|
||||
|
@ -1118,7 +1079,6 @@ class KnowledgebaseService:
|
|||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询需要解析的文档
|
||||
# 注意:这里的条件要和前端期望的一致
|
||||
query = """
|
||||
SELECT id, name FROM document
|
||||
WHERE kb_id = %s AND run != '3'
|
||||
|
|
|
@ -162,7 +162,7 @@ class Dealer:
|
|||
q_vec = matchDense.embedding_data
|
||||
# 在返回字段中加入查询向量字段
|
||||
src.append(f"q_{len(q_vec)}_vec")
|
||||
# 创建融合表达式:设置向量匹配为95%,全文为5%(可以调整权重)
|
||||
# 创建融合表达式:设置向量匹配为95%,全文为5%
|
||||
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
|
||||
# 构建混合查询表达式
|
||||
matchExprs = [matchText, matchDense, fusionExpr]
|
||||
|
@ -210,10 +210,6 @@ class Dealer:
|
|||
keywords = list(kwds) # 转为列表格式返回
|
||||
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") # 获取高亮内容
|
||||
aggs = self.dataStore.getAggregation(res, "docnm_kwd") # 执行基于文档名的聚合分析
|
||||
print(f"ids:{ids}")
|
||||
print(f"keywords:{keywords}")
|
||||
print(f"highlight:{highlight}")
|
||||
print(f"aggs:{aggs}")
|
||||
return self.SearchResult(total=total, ids=ids, query_vector=q_vec, aggregation=aggs, highlight=highlight, field=self.dataStore.getFields(res, src), keywords=keywords)
|
||||
|
||||
@staticmethod
|
||||
|
@ -322,6 +318,33 @@ class Dealer:
|
|||
return np.array(rank_fea) * 10.0 + pageranks
|
||||
|
||||
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks", rank_feature: dict | None = None):
|
||||
"""
|
||||
对初步检索到的结果 (sres) 进行重排序。
|
||||
|
||||
该方法结合了多种相似度/特征来计算每个结果的新排序分数:
|
||||
1. 文本相似度 (Token Similarity): 基于查询关键词与文档内容的词元匹配。
|
||||
2. 向量相似度 (Vector Similarity): 基于查询向量与文档向量的余弦相似度。
|
||||
3. 排序特征分数 (Rank Feature Scores): 如文档的 PageRank 值或与查询相关的标签特征得分。
|
||||
|
||||
最终的排序分数是这几种分数的加权组合(或直接相加)。
|
||||
|
||||
Args:
|
||||
sres (SearchResult): 初步检索的结果对象,包含查询向量、文档ID、字段内容等。
|
||||
query (str): 原始用户查询字符串。
|
||||
tkweight (float): 文本相似度在混合相似度计算中的权重。
|
||||
vtweight (float): 向量相似度在混合相似度计算中的权重。
|
||||
cfield (str): 用于提取主要文本内容以进行词元匹配的字段名,默认为 "content_ltks"。
|
||||
rank_feature (dict | None): 用于计算排序特征分数的查询侧特征,
|
||||
例如 {PAGERANK_FLD: 10} 表示 PageRank 权重,
|
||||
或包含其他标签及其权重的字典。
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
- sim (np.ndarray): 每个文档的最终重排序分数 (混合相似度 + 排序特征分数)。
|
||||
- tksim (np.ndarray): 每个文档的纯文本相似度分数。
|
||||
- vtsim (np.ndarray): 每个文档的纯向量相似度分数。
|
||||
如果初步检索结果为空 (sres.ids is empty),则返回三个空列表。
|
||||
"""
|
||||
_, keywords = self.qryr.question(query)
|
||||
vector_size = len(sres.query_vector)
|
||||
vector_column = f"q_{vector_size}_vec"
|
||||
|
@ -446,6 +469,7 @@ class Dealer:
|
|||
# 执行搜索操作
|
||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
|
||||
|
||||
# 执行重排序操作
|
||||
if rerank_mdl and sres.total > 0:
|
||||
sim, tsim, vsim = self.rerank_by_model(rerank_mdl, sres, question, 1 - vector_similarity_weight, vector_similarity_weight, rank_feature=rank_feature)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue