fix(knowledgebase): 修复私人知识库访问权限问题
- 修改文档解析逻辑,使用知识库创建者作为 tenant_id
This commit is contained in:
parent
d0f17da9b4
commit
04439e6640
|
@ -14,16 +14,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import xxhash
|
||||||
|
import re
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
from rag.app.qa import rmPrefix, beAdoc
|
from rag.app.qa import rmPrefix, beAdoc
|
||||||
from rag.app.tag import label_question
|
|
||||||
from rag.nlp import search, rag_tokenizer
|
from rag.nlp import search, rag_tokenizer
|
||||||
|
|
||||||
# from rag.prompts import keyword_extraction, cross_languages
|
|
||||||
from rag.settings import PAGERANK_FLD
|
from rag.settings import PAGERANK_FLD
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
from api.db import LLMType, ParserType
|
from api.db import LLMType, ParserType
|
||||||
|
@ -34,8 +30,6 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
import xxhash
|
|
||||||
import re
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||||
|
@ -248,28 +242,6 @@ def create():
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
{
|
|
||||||
"similarity_threshold": 0.2,
|
|
||||||
"vector_similarity_weight": 0.30000000000000004,
|
|
||||||
"question": "香港",
|
|
||||||
"doc_ids": [],
|
|
||||||
"kb_id": "4b071030bc8e43f1bfb8b7831f320d2f",
|
|
||||||
"page": 1,
|
|
||||||
"size": 10
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"similarity_threshold": 0.2,
|
|
||||||
"vector_similarity_weight": 0.30000000000000004,
|
|
||||||
"question": "显著优势",
|
|
||||||
"doc_ids": [],
|
|
||||||
"kb_id": "1848bc54384611f0b33e4e66786d0323",
|
|
||||||
"page": 1,
|
|
||||||
"size": 10
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/retrieval_test", methods=["POST"]) # noqa: F821
|
@manager.route("/retrieval_test", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("kb_id", "question")
|
@validate_request("kb_id", "question")
|
||||||
|
@ -285,13 +257,14 @@ def retrieval_test():
|
||||||
doc_ids = req.get("doc_ids", [])
|
doc_ids = req.get("doc_ids", [])
|
||||||
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
||||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||||
top = int(req.get("top_k", 1024))
|
top = int(req.get("top_k", 1024)) # 此参数前端请求不会携带,默认即1024
|
||||||
# langs = req.get("cross_languages", []) # 获取跨语言设定
|
# langs = req.get("cross_languages", []) # 获取跨语言设定
|
||||||
tenant_ids = []
|
tenant_ids = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 查询当前用户所属的租户
|
# 查询当前用户所属的租户
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
|
|
||||||
# 验证知识库权限
|
# 验证知识库权限
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
|
@ -300,6 +273,7 @@ def retrieval_test():
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
# 获取知识库信息
|
# 获取知识库信息
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||||
if not e:
|
if not e:
|
||||||
|
|
|
@ -236,7 +236,7 @@ def process_table_content(content_list):
|
||||||
return new_content_list
|
return new_content_list
|
||||||
|
|
||||||
|
|
||||||
def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
|
||||||
"""
|
"""
|
||||||
执行文档解析的核心逻辑
|
执行文档解析的核心逻辑
|
||||||
|
|
||||||
|
@ -244,6 +244,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
||||||
doc_id (str): 文档ID.
|
doc_id (str): 文档ID.
|
||||||
doc_info (dict): 包含文档信息的字典 (name, location, type, kb_id, parser_config, created_by).
|
doc_info (dict): 包含文档信息的字典 (name, location, type, kb_id, parser_config, created_by).
|
||||||
file_info (dict): 包含文件信息的字典 (parent_id/bucket_name).
|
file_info (dict): 包含文件信息的字典 (parent_id/bucket_name).
|
||||||
|
kb_info (dict): 包含知识库信息的字典 (created_by).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: 包含解析结果的字典 (success, chunk_count).
|
dict: 包含解析结果的字典 (success, chunk_count).
|
||||||
|
@ -305,7 +306,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
||||||
_, file_extension = os.path.splitext(file_location)
|
_, file_extension = os.path.splitext(file_location)
|
||||||
file_type = doc_info["type"].lower()
|
file_type = doc_info["type"].lower()
|
||||||
bucket_name = file_info["parent_id"] # 文件存储的桶是 parent_id
|
bucket_name = file_info["parent_id"] # 文件存储的桶是 parent_id
|
||||||
tenant_id = doc_info["created_by"] # 文档创建者作为 tenant_id
|
tenant_id = kb_info["created_by"] # 知识库创建者作为 tenant_id
|
||||||
|
|
||||||
# 进度更新回调 (直接调用内部更新函数)
|
# 进度更新回调 (直接调用内部更新函数)
|
||||||
def update_progress(prog=None, msg=None):
|
def update_progress(prog=None, msg=None):
|
||||||
|
|
|
@ -10,16 +10,6 @@ from nltk import word_tokenize
|
||||||
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
||||||
|
|
||||||
|
|
||||||
# def get_project_base_directory():
|
|
||||||
# return os.path.abspath(
|
|
||||||
# os.path.join(
|
|
||||||
# os.path.dirname(os.path.realpath(__file__)),
|
|
||||||
# os.pardir,
|
|
||||||
# os.pardir,
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
class RagTokenizer:
|
class RagTokenizer:
|
||||||
def key_(self, line):
|
def key_(self, line):
|
||||||
return str(line.lower().encode("utf-8"))[2:-1]
|
return str(line.lower().encode("utf-8"))[2:-1]
|
||||||
|
|
|
@ -741,11 +741,11 @@ class KnowledgebaseService:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_document(cls, doc_id):
|
def parse_document(cls, doc_id):
|
||||||
"""解析文档(调用解析逻辑)"""
|
"""解析文档"""
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
# 1. 获取文档和文件信息
|
# 获取文档和文件信息
|
||||||
conn = cls._get_db_connection()
|
conn = cls._get_db_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
@ -765,28 +765,40 @@ class KnowledgebaseService:
|
||||||
f2d_query = "SELECT file_id FROM file2document WHERE document_id = %s"
|
f2d_query = "SELECT file_id FROM file2document WHERE document_id = %s"
|
||||||
cursor.execute(f2d_query, (doc_id,))
|
cursor.execute(f2d_query, (doc_id,))
|
||||||
f2d_result = cursor.fetchone()
|
f2d_result = cursor.fetchone()
|
||||||
|
|
||||||
if not f2d_result:
|
if not f2d_result:
|
||||||
raise Exception("无法找到文件到文档的映射关系")
|
raise Exception("无法找到文件到文档的映射关系")
|
||||||
file_id = f2d_result["file_id"]
|
|
||||||
|
|
||||||
|
file_id = f2d_result["file_id"]
|
||||||
file_query = "SELECT parent_id FROM file WHERE id = %s"
|
file_query = "SELECT parent_id FROM file WHERE id = %s"
|
||||||
cursor.execute(file_query, (file_id,))
|
cursor.execute(file_query, (file_id,))
|
||||||
file_info = cursor.fetchone()
|
file_info = cursor.fetchone()
|
||||||
|
|
||||||
if not file_info:
|
if not file_info:
|
||||||
raise Exception("无法找到文件记录")
|
raise Exception("无法找到文件记录")
|
||||||
|
|
||||||
|
# 获取知识库创建人信息
|
||||||
|
# 根据doc_id查询document这张表,得到kb_id
|
||||||
|
kb_id_query = "SELECT kb_id FROM document WHERE id = %s"
|
||||||
|
cursor.execute(kb_id_query, (doc_id,))
|
||||||
|
kb_id = cursor.fetchone()
|
||||||
|
# 根据kb_id查询knowledgebase这张表,得到created_by
|
||||||
|
kb_query = "SELECT created_by FROM knowledgebase WHERE id = %s"
|
||||||
|
cursor.execute(kb_query, (kb_id["kb_id"],))
|
||||||
|
kb_info = cursor.fetchone()
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
conn.close()
|
conn.close()
|
||||||
conn = None # 确保连接已关闭
|
conn = None # 确保连接已关闭
|
||||||
|
|
||||||
# 2. 更新文档状态为处理中 (使用 parser 模块的函数)
|
# 更新文档状态为处理中 (使用 parser 模块的函数)
|
||||||
_update_document_progress(doc_id, status="2", run="1", progress=0.0, message="开始解析")
|
_update_document_progress(doc_id, status="2", run="1", progress=0.0, message="开始解析")
|
||||||
|
|
||||||
# 3. 调用后台解析函数
|
# 调用后台解析函数
|
||||||
embedding_config = cls.get_system_embedding_config()
|
embedding_config = cls.get_system_embedding_config()
|
||||||
parse_result = perform_parse(doc_id, doc_info, file_info, embedding_config)
|
parse_result = perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info)
|
||||||
|
|
||||||
# 4. 返回解析结果
|
# 返回解析结果
|
||||||
return parse_result
|
return parse_result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1045,62 +1057,11 @@ class KnowledgebaseService:
|
||||||
return False, f"连接测试失败: {message}"
|
return False, f"连接测试失败: {message}"
|
||||||
|
|
||||||
return True, f"连接成功: {message}"
|
return True, f"连接成功: {message}"
|
||||||
# 测试通过,保存或更新配置到数据库(先不保存,以防冲突)
|
|
||||||
# conn = None
|
|
||||||
# cursor = None
|
|
||||||
# try:
|
|
||||||
# conn = cls._get_db_connection()
|
|
||||||
# cursor = conn.cursor()
|
|
||||||
|
|
||||||
# # 检查 TenantLLM 记录是否存在
|
|
||||||
# check_query = """
|
|
||||||
# SELECT id FROM tenant_llm
|
|
||||||
# WHERE tenant_id = %s AND llm_name = %s
|
|
||||||
# """
|
|
||||||
# cursor.execute(check_query, (tenant_id, llm_name))
|
|
||||||
# existing_config = cursor.fetchone()
|
|
||||||
|
|
||||||
# now = datetime.now()
|
|
||||||
# if existing_config:
|
|
||||||
# # 更新记录
|
|
||||||
# update_sql = """
|
|
||||||
# UPDATE tenant_llm
|
|
||||||
# SET api_key = %s, api_base = %s, max_tokens = %s, update_time = %s, update_date = %s
|
|
||||||
# WHERE id = %s
|
|
||||||
# """
|
|
||||||
# update_params = (api_key, api_base, max_tokens, now, now.date(), existing_config[0])
|
|
||||||
# cursor.execute(update_sql, update_params)
|
|
||||||
# print(f"已更新 TenantLLM 记录 (ID: {existing_config[0]})")
|
|
||||||
# else:
|
|
||||||
# # 插入新记录
|
|
||||||
# insert_sql = """
|
|
||||||
# INSERT INTO tenant_llm (tenant_id, llm_factory, model_type, llm_name, api_key, api_base, max_tokens, create_time, create_date, update_time, update_date, used_tokens)
|
|
||||||
# VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
||||||
# """
|
|
||||||
# insert_params = (tenant_id, "VLLM", "embedding", llm_name, api_key, api_base, max_tokens, now, now.date(), now, now.date(), 0) # used_tokens 默认为 0
|
|
||||||
# cursor.execute(insert_sql, insert_params)
|
|
||||||
# print(f"已创建新的 TenantLLM 记录")
|
|
||||||
|
|
||||||
# conn.commit() # 提交事务
|
|
||||||
# return True, "配置已成功保存"
|
|
||||||
|
|
||||||
# except Exception as e:
|
|
||||||
# if conn:
|
|
||||||
# conn.rollback() # 出错时回滚
|
|
||||||
# print(f"保存系统 Embedding 配置时数据库出错: {e}")
|
|
||||||
# traceback.print_exc()
|
|
||||||
# # 返回 False 和错误信息给路由层
|
|
||||||
# return False, f"保存配置时数据库出错: {e}"
|
|
||||||
# finally:
|
|
||||||
# if cursor:
|
|
||||||
# cursor.close()
|
|
||||||
# if conn and conn.is_connected():
|
|
||||||
# conn.close()
|
|
||||||
|
|
||||||
# 顺序批量解析 (核心逻辑,在后台线程运行)
|
# 顺序批量解析 (核心逻辑,在后台线程运行)
|
||||||
@classmethod
|
@classmethod
|
||||||
def _run_sequential_batch_parse(cls, kb_id):
|
def _run_sequential_batch_parse(cls, kb_id):
|
||||||
"""【内部方法】顺序执行批量解析,并在 SEQUENTIAL_BATCH_TASKS 中更新状态"""
|
"""顺序执行批量解析,并在 SEQUENTIAL_BATCH_TASKS 中更新状态"""
|
||||||
global SEQUENTIAL_BATCH_TASKS
|
global SEQUENTIAL_BATCH_TASKS
|
||||||
task_info = SEQUENTIAL_BATCH_TASKS.get(kb_id)
|
task_info = SEQUENTIAL_BATCH_TASKS.get(kb_id)
|
||||||
if not task_info:
|
if not task_info:
|
||||||
|
@ -1118,7 +1079,6 @@ class KnowledgebaseService:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 查询需要解析的文档
|
# 查询需要解析的文档
|
||||||
# 注意:这里的条件要和前端期望的一致
|
|
||||||
query = """
|
query = """
|
||||||
SELECT id, name FROM document
|
SELECT id, name FROM document
|
||||||
WHERE kb_id = %s AND run != '3'
|
WHERE kb_id = %s AND run != '3'
|
||||||
|
|
|
@ -162,7 +162,7 @@ class Dealer:
|
||||||
q_vec = matchDense.embedding_data
|
q_vec = matchDense.embedding_data
|
||||||
# 在返回字段中加入查询向量字段
|
# 在返回字段中加入查询向量字段
|
||||||
src.append(f"q_{len(q_vec)}_vec")
|
src.append(f"q_{len(q_vec)}_vec")
|
||||||
# 创建融合表达式:设置向量匹配为95%,全文为5%(可以调整权重)
|
# 创建融合表达式:设置向量匹配为95%,全文为5%
|
||||||
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
|
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
|
||||||
# 构建混合查询表达式
|
# 构建混合查询表达式
|
||||||
matchExprs = [matchText, matchDense, fusionExpr]
|
matchExprs = [matchText, matchDense, fusionExpr]
|
||||||
|
@ -210,10 +210,6 @@ class Dealer:
|
||||||
keywords = list(kwds) # 转为列表格式返回
|
keywords = list(kwds) # 转为列表格式返回
|
||||||
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") # 获取高亮内容
|
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") # 获取高亮内容
|
||||||
aggs = self.dataStore.getAggregation(res, "docnm_kwd") # 执行基于文档名的聚合分析
|
aggs = self.dataStore.getAggregation(res, "docnm_kwd") # 执行基于文档名的聚合分析
|
||||||
print(f"ids:{ids}")
|
|
||||||
print(f"keywords:{keywords}")
|
|
||||||
print(f"highlight:{highlight}")
|
|
||||||
print(f"aggs:{aggs}")
|
|
||||||
return self.SearchResult(total=total, ids=ids, query_vector=q_vec, aggregation=aggs, highlight=highlight, field=self.dataStore.getFields(res, src), keywords=keywords)
|
return self.SearchResult(total=total, ids=ids, query_vector=q_vec, aggregation=aggs, highlight=highlight, field=self.dataStore.getFields(res, src), keywords=keywords)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -322,6 +318,33 @@ class Dealer:
|
||||||
return np.array(rank_fea) * 10.0 + pageranks
|
return np.array(rank_fea) * 10.0 + pageranks
|
||||||
|
|
||||||
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks", rank_feature: dict | None = None):
|
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks", rank_feature: dict | None = None):
|
||||||
|
"""
|
||||||
|
对初步检索到的结果 (sres) 进行重排序。
|
||||||
|
|
||||||
|
该方法结合了多种相似度/特征来计算每个结果的新排序分数:
|
||||||
|
1. 文本相似度 (Token Similarity): 基于查询关键词与文档内容的词元匹配。
|
||||||
|
2. 向量相似度 (Vector Similarity): 基于查询向量与文档向量的余弦相似度。
|
||||||
|
3. 排序特征分数 (Rank Feature Scores): 如文档的 PageRank 值或与查询相关的标签特征得分。
|
||||||
|
|
||||||
|
最终的排序分数是这几种分数的加权组合(或直接相加)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sres (SearchResult): 初步检索的结果对象,包含查询向量、文档ID、字段内容等。
|
||||||
|
query (str): 原始用户查询字符串。
|
||||||
|
tkweight (float): 文本相似度在混合相似度计算中的权重。
|
||||||
|
vtweight (float): 向量相似度在混合相似度计算中的权重。
|
||||||
|
cfield (str): 用于提取主要文本内容以进行词元匹配的字段名,默认为 "content_ltks"。
|
||||||
|
rank_feature (dict | None): 用于计算排序特征分数的查询侧特征,
|
||||||
|
例如 {PAGERANK_FLD: 10} 表示 PageRank 权重,
|
||||||
|
或包含其他标签及其权重的字典。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
- sim (np.ndarray): 每个文档的最终重排序分数 (混合相似度 + 排序特征分数)。
|
||||||
|
- tksim (np.ndarray): 每个文档的纯文本相似度分数。
|
||||||
|
- vtsim (np.ndarray): 每个文档的纯向量相似度分数。
|
||||||
|
如果初步检索结果为空 (sres.ids is empty),则返回三个空列表。
|
||||||
|
"""
|
||||||
_, keywords = self.qryr.question(query)
|
_, keywords = self.qryr.question(query)
|
||||||
vector_size = len(sres.query_vector)
|
vector_size = len(sres.query_vector)
|
||||||
vector_column = f"q_{vector_size}_vec"
|
vector_column = f"q_{vector_size}_vec"
|
||||||
|
@ -446,6 +469,7 @@ class Dealer:
|
||||||
# 执行搜索操作
|
# 执行搜索操作
|
||||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
|
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
|
||||||
|
|
||||||
|
# 执行重排序操作
|
||||||
if rerank_mdl and sres.total > 0:
|
if rerank_mdl and sres.total > 0:
|
||||||
sim, tsim, vsim = self.rerank_by_model(rerank_mdl, sres, question, 1 - vector_similarity_weight, vector_similarity_weight, rank_feature=rank_feature)
|
sim, tsim, vsim = self.rerank_by_model(rerank_mdl, sres, question, 1 - vector_similarity_weight, vector_similarity_weight, rank_feature=rank_feature)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue