2025-03-24 11:19:28 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from api . db import StatusEnum , TenantPermission
from api . db . db_models import Knowledgebase , DB , Tenant , User , UserTenant , Document
from api . db . services . common_service import CommonService
from peewee import fn
class KnowledgebaseService ( CommonService ) :
model = Knowledgebase
2025-03-31 10:53:42 +08:00
@classmethod
@DB.connection_context ( )
def is_parsed_done ( cls , kb_id ) :
"""
Check if all documents in the knowledge base have completed parsing
Args :
kb_id : Knowledge base ID
Returns :
If all documents are parsed successfully , returns ( True , None )
If any document is not fully parsed , returns ( False , error_message )
"""
from api . db import TaskStatus
from api . db . services . document_service import DocumentService
# Get knowledge base information
kbs = cls . query ( id = kb_id )
if not kbs :
return False , " Knowledge base not found "
kb = kbs [ 0 ]
# Get all documents in the knowledge base
docs , _ = DocumentService . get_by_kb_id ( kb_id , 1 , 1000 , " create_time " , True , " " )
# Check parsing status of each document
for doc in docs :
# If document is being parsed, don't allow chat creation
if doc [ ' run ' ] == TaskStatus . RUNNING . value or doc [ ' run ' ] == TaskStatus . CANCEL . value or doc [ ' run ' ] == TaskStatus . FAIL . value :
return False , f " Document ' { doc [ ' name ' ] } ' in dataset ' { kb . name } ' is still being parsed. Please wait until all documents are parsed before starting a chat. "
# If document is not yet parsed and has no chunks, don't allow chat creation
if doc [ ' run ' ] == TaskStatus . UNSTART . value and doc [ ' chunk_num ' ] == 0 :
return False , f " Document ' { doc [ ' name ' ] } ' in dataset ' { kb . name } ' has not been parsed yet. Please parse all documents before starting a chat. "
return True , None
2025-03-24 11:19:28 +08:00
@classmethod
@DB.connection_context ( )
def list_documents_by_ids ( cls , kb_ids ) :
doc_ids = cls . model . select ( Document . id . alias ( " document_id " ) ) . join ( Document , on = ( cls . model . id == Document . kb_id ) ) . where (
cls . model . id . in_ ( kb_ids )
)
doc_ids = list ( doc_ids . dicts ( ) )
doc_ids = [ doc [ " document_id " ] for doc in doc_ids ]
return doc_ids
@classmethod
@DB.connection_context ( )
def get_by_tenant_ids ( cls , joined_tenant_ids , user_id ,
page_number , items_per_page ,
orderby , desc , keywords ,
parser_id = None
) :
2025-08-13 17:00:11 +08:00
"""
根据租户ID列表获取知识库记录的方法 。
参数 :
joined_tenant_ids : 列表 , 包含用户已加入的租户ID 。
user_id : 用户ID , 用于查询特定用户的知识库 。
page_number : 分页页码 , 用于分页显示知识库记录 。
items_per_page : 每页显示的记录数 。
orderby : 字符串 , 指定排序字段 。
desc : 布尔值 , 指示是否按降序排序 。
keywords : 字符串 , 用于模糊搜索知识库名称 。
parser_id : 可选参数 , 指定解析器ID以过滤知识库记录 。
返回值 :
返回两个元素的元组 , 第一个元素是知识库记录的字典列表 , 第二个元素是知识库记录的总数 。
"""
# 定义需要查询的字段
2025-03-24 11:19:28 +08:00
fields = [
2025-08-13 17:00:11 +08:00
cls . model . id , # 知识库ID
cls . model . avatar , # 知识库头像
cls . model . name , # 知识库名称
cls . model . language , # 知识库语言
cls . model . description , # 知识库描述
cls . model . permission , # 知识库权限
cls . model . doc_num , # 文档数量
cls . model . token_num , # 令牌数量
cls . model . chunk_num , # 块数量
cls . model . parser_id , # 解析器ID
cls . model . embd_id , # 嵌入ID
User . nickname , # 租户昵称
User . avatar . alias ( ' tenant_avatar ' ) , # 租户头像,使用别名以便区分
cls . model . update_time # 更新时间
2025-03-24 11:19:28 +08:00
]
2025-08-13 17:00:11 +08:00
# 如果提供了关键词,则进行模糊搜索
2025-03-24 11:19:28 +08:00
if keywords :
kbs = cls . model . select ( * fields ) . join ( User , on = ( cls . model . tenant_id == User . id ) ) . where (
2025-08-13 17:00:11 +08:00
# 筛选条件: 租户ID在列表中且权限为团队或租户ID等于用户ID
2025-03-24 11:19:28 +08:00
( ( cls . model . tenant_id . in_ ( joined_tenant_ids ) & ( cls . model . permission ==
TenantPermission . TEAM . value ) ) | (
cls . model . tenant_id == user_id ) )
2025-08-13 17:00:11 +08:00
# 状态必须为有效
2025-03-24 11:19:28 +08:00
& ( cls . model . status == StatusEnum . VALID . value ) ,
2025-08-13 17:00:11 +08:00
# 名称包含关键词(不区分大小写)
2025-03-24 11:19:28 +08:00
( fn . LOWER ( cls . model . name ) . contains ( keywords . lower ( ) ) )
)
else :
2025-08-13 17:00:11 +08:00
# 如果没有提供关键词,则不进行搜索,只筛选符合条件的知识库
2025-03-24 11:19:28 +08:00
kbs = cls . model . select ( * fields ) . join ( User , on = ( cls . model . tenant_id == User . id ) ) . where (
( ( cls . model . tenant_id . in_ ( joined_tenant_ids ) & ( cls . model . permission ==
TenantPermission . TEAM . value ) ) | (
cls . model . tenant_id == user_id ) )
& ( cls . model . status == StatusEnum . VALID . value )
)
2025-08-13 17:00:11 +08:00
# 如果提供了解析器ID, 则进一步过滤知识库记录
2025-03-24 11:19:28 +08:00
if parser_id :
kbs = kbs . where ( cls . model . parser_id == parser_id )
2025-08-13 17:00:11 +08:00
# 根据desc参数决定排序方式
2025-03-24 11:19:28 +08:00
if desc :
2025-08-13 17:00:11 +08:00
kbs = kbs . order_by ( cls . model . getter_by ( orderby ) . desc ( ) ) # 按降序排序
2025-03-24 11:19:28 +08:00
else :
2025-08-13 17:00:11 +08:00
kbs = kbs . order_by ( cls . model . getter_by ( orderby ) . asc ( ) ) # 按升序排序
# 获取知识库记录总数
2025-03-24 11:19:28 +08:00
count = kbs . count ( )
2025-08-13 17:00:11 +08:00
# 对查询结果进行分页
2025-03-24 11:19:28 +08:00
kbs = kbs . paginate ( page_number , items_per_page )
2025-08-13 17:00:11 +08:00
# 将查询结果转换为字典列表并返回,同时返回记录总数
2025-03-24 11:19:28 +08:00
return list ( kbs . dicts ( ) ) , count
@classmethod
@DB.connection_context ( )
def get_kb_ids ( cls , tenant_id ) :
fields = [
cls . model . id ,
]
kbs = cls . model . select ( * fields ) . where ( cls . model . tenant_id == tenant_id )
kb_ids = [ kb . id for kb in kbs ]
return kb_ids
@classmethod
@DB.connection_context ( )
def get_detail ( cls , kb_id ) :
fields = [
cls . model . id ,
# Tenant.embd_id,
cls . model . embd_id ,
cls . model . avatar ,
cls . model . name ,
cls . model . language ,
cls . model . description ,
cls . model . permission ,
cls . model . doc_num ,
cls . model . token_num ,
cls . model . chunk_num ,
cls . model . parser_id ,
cls . model . parser_config ,
cls . model . pagerank ]
kbs = cls . model . select ( * fields ) . join ( Tenant , on = (
( Tenant . id == cls . model . tenant_id ) & ( Tenant . status == StatusEnum . VALID . value ) ) ) . where (
( cls . model . id == kb_id ) ,
( cls . model . status == StatusEnum . VALID . value )
)
if not kbs :
return
d = kbs [ 0 ] . to_dict ( )
# d["embd_id"] = kbs[0].tenant.embd_id
return d
@classmethod
@DB.connection_context ( )
def update_parser_config ( cls , id , config ) :
e , m = cls . get_by_id ( id )
if not e :
raise LookupError ( f " knowledgebase( { id } ) not found. " )
def dfs_update ( old , new ) :
for k , v in new . items ( ) :
if k not in old :
old [ k ] = v
continue
if isinstance ( v , dict ) :
assert isinstance ( old [ k ] , dict )
dfs_update ( old [ k ] , v )
elif isinstance ( v , list ) :
assert isinstance ( old [ k ] , list )
old [ k ] = list ( set ( old [ k ] + v ) )
else :
old [ k ] = v
dfs_update ( m . parser_config , config )
cls . update_by_id ( id , { " parser_config " : m . parser_config } )
@classmethod
@DB.connection_context ( )
def get_field_map ( cls , ids ) :
conf = { }
for k in cls . get_by_ids ( ids ) :
if k . parser_config and " field_map " in k . parser_config :
conf . update ( k . parser_config [ " field_map " ] )
return conf
@classmethod
@DB.connection_context ( )
def get_by_name ( cls , kb_name , tenant_id ) :
kb = cls . model . select ( ) . where (
( cls . model . name == kb_name )
& ( cls . model . tenant_id == tenant_id )
& ( cls . model . status == StatusEnum . VALID . value )
)
if kb :
return True , kb [ 0 ]
return False , None
@classmethod
@DB.connection_context ( )
def get_all_ids ( cls ) :
return [ m [ " id " ] for m in cls . model . select ( cls . model . id ) . dicts ( ) ]
@classmethod
@DB.connection_context ( )
def get_list ( cls , joined_tenant_ids , user_id ,
page_number , items_per_page , orderby , desc , id , name ) :
kbs = cls . model . select ( )
if id :
kbs = kbs . where ( cls . model . id == id )
if name :
kbs = kbs . where ( cls . model . name == name )
kbs = kbs . where (
( ( cls . model . tenant_id . in_ ( joined_tenant_ids ) & ( cls . model . permission ==
TenantPermission . TEAM . value ) ) | (
cls . model . tenant_id == user_id ) )
& ( cls . model . status == StatusEnum . VALID . value )
)
if desc :
kbs = kbs . order_by ( cls . model . getter_by ( orderby ) . desc ( ) )
else :
kbs = kbs . order_by ( cls . model . getter_by ( orderby ) . asc ( ) )
kbs = kbs . paginate ( page_number , items_per_page )
return list ( kbs . dicts ( ) )
@classmethod
@DB.connection_context ( )
def accessible ( cls , kb_id , user_id ) :
docs = cls . model . select (
cls . model . id ) . join ( UserTenant , on = ( UserTenant . tenant_id == Knowledgebase . tenant_id )
) . where ( cls . model . id == kb_id , UserTenant . user_id == user_id ) . paginate ( 0 , 1 )
docs = docs . dicts ( )
if not docs :
return False
return True
@classmethod
@DB.connection_context ( )
def get_kb_by_id ( cls , kb_id , user_id ) :
kbs = cls . model . select ( ) . join ( UserTenant , on = ( UserTenant . tenant_id == Knowledgebase . tenant_id )
) . where ( cls . model . id == kb_id , UserTenant . user_id == user_id ) . paginate ( 0 , 1 )
kbs = kbs . dicts ( )
return list ( kbs )
@classmethod
@DB.connection_context ( )
def get_kb_by_name ( cls , kb_name , user_id ) :
kbs = cls . model . select ( ) . join ( UserTenant , on = ( UserTenant . tenant_id == Knowledgebase . tenant_id )
) . where ( cls . model . name == kb_name , UserTenant . user_id == user_id ) . paginate ( 0 , 1 )
kbs = kbs . dicts ( )
return list ( kbs )
@classmethod
@DB.connection_context ( )
def accessible4deletion ( cls , kb_id , user_id ) :
docs = cls . model . select (
cls . model . id ) . where ( cls . model . id == kb_id , cls . model . created_by == user_id ) . paginate ( 0 , 1 )
docs = docs . dicts ( )
if not docs :
return False
return True