2025-03-24 11:19:28 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import binascii
2025-06-12 19:43:24 +08:00
import logging
2025-03-24 11:19:28 +08:00
import re
2025-06-12 19:43:24 +08:00
import time
2025-03-24 11:19:28 +08:00
from copy import deepcopy
from timeit import default_timer as timer
2025-06-12 19:43:24 +08:00
from api import settings
2025-03-24 11:19:28 +08:00
from api . db import LLMType , ParserType , StatusEnum
2025-06-12 19:43:24 +08:00
from api . db . db_models import DB , Dialog
2025-03-24 11:19:28 +08:00
from api . db . services . common_service import CommonService
from api . db . services . knowledgebase_service import KnowledgebaseService
2025-06-12 19:43:24 +08:00
from api . db . services . llm_service import LLMBundle , TenantLLMService
2025-03-24 11:19:28 +08:00
from rag . app . resume import forbidden_select_fields4resume
from rag . app . tag import label_question
from rag . nlp . search import index_name
2025-06-12 19:43:24 +08:00
from rag . prompts import chunks_format , citation_prompt , kb_prompt , keyword_extraction , llm_id2llm_type , message_fit_in
from rag . utils import num_tokens_from_string , rmSpace
2025-06-07 13:00:07 +08:00
from . database import MINIO_CONFIG
2025-03-24 11:19:28 +08:00
class DialogService ( CommonService ) :
model = Dialog
@classmethod
@DB.connection_context ( )
2025-06-03 23:42:47 +08:00
def get_list ( cls , tenant_id , page_number , items_per_page , orderby , desc , id , name ) :
2025-03-24 11:19:28 +08:00
chats = cls . model . select ( )
if id :
chats = chats . where ( cls . model . id == id )
if name :
chats = chats . where ( cls . model . name == name )
2025-06-03 23:42:47 +08:00
chats = chats . where ( ( cls . model . tenant_id == tenant_id ) & ( cls . model . status == StatusEnum . VALID . value ) )
2025-03-24 11:19:28 +08:00
if desc :
chats = chats . order_by ( cls . model . getter_by ( orderby ) . desc ( ) )
else :
chats = chats . order_by ( cls . model . getter_by ( orderby ) . asc ( ) )
chats = chats . paginate ( page_number , items_per_page )
return list ( chats . dicts ( ) )
def chat_solo ( dialog , messages , stream = True ) :
if llm_id2llm_type ( dialog . llm_id ) == " image2text " :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
else :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
prompt_config = dialog . prompt_config
tts_mdl = None
if prompt_config . get ( " tts " ) :
tts_mdl = LLMBundle ( dialog . tenant_id , LLMType . TTS )
2025-06-03 23:42:47 +08:00
msg = [ { " role " : m [ " role " ] , " content " : re . sub ( r " ## \ d+ \ $ \ $ " , " " , m [ " content " ] ) } for m in messages if m [ " role " ] != " system " ]
2025-03-24 11:19:28 +08:00
if stream :
last_ans = " "
for ans in chat_mdl . chat_streamly ( prompt_config . get ( " system " , " " ) , msg , dialog . llm_setting ) :
answer = ans
2025-06-03 23:42:47 +08:00
delta_ans = ans [ len ( last_ans ) : ]
2025-03-24 11:19:28 +08:00
if num_tokens_from_string ( delta_ans ) < 16 :
continue
last_ans = answer
2025-03-31 10:53:42 +08:00
yield { " answer " : answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) , " prompt " : " " , " created_at " : time . time ( ) }
if delta_ans :
yield { " answer " : answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) , " prompt " : " " , " created_at " : time . time ( ) }
2025-03-24 11:19:28 +08:00
else :
answer = chat_mdl . chat ( prompt_config . get ( " system " , " " ) , msg , dialog . llm_setting )
user_content = msg [ - 1 ] . get ( " content " , " [content not available] " )
logging . debug ( " User: {} |Assistant: {} " . format ( user_content , answer ) )
yield { " answer " : answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , answer ) , " prompt " : " " , " created_at " : time . time ( ) }
def chat ( dialog , messages , stream = True , * * kwargs ) :
assert messages [ - 1 ] [ " role " ] == " user " , " The last content of this conversation is not from user. "
if not dialog . kb_ids :
for ans in chat_solo ( dialog , messages , stream ) :
yield ans
return
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
chat_start_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
if llm_id2llm_type ( dialog . llm_id ) == " image2text " :
llm_model_config = TenantLLMService . get_model_config ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
else :
llm_model_config = TenantLLMService . get_model_config ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
max_tokens = llm_model_config . get ( " max_tokens " , 8192 )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
check_llm_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
kbs = KnowledgebaseService . get_by_ids ( dialog . kb_ids )
embedding_list = list ( set ( [ kb . embd_id for kb in kbs ] ) )
if len ( embedding_list ) != 1 :
yield { " answer " : " **ERROR**: Knowledge bases use different embedding models. " , " reference " : [ ] }
return { " answer " : " **ERROR**: Knowledge bases use different embedding models. " , " reference " : [ ] }
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
embedding_model_name = embedding_list [ 0 ]
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
retriever = settings . retrievaler
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
questions = [ m [ " content " ] for m in messages if m [ " role " ] == " user " ] [ - 3 : ]
attachments = kwargs [ " doc_ids " ] . split ( " , " ) if " doc_ids " in kwargs else None
if " doc_ids " in messages [ - 1 ] :
attachments = messages [ - 1 ] [ " doc_ids " ]
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
create_retriever_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
embd_mdl = LLMBundle ( dialog . tenant_id , LLMType . EMBEDDING , embedding_model_name )
if not embd_mdl :
raise LookupError ( " Embedding model( %s ) not found " % embedding_model_name )
bind_embedding_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
if llm_id2llm_type ( dialog . llm_id ) == " image2text " :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
else :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
bind_llm_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
prompt_config = dialog . prompt_config
field_map = KnowledgebaseService . get_field_map ( dialog . kb_ids )
tts_mdl = None
if prompt_config . get ( " tts " ) :
tts_mdl = LLMBundle ( dialog . tenant_id , LLMType . TTS )
# try to use sql if field mapping is good to go
if field_map :
logging . debug ( " Use SQL to retrieval: {} " . format ( questions [ - 1 ] ) )
ans = use_sql ( questions [ - 1 ] , field_map , dialog . tenant_id , chat_mdl , prompt_config . get ( " quote " , True ) )
if ans :
yield ans
return
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
for p in prompt_config [ " parameters " ] :
if p [ " key " ] == " knowledge " :
continue
if p [ " key " ] not in kwargs and not p [ " optional " ] :
raise KeyError ( " Miss parameter: " + p [ " key " ] )
if p [ " key " ] not in kwargs :
2025-06-03 23:42:47 +08:00
prompt_config [ " system " ] = prompt_config [ " system " ] . replace ( " { %s } " % p [ " key " ] , " " )
2025-04-18 22:34:25 +08:00
questions = questions [ - 1 : ]
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
refine_question_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
rerank_mdl = None
if dialog . rerank_id :
rerank_mdl = LLMBundle ( dialog . tenant_id , LLMType . RERANK , dialog . rerank_id )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
bind_reranker_ts = timer ( )
generate_keyword_ts = bind_reranker_ts
thought = " "
kbinfos = { " total " : 0 , " chunks " : [ ] , " doc_aggs " : [ ] }
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
if " knowledge " not in [ p [ " key " ] for p in prompt_config [ " parameters " ] ] :
knowledges = [ ]
else :
if prompt_config . get ( " keyword " , False ) :
questions [ - 1 ] + = keyword_extraction ( chat_mdl , questions [ - 1 ] )
generate_keyword_ts = timer ( )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
tenant_ids = list ( set ( [ kb . tenant_id for kb in kbs ] ) )
knowledges = [ ]
2025-06-03 23:42:47 +08:00
kbinfos = retriever . retrieval (
" " . join ( questions ) ,
embd_mdl ,
tenant_ids ,
dialog . kb_ids ,
1 ,
dialog . top_n ,
dialog . similarity_threshold ,
dialog . vector_similarity_weight ,
doc_ids = attachments ,
top = dialog . top_k ,
aggs = False ,
rerank_mdl = rerank_mdl ,
rank_feature = label_question ( " " . join ( questions ) , kbs ) ,
)
2025-04-18 22:34:25 +08:00
knowledges = kb_prompt ( kbinfos , max_tokens )
2025-03-31 10:53:42 +08:00
2025-06-03 23:42:47 +08:00
logging . debug ( " {} -> {} " . format ( " " . join ( questions ) , " \n -> " . join ( knowledges ) ) )
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
retrieval_ts = timer ( )
if not knowledges and prompt_config . get ( " empty_response " ) :
empty_res = prompt_config [ " empty_response " ]
2025-03-31 10:53:42 +08:00
yield { " answer " : empty_res , " reference " : kbinfos , " prompt " : " \n \n ### Query: \n %s " % " " . join ( questions ) , " audio_binary " : tts ( tts_mdl , empty_res ) }
2025-03-24 11:19:28 +08:00
return { " answer " : prompt_config [ " empty_response " ] , " reference " : kbinfos }
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
kwargs [ " knowledge " ] = " \n ------ \n " + " \n \n ------ \n \n " . join ( knowledges )
gen_conf = dialog . llm_setting
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
msg = [ { " role " : " system " , " content " : prompt_config [ " system " ] . format ( * * kwargs ) } ]
2025-03-31 10:53:42 +08:00
prompt4citation = " "
if knowledges and ( prompt_config . get ( " quote " , True ) and kwargs . get ( " quote " , True ) ) :
prompt4citation = citation_prompt ( )
2025-04-03 21:00:49 +08:00
# 过滤掉 system 角色的消息(因为前面已经单独处理了系统消息)
2025-06-03 23:42:47 +08:00
msg . extend ( [ { " role " : m [ " role " ] , " content " : re . sub ( r " ## \ d+ \ $ \ $ " , " " , m [ " content " ] ) } for m in messages if m [ " role " ] != " system " ] )
2025-03-31 10:53:42 +08:00
used_token_count , msg = message_fit_in ( msg , int ( max_tokens * 0.95 ) )
2025-03-24 11:19:28 +08:00
assert len ( msg ) > = 2 , f " message_fit_in has bug: { msg } "
prompt = msg [ 0 ] [ " content " ]
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
if " max_tokens " in gen_conf :
2025-06-03 23:42:47 +08:00
gen_conf [ " max_tokens " ] = min ( gen_conf [ " max_tokens " ] , max_tokens - used_token_count )
2025-03-24 11:19:28 +08:00
def decorate_answer ( answer ) :
nonlocal prompt_config , knowledges , kwargs , kbinfos , prompt , retrieval_ts , questions
refs = [ ]
ans = answer . split ( " </think> " )
think = " "
if len ( ans ) == 2 :
think = ans [ 0 ] + " </think> "
answer = ans [ 1 ]
2025-06-07 13:00:07 +08:00
cited_chunk_indices = set ( )
inserted_images = { }
processed_image_urls = set ( )
2025-03-24 11:19:28 +08:00
if knowledges and ( prompt_config . get ( " quote " , True ) and kwargs . get ( " quote " , True ) ) :
2025-06-07 13:00:07 +08:00
# 获取引用的 chunk 索引
2025-03-31 10:53:42 +08:00
if not re . search ( r " ##[0-9]+ \ $ \ $ " , answer ) :
2025-06-03 23:42:47 +08:00
answer , idx = retriever . insert_citations (
answer ,
[ ck [ " content_ltks " ] for ck in kbinfos [ " chunks " ] ] ,
[ ck [ " vector " ] for ck in kbinfos [ " chunks " ] ] ,
embd_mdl ,
tkweight = 1 - dialog . vector_similarity_weight ,
vtweight = dialog . vector_similarity_weight ,
)
2025-06-07 13:00:07 +08:00
cited_chunk_indices = idx
2025-03-31 10:53:42 +08:00
else :
for r in re . finditer ( r " ##([0-9]+) \ $ \ $ " , answer ) :
i = int ( r . group ( 1 ) )
if i < len ( kbinfos [ " chunks " ] ) :
2025-06-07 13:00:07 +08:00
cited_chunk_indices . add ( i )
# 处理图片插入
def insert_image_markdown ( match ) :
idx = int ( match . group ( 1 ) )
if idx > = len ( kbinfos [ " chunks " ] ) :
return match . group ( 0 )
chunk = kbinfos [ " chunks " ] [ idx ]
img_path = chunk . get ( " image_id " )
if not img_path :
return match . group ( 0 )
protocol = " https " if MINIO_CONFIG . get ( " secure " , False ) else " http "
2025-06-08 20:44:14 +08:00
img_url = f " { protocol } :// { MINIO_CONFIG [ ' visit_point ' ] } / { img_path } "
2025-06-07 13:00:07 +08:00
if img_url in processed_image_urls :
return match . group ( 0 )
processed_image_urls . add ( img_url )
inserted_images [ idx ] = img_url
2025-06-12 12:19:25 +08:00
# 插入图片,并限制最大宽度
return f ' { match . group ( 0 ) } \n \n <img src= " { img_url } " alt= " { img_url } " style= " max-width:800px; " > '
2025-06-07 13:00:07 +08:00
# 用正则替换插图
answer = re . sub ( r " ##( \ d+) \ $ \ $ " , insert_image_markdown , answer )
# 清理引用文献信息
idx = set ( [ kbinfos [ " chunks " ] [ int ( i ) ] [ " doc_id " ] for i in cited_chunk_indices ] )
2025-06-03 23:42:47 +08:00
recall_docs = [ d for d in kbinfos [ " doc_aggs " ] if d [ " doc_id " ] in idx ]
2025-03-24 11:19:28 +08:00
if not recall_docs :
recall_docs = kbinfos [ " doc_aggs " ]
kbinfos [ " doc_aggs " ] = recall_docs
refs = deepcopy ( kbinfos )
for c in refs [ " chunks " ] :
if c . get ( " vector " ) :
del c [ " vector " ]
2025-06-03 23:42:47 +08:00
2025-06-07 13:00:07 +08:00
# 特殊错误提示
if " invalid key " in answer . lower ( ) or " invalid api " in answer . lower ( ) :
2025-03-24 11:19:28 +08:00
answer + = " Please set LLM API-Key in ' User Setting -> Model providers -> API-Key ' "
2025-03-31 10:53:42 +08:00
2025-06-07 13:00:07 +08:00
# 时间信息拼接
finish_chat_ts = timer ( )
2025-03-24 11:19:28 +08:00
total_time_cost = ( finish_chat_ts - chat_start_ts ) * 1000
check_llm_time_cost = ( check_llm_ts - chat_start_ts ) * 1000
create_retriever_time_cost = ( create_retriever_ts - check_llm_ts ) * 1000
bind_embedding_time_cost = ( bind_embedding_ts - create_retriever_ts ) * 1000
bind_llm_time_cost = ( bind_llm_ts - bind_embedding_ts ) * 1000
refine_question_time_cost = ( refine_question_ts - bind_llm_ts ) * 1000
bind_reranker_time_cost = ( bind_reranker_ts - refine_question_ts ) * 1000
generate_keyword_time_cost = ( generate_keyword_ts - bind_reranker_ts ) * 1000
retrieval_time_cost = ( retrieval_ts - generate_keyword_ts ) * 1000
generate_result_time_cost = ( finish_chat_ts - retrieval_ts ) * 1000
prompt + = " \n \n ### Query: \n %s " % " " . join ( questions )
prompt = f " { prompt } \n \n - Total: { total_time_cost : .1f } ms \n - Check LLM: { check_llm_time_cost : .1f } ms \n - Create retriever: { create_retriever_time_cost : .1f } ms \n - Bind embedding: { bind_embedding_time_cost : .1f } ms \n - Bind LLM: { bind_llm_time_cost : .1f } ms \n - Tune question: { refine_question_time_cost : .1f } ms \n - Bind reranker: { bind_reranker_time_cost : .1f } ms \n - Generate keyword: { generate_keyword_time_cost : .1f } ms \n - Retrieval: { retrieval_time_cost : .1f } ms \n - Generate answer: { generate_result_time_cost : .1f } ms "
2025-06-07 13:00:07 +08:00
2025-06-03 23:42:47 +08:00
return { " answer " : think + answer , " reference " : refs , " prompt " : re . sub ( r " \ n " , " \n " , prompt ) , " created_at " : time . time ( ) }
2025-03-31 10:53:42 +08:00
2025-03-24 11:19:28 +08:00
if stream :
2025-06-03 23:42:47 +08:00
last_ans = " " # 记录上一次返回的完整回答
answer = " " # 当前累计的完整回答
for ans in chat_mdl . chat_streamly ( prompt + prompt4citation , msg [ 1 : ] , gen_conf ) :
2025-04-03 21:00:49 +08:00
# 如果存在思考过程(thought),移除相关标记
2025-03-24 11:19:28 +08:00
if thought :
ans = re . sub ( r " <think>.*</think> " , " " , ans , flags = re . DOTALL )
answer = ans
2025-04-03 21:00:49 +08:00
# 计算新增的文本片段(delta)
2025-06-03 23:42:47 +08:00
delta_ans = ans [ len ( last_ans ) : ]
2025-04-03 21:00:49 +08:00
# 如果新增token太少(小于16),跳过本次返回(避免频繁发送小片段)
2025-03-24 11:19:28 +08:00
if num_tokens_from_string ( delta_ans ) < 16 :
continue
last_ans = answer
2025-04-03 21:00:49 +08:00
# 返回当前累计回答(包含思考过程)+新增片段)
2025-06-03 23:42:47 +08:00
yield { " answer " : thought + answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) }
delta_ans = answer [ len ( last_ans ) : ]
2025-03-24 11:19:28 +08:00
if delta_ans :
2025-06-03 23:42:47 +08:00
yield { " answer " : thought + answer , " reference " : { } , " audio_binary " : tts ( tts_mdl , delta_ans ) }
yield decorate_answer ( thought + answer )
2025-03-24 11:19:28 +08:00
else :
2025-06-03 23:42:47 +08:00
answer = chat_mdl . chat ( prompt + prompt4citation , msg [ 1 : ] , gen_conf )
2025-03-24 11:19:28 +08:00
user_content = msg [ - 1 ] . get ( " content " , " [content not available] " )
logging . debug ( " User: {} |Assistant: {} " . format ( user_content , answer ) )
res = decorate_answer ( answer )
res [ " audio_binary " ] = tts ( tts_mdl , answer )
yield res
def use_sql ( question , field_map , tenant_id , chat_mdl , quota = True ) :
sys_prompt = " You are a Database Administrator. You need to check the fields of the following tables based on the user ' s list of questions and write the SQL corresponding to the last question. "
user_prompt = """
Table name : { } ;
Table of database fields are as follows :
{ }
Question are as follows :
{ }
Please write the SQL , only SQL , without any other explanations or text .
2025-06-03 23:42:47 +08:00
""" .format(index_name(tenant_id), " \n " .join([f " {k} : {v} " for k, v in field_map.items()]), question)
2025-03-24 11:19:28 +08:00
tried_times = 0
def get_table ( ) :
nonlocal sys_prompt , user_prompt , question , tried_times
2025-06-03 23:42:47 +08:00
sql = chat_mdl . chat ( sys_prompt , [ { " role " : " user " , " content " : user_prompt } ] , { " temperature " : 0.06 } )
2025-03-31 10:53:42 +08:00
sql = re . sub ( r " <think>.*</think> " , " " , sql , flags = re . DOTALL )
2025-03-24 11:19:28 +08:00
logging . debug ( f " { question } ==> { user_prompt } get SQL: { sql } " )
sql = re . sub ( r " [ \ r \ n]+ " , " " , sql . lower ( ) )
sql = re . sub ( r " .*select " , " select " , sql . lower ( ) )
sql = re . sub ( r " + " , " " , sql )
sql = re . sub ( r " ([;; ]|```).* " , " " , sql )
2025-06-03 23:42:47 +08:00
if sql [ : len ( " select " ) ] != " select " :
2025-03-24 11:19:28 +08:00
return None , None
if not re . search ( r " ((sum|avg|max|min) \ (|group by ) " , sql . lower ( ) ) :
2025-06-03 23:42:47 +08:00
if sql [ : len ( " select * " ) ] != " select * " :
2025-03-24 11:19:28 +08:00
sql = " select doc_id,docnm_kwd, " + sql [ 6 : ]
else :
flds = [ ]
for k in field_map . keys ( ) :
if k in forbidden_select_fields4resume :
continue
if len ( flds ) > 11 :
break
flds . append ( k )
sql = " select doc_id,docnm_kwd, " + " , " . join ( flds ) + sql [ 8 : ]
logging . debug ( f " { question } get SQL(refined): { sql } " )
tried_times + = 1
return settings . retrievaler . sql_retrieval ( sql , format = " json " ) , sql
tbl , sql = get_table ( )
if tbl is None :
return None
if tbl . get ( " error " ) and tried_times < = 2 :
user_prompt = """
Table name : { } ;
Table of database fields are as follows :
{ }
Question are as follows :
{ }
Please write the SQL , only SQL , without any other explanations or text .
The SQL error you provided last time is as follows :
{ }
Error issued by database as follows :
{ }
Please correct the error and write SQL again , only SQL , without any other explanations or text .
2025-06-03 23:42:47 +08:00
""" .format(index_name(tenant_id), " \n " .join([f " {k} : {v} " for k, v in field_map.items()]), question, sql, tbl[ " error " ])
2025-03-24 11:19:28 +08:00
tbl , sql = get_table ( )
logging . debug ( " TRY it again: {} " . format ( sql ) )
logging . debug ( " GET table: {} " . format ( tbl ) )
if tbl . get ( " error " ) or len ( tbl [ " rows " ] ) == 0 :
return None
2025-06-03 23:42:47 +08:00
docid_idx = set ( [ ii for ii , c in enumerate ( tbl [ " columns " ] ) if c [ " name " ] == " doc_id " ] )
doc_name_idx = set ( [ ii for ii , c in enumerate ( tbl [ " columns " ] ) if c [ " name " ] == " docnm_kwd " ] )
column_idx = [ ii for ii in range ( len ( tbl [ " columns " ] ) ) if ii not in ( docid_idx | doc_name_idx ) ]
2025-03-24 11:19:28 +08:00
# compose Markdown table
2025-06-03 23:42:47 +08:00
columns = (
" | " + " | " . join ( [ re . sub ( r " (/.*|( [^( ) ]+) ) " , " " , field_map . get ( tbl [ " columns " ] [ i ] [ " name " ] , tbl [ " columns " ] [ i ] [ " name " ] ) ) for i in column_idx ] ) + ( " |Source| " if docid_idx and docid_idx else " | " )
)
2025-03-24 11:19:28 +08:00
2025-06-03 23:42:47 +08:00
line = " | " + " | " . join ( [ " ------ " for _ in range ( len ( column_idx ) ) ] ) + ( " |------| " if docid_idx and docid_idx else " " )
2025-03-24 11:19:28 +08:00
2025-06-03 23:42:47 +08:00
rows = [ " | " + " | " . join ( [ rmSpace ( str ( r [ i ] ) ) for i in column_idx ] ) . replace ( " None " , " " ) + " | " for r in tbl [ " rows " ] ]
2025-03-24 11:19:28 +08:00
rows = [ r for r in rows if re . sub ( r " [ |]+ " , " " , r ) ]
if quota :
rows = " \n " . join ( [ r + f " ## { ii } $$ | " for ii , r in enumerate ( rows ) ] )
else :
rows = " \n " . join ( [ r + f " ## { ii } $$ | " for ii , r in enumerate ( rows ) ] )
rows = re . sub ( r " T[0-9] {2} :[0-9] {2} :[0-9] {2} ( \ .[0-9]+Z)? \ | " , " | " , rows )
if not docid_idx or not doc_name_idx :
logging . warning ( " SQL missing field: " + sql )
2025-06-03 23:42:47 +08:00
return { " answer " : " \n " . join ( [ columns , line , rows ] ) , " reference " : { " chunks " : [ ] , " doc_aggs " : [ ] } , " prompt " : sys_prompt }
2025-03-24 11:19:28 +08:00
docid_idx = list ( docid_idx ) [ 0 ]
doc_name_idx = list ( doc_name_idx ) [ 0 ]
doc_aggs = { }
for r in tbl [ " rows " ] :
if r [ docid_idx ] not in doc_aggs :
doc_aggs [ r [ docid_idx ] ] = { " doc_name " : r [ doc_name_idx ] , " count " : 0 }
doc_aggs [ r [ docid_idx ] ] [ " count " ] + = 1
return {
" answer " : " \n " . join ( [ columns , line , rows ] ) ,
2025-06-03 23:42:47 +08:00
" reference " : {
" chunks " : [ { " doc_id " : r [ docid_idx ] , " docnm_kwd " : r [ doc_name_idx ] } for r in tbl [ " rows " ] ] ,
" doc_aggs " : [ { " doc_id " : did , " doc_name " : d [ " doc_name " ] , " count " : d [ " count " ] } for did , d in doc_aggs . items ( ) ] ,
} ,
" prompt " : sys_prompt ,
2025-03-24 11:19:28 +08:00
}
def tts ( tts_mdl , text ) :
if not tts_mdl or not text :
return
bin = b " "
for chunk in tts_mdl . tts ( text ) :
bin + = chunk
return binascii . hexlify ( bin ) . decode ( " utf-8 " )
def ask ( question , kb_ids , tenant_id ) :
2025-04-03 21:00:49 +08:00
"""
处理用户搜索请求 , 从知识库中检索相关信息并生成回答
2025-06-03 23:42:47 +08:00
2025-04-03 21:00:49 +08:00
参数 :
question ( str ) : 用户的问题或查询
kb_ids ( list ) : 知识库ID列表 , 指定要搜索的知识库
tenant_id ( str ) : 租户ID , 用于权限控制和资源隔离
2025-06-03 23:42:47 +08:00
2025-04-03 21:00:49 +08:00
流程 :
1. 获取指定知识库的信息
2. 确定使用的嵌入模型
3. 根据知识库类型选择检索器 ( 普通检索器或知识图谱检索器 )
4. 初始化嵌入模型和聊天模型
5. 执行检索操作获取相关文档片段
6. 格式化知识库内容作为上下文
7. 构建系统提示词
8. 生成回答并添加引用标记
9. 流式返回生成的回答
2025-06-03 23:42:47 +08:00
2025-04-03 21:00:49 +08:00
返回 :
generator : 生成器对象 , 产生包含回答和引用信息的字典
"""
2025-06-03 23:42:47 +08:00
2025-03-24 11:19:28 +08:00
kbs = KnowledgebaseService . get_by_ids ( kb_ids )
embedding_list = list ( set ( [ kb . embd_id for kb in kbs ] ) )
is_knowledge_graph = all ( [ kb . parser_id == ParserType . KG for kb in kbs ] )
retriever = settings . retrievaler if not is_knowledge_graph else settings . kg_retrievaler
2025-04-03 21:00:49 +08:00
# 初始化嵌入模型,用于将文本转换为向量表示
2025-03-24 11:19:28 +08:00
embd_mdl = LLMBundle ( tenant_id , LLMType . EMBEDDING , embedding_list [ 0 ] )
2025-04-03 21:00:49 +08:00
# 初始化聊天模型,用于生成回答
2025-03-24 11:19:28 +08:00
chat_mdl = LLMBundle ( tenant_id , LLMType . CHAT )
2025-04-03 21:00:49 +08:00
# 获取聊天模型的最大token长度, 用于控制上下文长度
2025-03-24 11:19:28 +08:00
max_tokens = chat_mdl . max_length
2025-04-03 21:00:49 +08:00
# 获取所有知识库的租户ID并去重
2025-03-24 11:19:28 +08:00
tenant_ids = list ( set ( [ kb . tenant_id for kb in kbs ] ) )
2025-06-09 18:46:51 +08:00
# 设置更小的相似度阈值以适配更好的效果(原始值0.1)
similarity_threshold = 0.01
2025-06-03 23:42:47 +08:00
# 调用检索器检索相关文档片段
2025-06-09 18:46:51 +08:00
kbinfos = retriever . retrieval ( question , embd_mdl , tenant_ids , kb_ids , 1 , 12 , similarity_threshold , 0.3 , aggs = False , rank_feature = label_question ( question , kbs ) )
2025-06-03 23:42:47 +08:00
# 将检索结果格式化为提示词, 并确保不超过模型最大token限制
2025-03-24 11:19:28 +08:00
knowledges = kb_prompt ( kbinfos , max_tokens )
prompt = """
2025-06-03 23:42:47 +08:00
角色 : 你是一个聪明的助手 。
任务 : 总结知识库中的信息并回答用户的问题 。
要求与限制 :
- 绝不要捏造内容 , 尤其是数字 。
- 如果知识库中的信息与用户问题无关 , * * 只需回答 : 对不起 , 未提供相关信息 。
- 使用Markdown格式进行回答 。
- 使用用户提问所用的语言作答 。
- 绝不要捏造内容 , 尤其是数字 。
### 来自知识库的信息
2025-03-24 11:19:28 +08:00
% s
2025-06-03 23:42:47 +08:00
以上是来自知识库的信息 。
2025-03-24 11:19:28 +08:00
""" % " \n " .join(knowledges)
msg = [ { " role " : " user " , " content " : question } ]
2025-04-03 21:00:49 +08:00
# 生成完成后添加回答中的引用标记
2025-03-24 11:19:28 +08:00
def decorate_answer ( answer ) :
nonlocal knowledges , kbinfos , prompt
2025-06-03 23:42:47 +08:00
answer , idx = retriever . insert_citations ( answer , [ ck [ " content_ltks " ] for ck in kbinfos [ " chunks " ] ] , [ ck [ " vector " ] for ck in kbinfos [ " chunks " ] ] , embd_mdl , tkweight = 0.7 , vtweight = 0.3 )
2025-03-24 11:19:28 +08:00
idx = set ( [ kbinfos [ " chunks " ] [ int ( i ) ] [ " doc_id " ] for i in idx ] )
2025-06-03 23:42:47 +08:00
recall_docs = [ d for d in kbinfos [ " doc_aggs " ] if d [ " doc_id " ] in idx ]
2025-03-24 11:19:28 +08:00
if not recall_docs :
recall_docs = kbinfos [ " doc_aggs " ]
kbinfos [ " doc_aggs " ] = recall_docs
refs = deepcopy ( kbinfos )
for c in refs [ " chunks " ] :
if c . get ( " vector " ) :
del c [ " vector " ]
if answer . lower ( ) . find ( " invalid key " ) > = 0 or answer . lower ( ) . find ( " invalid api " ) > = 0 :
answer + = " Please set LLM API-Key in ' User Setting -> Model Providers -> API-Key ' "
2025-03-31 10:53:42 +08:00
refs [ " chunks " ] = chunks_format ( refs )
return { " answer " : answer , " reference " : refs }
2025-03-24 11:19:28 +08:00
answer = " "
for ans in chat_mdl . chat_streamly ( prompt , msg , { " temperature " : 0.1 } ) :
answer = ans
yield { " answer " : answer , " reference " : { } }
2025-06-03 23:42:47 +08:00
yield decorate_answer ( answer )