2025-03-24 11:19:28 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import binascii
import time
from functools import partial
import re
from copy import deepcopy
from timeit import default_timer as timer
from agentic_reasoning import DeepResearcher
from api . db import LLMType , ParserType , StatusEnum
from api . db . db_models import Dialog , DB
from api . db . services . common_service import CommonService
from api . db . services . knowledgebase_service import KnowledgebaseService
from api . db . services . llm_service import TenantLLMService , LLMBundle
from api import settings
from rag . app . resume import forbidden_select_fields4resume
from rag . app . tag import label_question
from rag . nlp . search import index_name
2025-03-31 10:53:42 +08:00
from rag . prompts import kb_prompt , message_fit_in , llm_id2llm_type , keyword_extraction , full_question , chunks_format , \
citation_prompt
2025-03-24 11:19:28 +08:00
from rag . utils import rmSpace , num_tokens_from_string
from rag . utils . tavily_conn import Tavily
class DialogService ( CommonService ) :
model = Dialog
@classmethod
@DB.connection_context ( )
def get_list ( cls , tenant_id ,
page_number , items_per_page , orderby , desc , id , name ) :
chats = cls . model . select ( )
if id :
chats = chats . where ( cls . model . id == id )
if name :
chats = chats . where ( cls . model . name == name )
chats = chats . where (
( cls . model . tenant_id == tenant_id )
& ( cls . model . status == StatusEnum . VALID . value )
)
if desc :
chats = chats . order_by ( cls . model . getter_by ( orderby ) . desc ( ) )
else :
chats = chats . order_by ( cls . model . getter_by ( orderby ) . asc ( ) )
chats = chats . paginate ( page_number , items_per_page )
return list ( chats . dicts ( ) )
def chat_solo ( dialog , messages , stream = True ) :
if llm_id2llm_type ( dialog . llm_id ) == " image2text " :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
else :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
prompt_config = dialog . prompt_config
tts_mdl = None
if prompt_config . get ( " tts " ) :
tts_mdl = LLMBundle ( dialog . tenant_id , LLMType . TTS )
msg = [ { " role " : m [ " role " ] , " content " : re . sub ( r " ## \ d+ \ $ \ $ " , " " , m [ " content " ] ) }
2025-03-31 10:53:42 +08:00
for m in messages if m [ " role " ] != " system " ]
2025-03-24 11:19:28 +08:00
if stream :
last_ans = " "
for ans in chat_mdl . chat_streamly ( prompt_config . get ( " system " , " " ) , msg , dialog . llm_setting ) :
answer = ans
delta_ans = ans [ len ( last_ans ) : ]
if num_tokens_from_string ( delta_ans ) < 16 :
continue
last_ans = answer
2025-03-31 10:53:42 +08:00
yield { " answer " : answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) , " prompt " : " " , " created_at " : time . time ( ) }
if delta_ans :
yield { " answer " : answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) , " prompt " : " " , " created_at " : time . time ( ) }
2025-03-24 11:19:28 +08:00
else :
answer = chat_mdl . chat ( prompt_config . get ( " system " , " " ) , msg , dialog . llm_setting )
user_content = msg [ - 1 ] . get ( " content " , " [content not available] " )
logging . debug ( " User: {} |Assistant: {} " . format ( user_content , answer ) )
yield { " answer " : answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , answer ) , " prompt " : " " , " created_at " : time . time ( ) }
def chat ( dialog , messages , stream = True , * * kwargs ) :
assert messages [ - 1 ] [ " role " ] == " user " , " The last content of this conversation is not from user. "
if not dialog . kb_ids :
for ans in chat_solo ( dialog , messages , stream ) :
yield ans
return
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
chat_start_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
if llm_id2llm_type ( dialog . llm_id ) == " image2text " :
llm_model_config = TenantLLMService . get_model_config ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
else :
llm_model_config = TenantLLMService . get_model_config ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
max_tokens = llm_model_config . get ( " max_tokens " , 8192 )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
check_llm_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
kbs = KnowledgebaseService . get_by_ids ( dialog . kb_ids )
embedding_list = list ( set ( [ kb . embd_id for kb in kbs ] ) )
if len ( embedding_list ) != 1 :
yield { " answer " : " **ERROR**: Knowledge bases use different embedding models. " , " reference " : [ ] }
return { " answer " : " **ERROR**: Knowledge bases use different embedding models. " , " reference " : [ ] }
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
embedding_model_name = embedding_list [ 0 ]
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
retriever = settings . retrievaler
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
questions = [ m [ " content " ] for m in messages if m [ " role " ] == " user " ] [ - 3 : ]
attachments = kwargs [ " doc_ids " ] . split ( " , " ) if " doc_ids " in kwargs else None
if " doc_ids " in messages [ - 1 ] :
attachments = messages [ - 1 ] [ " doc_ids " ]
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
create_retriever_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
embd_mdl = LLMBundle ( dialog . tenant_id , LLMType . EMBEDDING , embedding_model_name )
if not embd_mdl :
raise LookupError ( " Embedding model( %s ) not found " % embedding_model_name )
bind_embedding_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
if llm_id2llm_type ( dialog . llm_id ) == " image2text " :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
else :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
bind_llm_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
prompt_config = dialog . prompt_config
field_map = KnowledgebaseService . get_field_map ( dialog . kb_ids )
tts_mdl = None
if prompt_config . get ( " tts " ) :
tts_mdl = LLMBundle ( dialog . tenant_id , LLMType . TTS )
# try to use sql if field mapping is good to go
if field_map :
logging . debug ( " Use SQL to retrieval: {} " . format ( questions [ - 1 ] ) )
ans = use_sql ( questions [ - 1 ] , field_map , dialog . tenant_id , chat_mdl , prompt_config . get ( " quote " , True ) )
if ans :
yield ans
return
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
for p in prompt_config [ " parameters " ] :
if p [ " key " ] == " knowledge " :
continue
if p [ " key " ] not in kwargs and not p [ " optional " ] :
raise KeyError ( " Miss parameter: " + p [ " key " ] )
if p [ " key " ] not in kwargs :
prompt_config [ " system " ] = prompt_config [ " system " ] . replace (
" { %s } " % p [ " key " ] , " " )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
if len ( questions ) > 1 and prompt_config . get ( " refine_multiturn " ) :
questions = [ full_question ( dialog . tenant_id , dialog . llm_id , messages ) ]
else :
questions = questions [ - 1 : ]
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
refine_question_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
rerank_mdl = None
if dialog . rerank_id :
rerank_mdl = LLMBundle ( dialog . tenant_id , LLMType . RERANK , dialog . rerank_id )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
bind_reranker_ts = timer ( )
generate_keyword_ts = bind_reranker_ts
thought = " "
kbinfos = { " total " : 0 , " chunks " : [ ] , " doc_aggs " : [ ] }
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
if " knowledge " not in [ p [ " key " ] for p in prompt_config [ " parameters " ] ] :
knowledges = [ ]
else :
if prompt_config . get ( " keyword " , False ) :
questions [ - 1 ] + = keyword_extraction ( chat_mdl , questions [ - 1 ] )
generate_keyword_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
tenant_ids = list ( set ( [ kb . tenant_id for kb in kbs ] ) )
knowledges = [ ]
if prompt_config . get ( " reasoning " , False ) :
reasoner = DeepResearcher ( chat_mdl ,
prompt_config ,
partial ( retriever . retrieval , embd_mdl = embd_mdl , tenant_ids = tenant_ids , kb_ids = dialog . kb_ids , page = 1 , page_size = dialog . top_n , similarity_threshold = 0.2 , vector_similarity_weight = 0.3 ) )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
for think in reasoner . thinking ( kbinfos , " " . join ( questions ) ) :
if isinstance ( think , str ) :
thought = think
knowledges = [ t for t in think . split ( " \n " ) if t ]
elif stream :
yield think
else :
kbinfos = retriever . retrieval ( " " . join ( questions ) , embd_mdl , tenant_ids , dialog . kb_ids , 1 , dialog . top_n ,
dialog . similarity_threshold ,
dialog . vector_similarity_weight ,
doc_ids = attachments ,
top = dialog . top_k , aggs = False , rerank_mdl = rerank_mdl ,
rank_feature = label_question ( " " . join ( questions ) , kbs )
)
if prompt_config . get ( " tavily_api_key " ) :
tav = Tavily ( prompt_config [ " tavily_api_key " ] )
tav_res = tav . retrieve_chunks ( " " . join ( questions ) )
kbinfos [ " chunks " ] . extend ( tav_res [ " chunks " ] )
kbinfos [ " doc_aggs " ] . extend ( tav_res [ " doc_aggs " ] )
if prompt_config . get ( " use_kg " ) :
ck = settings . kg_retrievaler . retrieval ( " " . join ( questions ) ,
tenant_ids ,
dialog . kb_ids ,
embd_mdl ,
LLMBundle ( dialog . tenant_id , LLMType . CHAT ) )
if ck [ " content_with_weight " ] :
kbinfos [ " chunks " ] . insert ( 0 , ck )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
knowledges = kb_prompt ( kbinfos , max_tokens )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
logging . debug (
" {} -> {} " . format ( " " . join ( questions ) , " \n -> " . join ( knowledges ) ) )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
retrieval_ts = timer ( )
if not knowledges and prompt_config . get ( " empty_response " ) :
empty_res = prompt_config [ " empty_response " ]
2025-03-31 10:53:42 +08:00
yield { " answer " : empty_res , " reference " : kbinfos , " prompt " : " \n \n ### Query: \n %s " % " " . join ( questions ) , " audio_binary " : tts ( tts_mdl , empty_res ) }
2025-03-24 11:19:28 +08:00
return { " answer " : prompt_config [ " empty_response " ] , " reference " : kbinfos }
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
kwargs [ " knowledge " ] = " \n ------ \n " + " \n \n ------ \n \n " . join ( knowledges )
gen_conf = dialog . llm_setting
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
msg = [ { " role " : " system " , " content " : prompt_config [ " system " ] . format ( * * kwargs ) } ]
2025-03-31 10:53:42 +08:00
prompt4citation = " "
if knowledges and ( prompt_config . get ( " quote " , True ) and kwargs . get ( " quote " , True ) ) :
prompt4citation = citation_prompt ( )
2025-04-03 21:00:49 +08:00
# 过滤掉 system 角色的消息(因为前面已经单独处理了系统消息)
2025-03-24 11:19:28 +08:00
msg . extend ( [ { " role " : m [ " role " ] , " content " : re . sub ( r " ## \ d+ \ $ \ $ " , " " , m [ " content " ] ) }
for m in messages if m [ " role " ] != " system " ] )
2025-03-31 10:53:42 +08:00
used_token_count , msg = message_fit_in ( msg , int ( max_tokens * 0.95 ) )
2025-03-24 11:19:28 +08:00
assert len ( msg ) > = 2 , f " message_fit_in has bug: { msg } "
prompt = msg [ 0 ] [ " content " ]
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
if " max_tokens " in gen_conf :
gen_conf [ " max_tokens " ] = min (
gen_conf [ " max_tokens " ] ,
max_tokens - used_token_count )
def decorate_answer ( answer ) :
nonlocal prompt_config , knowledges , kwargs , kbinfos , prompt , retrieval_ts , questions
refs = [ ]
ans = answer . split ( " </think> " )
think = " "
if len ( ans ) == 2 :
think = ans [ 0 ] + " </think> "
answer = ans [ 1 ]
if knowledges and ( prompt_config . get ( " quote " , True ) and kwargs . get ( " quote " , True ) ) :
2025-03-31 10:53:42 +08:00
answer = re . sub ( r " ##[ij] \ $ \ $ " , " " , answer , flags = re . DOTALL )
if not re . search ( r " ##[0-9]+ \ $ \ $ " , answer ) :
answer , idx = retriever . insert_citations ( answer ,
[ ck [ " content_ltks " ]
for ck in kbinfos [ " chunks " ] ] ,
[ ck [ " vector " ]
for ck in kbinfos [ " chunks " ] ] ,
embd_mdl ,
tkweight = 1 - dialog . vector_similarity_weight ,
vtweight = dialog . vector_similarity_weight )
else :
idx = set ( [ ] )
for r in re . finditer ( r " ##([0-9]+) \ $ \ $ " , answer ) :
i = int ( r . group ( 1 ) )
if i < len ( kbinfos [ " chunks " ] ) :
idx . add ( i )
2025-03-24 11:19:28 +08:00
idx = set ( [ kbinfos [ " chunks " ] [ int ( i ) ] [ " doc_id " ] for i in idx ] )
recall_docs = [
d for d in kbinfos [ " doc_aggs " ] if d [ " doc_id " ] in idx ]
if not recall_docs :
recall_docs = kbinfos [ " doc_aggs " ]
kbinfos [ " doc_aggs " ] = recall_docs
refs = deepcopy ( kbinfos )
for c in refs [ " chunks " ] :
if c . get ( " vector " ) :
del c [ " vector " ]
if answer . lower ( ) . find ( " invalid key " ) > = 0 or answer . lower ( ) . find ( " invalid api " ) > = 0 :
answer + = " Please set LLM API-Key in ' User Setting -> Model providers -> API-Key ' "
finish_chat_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
total_time_cost = ( finish_chat_ts - chat_start_ts ) * 1000
check_llm_time_cost = ( check_llm_ts - chat_start_ts ) * 1000
create_retriever_time_cost = ( create_retriever_ts - check_llm_ts ) * 1000
bind_embedding_time_cost = ( bind_embedding_ts - create_retriever_ts ) * 1000
bind_llm_time_cost = ( bind_llm_ts - bind_embedding_ts ) * 1000
refine_question_time_cost = ( refine_question_ts - bind_llm_ts ) * 1000
bind_reranker_time_cost = ( bind_reranker_ts - refine_question_ts ) * 1000
generate_keyword_time_cost = ( generate_keyword_ts - bind_reranker_ts ) * 1000
retrieval_time_cost = ( retrieval_ts - generate_keyword_ts ) * 1000
generate_result_time_cost = ( finish_chat_ts - retrieval_ts ) * 1000
prompt + = " \n \n ### Query: \n %s " % " " . join ( questions )
prompt = f " { prompt } \n \n - Total: { total_time_cost : .1f } ms \n - Check LLM: { check_llm_time_cost : .1f } ms \n - Create retriever: { create_retriever_time_cost : .1f } ms \n - Bind embedding: { bind_embedding_time_cost : .1f } ms \n - Bind LLM: { bind_llm_time_cost : .1f } ms \n - Tune question: { refine_question_time_cost : .1f } ms \n - Bind reranker: { bind_reranker_time_cost : .1f } ms \n - Generate keyword: { generate_keyword_time_cost : .1f } ms \n - Retrieval: { retrieval_time_cost : .1f } ms \n - Generate answer: { generate_result_time_cost : .1f } ms "
return { " answer " : think + answer , " reference " : refs , " prompt " : re . sub ( r " \ n " , " \n " , prompt ) , " created_at " : time . time ( ) }
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
if stream :
2025-04-03 21:00:49 +08:00
last_ans = " " # 记录上一次返回的完整回答
answer = " " # 当前累计的完整回答
2025-03-31 10:53:42 +08:00
for ans in chat_mdl . chat_streamly ( prompt + prompt4citation , msg [ 1 : ] , gen_conf ) :
2025-04-03 21:00:49 +08:00
# 如果存在思考过程(thought),移除相关标记
2025-03-24 11:19:28 +08:00
if thought :
ans = re . sub ( r " <think>.*</think> " , " " , ans , flags = re . DOTALL )
answer = ans
2025-04-03 21:00:49 +08:00
# 计算新增的文本片段(delta)
2025-03-24 11:19:28 +08:00
delta_ans = ans [ len ( last_ans ) : ]
2025-04-03 21:00:49 +08:00
# 如果新增token太少(小于16),跳过本次返回(避免频繁发送小片段)
2025-03-24 11:19:28 +08:00
if num_tokens_from_string ( delta_ans ) < 16 :
continue
last_ans = answer
2025-04-03 21:00:49 +08:00
# 返回当前累计回答(包含思考过程)+新增片段)
2025-03-24 11:19:28 +08:00
yield { " answer " : thought + answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) }
delta_ans = answer [ len ( last_ans ) : ]
if delta_ans :
yield { " answer " : thought + answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) }
yield decorate_answer ( thought + answer )
else :
2025-03-31 10:53:42 +08:00
answer = chat_mdl . chat ( prompt + prompt4citation , msg [ 1 : ] , gen_conf )
2025-03-24 11:19:28 +08:00
user_content = msg [ - 1 ] . get ( " content " , " [content not available] " )
logging . debug ( " User: {} |Assistant: {} " . format ( user_content , answer ) )
res = decorate_answer ( answer )
res [ " audio_binary " ] = tts ( tts_mdl , answer )
yield res
def use_sql ( question , field_map , tenant_id , chat_mdl , quota = True ) :
sys_prompt = " You are a Database Administrator. You need to check the fields of the following tables based on the user ' s list of questions and write the SQL corresponding to the last question. "
user_prompt = """
Table name : { } ;
Table of database fields are as follows :
{ }
Question are as follows :
{ }
Please write the SQL , only SQL , without any other explanations or text .
""" .format(
index_name ( tenant_id ) ,
" \n " . join ( [ f " { k } : { v } " for k , v in field_map . items ( ) ] ) ,
question
)
tried_times = 0
def get_table ( ) :
nonlocal sys_prompt , user_prompt , question , tried_times
sql = chat_mdl . chat ( sys_prompt , [ { " role " : " user " , " content " : user_prompt } ] , {
" temperature " : 0.06 } )
2025-03-31 10:53:42 +08:00
sql = re . sub ( r " <think>.*</think> " , " " , sql , flags = re . DOTALL )
2025-03-24 11:19:28 +08:00
logging . debug ( f " { question } ==> { user_prompt } get SQL: { sql } " )
sql = re . sub ( r " [ \ r \ n]+ " , " " , sql . lower ( ) )
sql = re . sub ( r " .*select " , " select " , sql . lower ( ) )
sql = re . sub ( r " + " , " " , sql )
sql = re . sub ( r " ([;; ]|```).* " , " " , sql )
if sql [ : len ( " select " ) ] != " select " :
return None , None
if not re . search ( r " ((sum|avg|max|min) \ (|group by ) " , sql . lower ( ) ) :
if sql [ : len ( " select * " ) ] != " select * " :
sql = " select doc_id,docnm_kwd, " + sql [ 6 : ]
else :
flds = [ ]
for k in field_map . keys ( ) :
if k in forbidden_select_fields4resume :
continue
if len ( flds ) > 11 :
break
flds . append ( k )
sql = " select doc_id,docnm_kwd, " + " , " . join ( flds ) + sql [ 8 : ]
logging . debug ( f " { question } get SQL(refined): { sql } " )
tried_times + = 1
return settings . retrievaler . sql_retrieval ( sql , format = " json " ) , sql
tbl , sql = get_table ( )
if tbl is None :
return None
if tbl . get ( " error " ) and tried_times < = 2 :
user_prompt = """
Table name : { } ;
Table of database fields are as follows :
{ }
Question are as follows :
{ }
Please write the SQL , only SQL , without any other explanations or text .
The SQL error you provided last time is as follows :
{ }
Error issued by database as follows :
{ }
Please correct the error and write SQL again , only SQL , without any other explanations or text .
""" .format(
index_name ( tenant_id ) ,
" \n " . join ( [ f " { k } : { v } " for k , v in field_map . items ( ) ] ) ,
question , sql , tbl [ " error " ]
)
tbl , sql = get_table ( )
logging . debug ( " TRY it again: {} " . format ( sql ) )
logging . debug ( " GET table: {} " . format ( tbl ) )
if tbl . get ( " error " ) or len ( tbl [ " rows " ] ) == 0 :
return None
docid_idx = set ( [ ii for ii , c in enumerate (
tbl [ " columns " ] ) if c [ " name " ] == " doc_id " ] )
doc_name_idx = set ( [ ii for ii , c in enumerate (
tbl [ " columns " ] ) if c [ " name " ] == " docnm_kwd " ] )
column_idx = [ ii for ii in range (
len ( tbl [ " columns " ] ) ) if ii not in ( docid_idx | doc_name_idx ) ]
# compose Markdown table
columns = " | " + " | " . join ( [ re . sub ( r " (/.*|( [^( ) ]+) ) " , " " , field_map . get ( tbl [ " columns " ] [ i ] [ " name " ] ,
tbl [ " columns " ] [ i ] [ " name " ] ) ) for i in
column_idx ] ) + ( " |Source| " if docid_idx and docid_idx else " | " )
line = " | " + " | " . join ( [ " ------ " for _ in range ( len ( column_idx ) ) ] ) + \
( " |------| " if docid_idx and docid_idx else " " )
rows = [ " | " +
" | " . join ( [ rmSpace ( str ( r [ i ] ) ) for i in column_idx ] ) . replace ( " None " , " " ) +
" | " for r in tbl [ " rows " ] ]
rows = [ r for r in rows if re . sub ( r " [ |]+ " , " " , r ) ]
if quota :
rows = " \n " . join ( [ r + f " ## { ii } $$ | " for ii , r in enumerate ( rows ) ] )
else :
rows = " \n " . join ( [ r + f " ## { ii } $$ | " for ii , r in enumerate ( rows ) ] )
rows = re . sub ( r " T[0-9] {2} :[0-9] {2} :[0-9] {2} ( \ .[0-9]+Z)? \ | " , " | " , rows )
if not docid_idx or not doc_name_idx :
logging . warning ( " SQL missing field: " + sql )
return {
" answer " : " \n " . join ( [ columns , line , rows ] ) ,
" reference " : { " chunks " : [ ] , " doc_aggs " : [ ] } ,
" prompt " : sys_prompt
}
docid_idx = list ( docid_idx ) [ 0 ]
doc_name_idx = list ( doc_name_idx ) [ 0 ]
doc_aggs = { }
for r in tbl [ " rows " ] :
if r [ docid_idx ] not in doc_aggs :
doc_aggs [ r [ docid_idx ] ] = { " doc_name " : r [ doc_name_idx ] , " count " : 0 }
doc_aggs [ r [ docid_idx ] ] [ " count " ] + = 1
return {
" answer " : " \n " . join ( [ columns , line , rows ] ) ,
" reference " : { " chunks " : [ { " doc_id " : r [ docid_idx ] , " docnm_kwd " : r [ doc_name_idx ] } for r in tbl [ " rows " ] ] ,
" doc_aggs " : [ { " doc_id " : did , " doc_name " : d [ " doc_name " ] , " count " : d [ " count " ] } for did , d in
doc_aggs . items ( ) ] } ,
" prompt " : sys_prompt
}
def tts ( tts_mdl , text ) :
if not tts_mdl or not text :
return
bin = b " "
for chunk in tts_mdl . tts ( text ) :
bin + = chunk
return binascii . hexlify ( bin ) . decode ( " utf-8 " )
def ask ( question , kb_ids , tenant_id ) :
2025-04-03 21:00:49 +08:00
"""
处理用户搜索请求 , 从知识库中检索相关信息并生成回答
参数 :
question ( str ) : 用户的问题或查询
kb_ids ( list ) : 知识库ID列表 , 指定要搜索的知识库
tenant_id ( str ) : 租户ID , 用于权限控制和资源隔离
流程 :
1. 获取指定知识库的信息
2. 确定使用的嵌入模型
3. 根据知识库类型选择检索器 ( 普通检索器或知识图谱检索器 )
4. 初始化嵌入模型和聊天模型
5. 执行检索操作获取相关文档片段
6. 格式化知识库内容作为上下文
7. 构建系统提示词
8. 生成回答并添加引用标记
9. 流式返回生成的回答
返回 :
generator : 生成器对象 , 产生包含回答和引用信息的字典
"""
2025-03-24 11:19:28 +08:00
kbs = KnowledgebaseService . get_by_ids ( kb_ids )
embedding_list = list ( set ( [ kb . embd_id for kb in kbs ] ) )
is_knowledge_graph = all ( [ kb . parser_id == ParserType . KG for kb in kbs ] )
retriever = settings . retrievaler if not is_knowledge_graph else settings . kg_retrievaler
2025-04-03 21:00:49 +08:00
# 初始化嵌入模型,用于将文本转换为向量表示
2025-03-24 11:19:28 +08:00
embd_mdl = LLMBundle ( tenant_id , LLMType . EMBEDDING , embedding_list [ 0 ] )
2025-04-03 21:00:49 +08:00
# 初始化聊天模型,用于生成回答
2025-03-24 11:19:28 +08:00
chat_mdl = LLMBundle ( tenant_id , LLMType . CHAT )
2025-04-03 21:00:49 +08:00
# 获取聊天模型的最大token长度, 用于控制上下文长度
2025-03-24 11:19:28 +08:00
max_tokens = chat_mdl . max_length
2025-04-03 21:00:49 +08:00
# 获取所有知识库的租户ID并去重
2025-03-24 11:19:28 +08:00
tenant_ids = list ( set ( [ kb . tenant_id for kb in kbs ] ) )
2025-04-03 21:00:49 +08:00
# 调用检索器检索相关文档片段
2025-03-24 11:19:28 +08:00
kbinfos = retriever . retrieval ( question , embd_mdl , tenant_ids , kb_ids ,
1 , 12 , 0.1 , 0.3 , aggs = False ,
rank_feature = label_question ( question , kbs )
)
2025-04-03 21:00:49 +08:00
# 将检索结果格式化为提示词, 并确保不超过模型最大token限制
2025-03-24 11:19:28 +08:00
knowledges = kb_prompt ( kbinfos , max_tokens )
prompt = """
Role : You ' re a smart assistant. Your name is Miss R.
Task : Summarize the information from knowledge bases and answer user ' s question.
Requirements and restriction :
- DO NOT make things up , especially for numbers .
- If the information from knowledge is irrelevant with user ' s question, JUST SAY: Sorry, no relevant information provided.
- Answer with markdown format text .
- Answer in language of user ' s question.
- DO NOT make things up , especially for numbers .
### Information from knowledge bases
% s
The above is information from knowledge bases .
""" % " \n " .join(knowledges)
msg = [ { " role " : " user " , " content " : question } ]
2025-04-03 21:00:49 +08:00
# 生成完成后添加回答中的引用标记
2025-03-24 11:19:28 +08:00
def decorate_answer ( answer ) :
nonlocal knowledges , kbinfos , prompt
answer , idx = retriever . insert_citations ( answer ,
[ ck [ " content_ltks " ]
for ck in kbinfos [ " chunks " ] ] ,
[ ck [ " vector " ]
for ck in kbinfos [ " chunks " ] ] ,
embd_mdl ,
tkweight = 0.7 ,
vtweight = 0.3 )
idx = set ( [ kbinfos [ " chunks " ] [ int ( i ) ] [ " doc_id " ] for i in idx ] )
recall_docs = [
d for d in kbinfos [ " doc_aggs " ] if d [ " doc_id " ] in idx ]
if not recall_docs :
recall_docs = kbinfos [ " doc_aggs " ]
kbinfos [ " doc_aggs " ] = recall_docs
refs = deepcopy ( kbinfos )
for c in refs [ " chunks " ] :
if c . get ( " vector " ) :
del c [ " vector " ]
if answer . lower ( ) . find ( " invalid key " ) > = 0 or answer . lower ( ) . find ( " invalid api " ) > = 0 :
answer + = " Please set LLM API-Key in ' User Setting -> Model Providers -> API-Key ' "
2025-03-31 10:53:42 +08:00
refs [ " chunks " ] = chunks_format ( refs )
return { " answer " : answer , " reference " : refs }
2025-03-24 11:19:28 +08:00
answer = " "
for ans in chat_mdl . chat_streamly ( prompt , msg , { " temperature " : 0.1 } ) :
answer = ans
yield { " answer " : answer , " reference " : { } }
2025-04-03 21:00:49 +08:00
yield decorate_answer ( answer )