feat(write): 新增文档撰写模式的问答单独调用函数路由

- 添加了 /writechat 路由,用于文档撰写模式的问答调用
This commit is contained in:
zstar 2025-06-03 23:42:47 +08:00
parent d1ed2019e9
commit cfab4bc7bf
5 changed files with 326 additions and 343 deletions

View File

@ -36,7 +36,7 @@ from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.tag import label_question from rag.app.tag import label_question
@manager.route('/set', methods=['POST']) # noqa: F821 @manager.route("/set", methods=["POST"]) # noqa: F821
@login_required @login_required
def set_conversation(): def set_conversation():
req = request.json req = request.json
@ -50,8 +50,7 @@ def set_conversation():
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
e, conv = ConversationService.get_by_id(conv_id) e, conv = ConversationService.get_by_id(conv_id)
if not e: if not e:
return get_data_error_result( return get_data_error_result(message="Fail to update a conversation!")
message="Fail to update a conversation!")
conv = conv.to_dict() conv = conv.to_dict()
return get_json_result(data=conv) return get_json_result(data=conv)
except Exception as e: except Exception as e:
@ -61,38 +60,30 @@ def set_conversation():
e, dia = DialogService.get_by_id(req["dialog_id"]) e, dia = DialogService.get_by_id(req["dialog_id"])
if not e: if not e:
return get_data_error_result(message="Dialog not found") return get_data_error_result(message="Dialog not found")
conv = { conv = {"id": conv_id, "dialog_id": req["dialog_id"], "name": req.get("name", "New conversation"), "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]}
"id": conv_id,
"dialog_id": req["dialog_id"],
"name": req.get("name", "New conversation"),
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
}
ConversationService.save(**conv) ConversationService.save(**conv)
return get_json_result(data=conv) return get_json_result(data=conv)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/get', methods=['GET']) # noqa: F821 @manager.route("/get", methods=["GET"]) # noqa: F821
@login_required @login_required
def get(): def get():
conv_id = request.args["conversation_id"] conv_id = request.args["conversation_id"]
try: try:
e, conv = ConversationService.get_by_id(conv_id) e, conv = ConversationService.get_by_id(conv_id)
if not e: if not e:
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
avatar =None avatar = None
for tenant in tenants: for tenant in tenants:
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id) dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
if dialog and len(dialog)>0: if dialog and len(dialog) > 0:
avatar = dialog[0].icon avatar = dialog[0].icon
break break
else: else:
return get_json_result( return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
data=False, message='Only owner of conversation authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
def get_value(d, k1, k2): def get_value(d, k1, k2):
return d.get(k1, d.get(k2)) return d.get(k1, d.get(k2))
@ -100,26 +91,29 @@ def get():
for ref in conv.reference: for ref in conv.reference:
if isinstance(ref, list): if isinstance(ref, list):
continue continue
ref["chunks"] = [{ ref["chunks"] = [
"id": get_value(ck, "chunk_id", "id"), {
"content": get_value(ck, "content", "content_with_weight"), "id": get_value(ck, "chunk_id", "id"),
"document_id": get_value(ck, "doc_id", "document_id"), "content": get_value(ck, "content", "content_with_weight"),
"document_name": get_value(ck, "docnm_kwd", "document_name"), "document_id": get_value(ck, "doc_id", "document_id"),
"dataset_id": get_value(ck, "kb_id", "dataset_id"), "document_name": get_value(ck, "docnm_kwd", "document_name"),
"image_id": get_value(ck, "image_id", "img_id"), "dataset_id": get_value(ck, "kb_id", "dataset_id"),
"positions": get_value(ck, "positions", "position_int"), "image_id": get_value(ck, "image_id", "img_id"),
} for ck in ref.get("chunks", [])] "positions": get_value(ck, "positions", "position_int"),
}
for ck in ref.get("chunks", [])
]
conv = conv.to_dict() conv = conv.to_dict()
conv["avatar"]=avatar conv["avatar"] = avatar
return get_json_result(data=conv) return get_json_result(data=conv)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/getsse/<dialog_id>', methods=['GET']) # type: ignore # noqa: F821
@manager.route("/getsse/<dialog_id>", methods=["GET"]) # type: ignore # noqa: F821
def getsse(dialog_id): def getsse(dialog_id):
token = request.headers.get("Authorization").split()
token = request.headers.get('Authorization').split()
if len(token) != 2: if len(token) != 2:
return get_data_error_result(message='Authorization is not valid!"') return get_data_error_result(message='Authorization is not valid!"')
token = token[1] token = token[1]
@ -131,13 +125,14 @@ def getsse(dialog_id):
if not e: if not e:
return get_data_error_result(message="Dialog not found!") return get_data_error_result(message="Dialog not found!")
conv = conv.to_dict() conv = conv.to_dict()
conv["avatar"]= conv["icon"] conv["avatar"] = conv["icon"]
del conv["icon"] del conv["icon"]
return get_json_result(data=conv) return get_json_result(data=conv)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/rm', methods=['POST']) # noqa: F821
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required @login_required
def rm(): def rm():
conv_ids = request.json["conversation_ids"] conv_ids = request.json["conversation_ids"]
@ -151,28 +146,21 @@ def rm():
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id): if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
break break
else: else:
return get_json_result( return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
data=False, message='Only owner of conversation authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
ConversationService.delete_by_id(cid) ConversationService.delete_by_id(cid)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/list', methods=['GET']) # noqa: F821 @manager.route("/list", methods=["GET"]) # noqa: F821
@login_required @login_required
def list_convsersation(): def list_convsersation():
dialog_id = request.args["dialog_id"] dialog_id = request.args["dialog_id"]
try: try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id): if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
return get_json_result( return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
data=False, message='Only owner of dialog authorized for this operation.', convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True)
code=settings.RetCode.OPERATING_ERROR)
convs = ConversationService.query(
dialog_id=dialog_id,
order_by=ConversationService.model.create_time,
reverse=True)
convs = [d.to_dict() for d in convs] convs = [d.to_dict() for d in convs]
return get_json_result(data=convs) return get_json_result(data=convs)
@ -180,7 +168,7 @@ def list_convsersation():
return server_error_response(e) return server_error_response(e)
@manager.route('/completion', methods=['POST']) # noqa: F821 @manager.route("/completion", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("conversation_id", "messages") @validate_request("conversation_id", "messages")
def completion(): def completion():
@ -207,25 +195,30 @@ def completion():
if not conv.reference: if not conv.reference:
conv.reference = [] conv.reference = []
else: else:
def get_value(d, k1, k2): def get_value(d, k1, k2):
return d.get(k1, d.get(k2)) return d.get(k1, d.get(k2))
for ref in conv.reference: for ref in conv.reference:
if isinstance(ref, list): if isinstance(ref, list):
continue continue
ref["chunks"] = [{ ref["chunks"] = [
"id": get_value(ck, "chunk_id", "id"), {
"content": get_value(ck, "content", "content_with_weight"), "id": get_value(ck, "chunk_id", "id"),
"document_id": get_value(ck, "doc_id", "document_id"), "content": get_value(ck, "content", "content_with_weight"),
"document_name": get_value(ck, "docnm_kwd", "document_name"), "document_id": get_value(ck, "doc_id", "document_id"),
"dataset_id": get_value(ck, "kb_id", "dataset_id"), "document_name": get_value(ck, "docnm_kwd", "document_name"),
"image_id": get_value(ck, "image_id", "img_id"), "dataset_id": get_value(ck, "kb_id", "dataset_id"),
"positions": get_value(ck, "positions", "position_int"), "image_id": get_value(ck, "image_id", "img_id"),
} for ck in ref.get("chunks", [])] "positions": get_value(ck, "positions", "position_int"),
}
for ck in ref.get("chunks", [])
]
if not conv.reference: if not conv.reference:
conv.reference = [] conv.reference = []
conv.reference.append({"chunks": [], "doc_aggs": []}) conv.reference.append({"chunks": [], "doc_aggs": []})
def stream(): def stream():
nonlocal dia, msg, req, conv nonlocal dia, msg, req, conv
try: try:
@ -235,9 +228,7 @@ def completion():
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield "data:" + json.dumps({"code": 500, "message": str(e), yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True): if req.get("stream", True):
@ -259,7 +250,32 @@ def completion():
return server_error_response(e) return server_error_response(e)
@manager.route('/tts', methods=['POST']) # noqa: F821 # 用于文档撰写模式的问答调用
@manager.route("/writechat", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
def writechat():
req = request.json
uid = current_user.id
def stream():
nonlocal req, uid
try:
for ans in ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
@manager.route("/tts", methods=["POST"]) # noqa: F821
@login_required @login_required
def tts(): def tts():
req = request.json req = request.json
@ -281,9 +297,7 @@ def tts():
for chunk in tts_mdl.tts(txt): for chunk in tts_mdl.tts(txt):
yield chunk yield chunk
except Exception as e: except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e), yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
"data": {"answer": "**ERROR**: " + str(e)}},
ensure_ascii=False)).encode('utf-8')
resp = Response(stream_audio(), mimetype="audio/mpeg") resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache") resp.headers.add_header("Cache-Control", "no-cache")
@ -293,7 +307,7 @@ def tts():
return resp return resp
@manager.route('/delete_msg', methods=['POST']) # noqa: F821 @manager.route("/delete_msg", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("conversation_id", "message_id") @validate_request("conversation_id", "message_id")
def delete_msg(): def delete_msg():
@ -316,7 +330,7 @@ def delete_msg():
return get_json_result(data=conv) return get_json_result(data=conv)
@manager.route('/thumbup', methods=['POST']) # noqa: F821 @manager.route("/thumbup", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("conversation_id", "message_id") @validate_request("conversation_id", "message_id")
def thumbup(): def thumbup():
@ -343,7 +357,7 @@ def thumbup():
return get_json_result(data=conv) return get_json_result(data=conv)
@manager.route('/ask', methods=['POST']) # noqa: F821 @manager.route("/ask", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("question", "kb_ids") @validate_request("question", "kb_ids")
def ask_about(): def ask_about():
@ -356,9 +370,7 @@ def ask_about():
for ans in ask(req["question"], req["kb_ids"], uid): for ans in ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e: except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream") resp = Response(stream(), mimetype="text/event-stream")
@ -369,7 +381,7 @@ def ask_about():
return resp return resp
@manager.route('/mindmap', methods=['POST']) # noqa: F821 @manager.route("/mindmap", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("question", "kb_ids") @validate_request("question", "kb_ids")
def mindmap(): def mindmap():
@ -382,10 +394,7 @@ def mindmap():
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
question = req["question"] question = req["question"]
ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, 0.3, 0.3, aggs=False, rank_feature=label_question(question, [kb]))
0.3, 0.3, aggs=False,
rank_feature=label_question(question, [kb])
)
mindmap = MindMapExtractor(chat_mdl) mindmap = MindMapExtractor(chat_mdl)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]]) mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
mind_map = mind_map.output mind_map = mind_map.output
@ -394,7 +403,7 @@ def mindmap():
return get_json_result(data=mind_map) return get_json_result(data=mind_map)
@manager.route('/related_questions', methods=['POST']) # noqa: F821 @manager.route("/related_questions", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("question") @validate_request("question")
def related_questions(): def related_questions():
@ -425,8 +434,17 @@ Reason:
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results. - At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
""" """
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f""" ans = chat_mdl.chat(
prompt,
[
{
"role": "user",
"content": f"""
Keywords: {question} Keywords: {question}
Related search terms: Related search terms:
"""}], {"temperature": 0.9}) """,
}
],
{"temperature": 0.9},
)
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])

View File

@ -30,8 +30,7 @@ from api import settings
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, \ from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, citation_prompt
citation_prompt
from rag.utils import rmSpace, num_tokens_from_string from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.tavily_conn import Tavily from rag.utils.tavily_conn import Tavily
@ -41,17 +40,13 @@ class DialogService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_list(cls, tenant_id, def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
page_number, items_per_page, orderby, desc, id, name):
chats = cls.model.select() chats = cls.model.select()
if id: if id:
chats = chats.where(cls.model.id == id) chats = chats.where(cls.model.id == id)
if name: if name:
chats = chats.where(cls.model.name == name) chats = chats.where(cls.model.name == name)
chats = chats.where( chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value))
(cls.model.tenant_id == tenant_id)
& (cls.model.status == StatusEnum.VALID.value)
)
if desc: if desc:
chats = chats.order_by(cls.model.getter_by(orderby).desc()) chats = chats.order_by(cls.model.getter_by(orderby).desc())
else: else:
@ -72,13 +67,12 @@ def chat_solo(dialog, messages, stream=True):
tts_mdl = None tts_mdl = None
if prompt_config.get("tts"): if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
for m in messages if m["role"] != "system"]
if stream: if stream:
last_ans = "" last_ans = ""
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans answer = ans
delta_ans = ans[len(last_ans):] delta_ans = ans[len(last_ans) :]
if num_tokens_from_string(delta_ans) < 16: if num_tokens_from_string(delta_ans) < 16:
continue continue
last_ans = answer last_ans = answer
@ -159,9 +153,8 @@ def chat(dialog, messages, stream=True, **kwargs):
if p["key"] not in kwargs and not p["optional"]: if p["key"] not in kwargs and not p["optional"]:
raise KeyError("Miss parameter: " + p["key"]) raise KeyError("Miss parameter: " + p["key"])
if p["key"] not in kwargs: if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace( prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
"{%s}" % p["key"], " ")
# 不再使用多轮对话优化 # 不再使用多轮对话优化
# if len(questions) > 1 and prompt_config.get("refine_multiturn"): # if len(questions) > 1 and prompt_config.get("refine_multiturn"):
# questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] # questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
@ -190,7 +183,7 @@ def chat(dialog, messages, stream=True, **kwargs):
tenant_ids = list(set([kb.tenant_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs]))
knowledges = [] knowledges = []
# 不再使用推理 # 不再使用推理
# if prompt_config.get("reasoning", False): # if prompt_config.get("reasoning", False):
# reasoner = DeepResearcher(chat_mdl, # reasoner = DeepResearcher(chat_mdl,
@ -226,17 +219,24 @@ def chat(dialog, messages, stream=True, **kwargs):
# kbinfos["chunks"].insert(0, ck) # kbinfos["chunks"].insert(0, ck)
# knowledges = kb_prompt(kbinfos, max_tokens) # knowledges = kb_prompt(kbinfos, max_tokens)
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, kbinfos = retriever.retrieval(
dialog.similarity_threshold, " ".join(questions),
dialog.vector_similarity_weight, embd_mdl,
doc_ids=attachments, tenant_ids,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, dialog.kb_ids,
rank_feature=label_question(" ".join(questions), kbs) 1,
) dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs),
)
knowledges = kb_prompt(kbinfos, max_tokens) knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug( logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_ts = timer() retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"): if not knowledges and prompt_config.get("empty_response"):
@ -252,22 +252,19 @@ def chat(dialog, messages, stream=True, **kwargs):
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
prompt4citation = citation_prompt() prompt4citation = citation_prompt()
# 过滤掉 system 角色的消息(因为前面已经单独处理了系统消息) # 过滤掉 system 角色的消息(因为前面已经单独处理了系统消息)
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
for m in messages if m["role"] != "system"])
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95)) used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
assert len(msg) >= 2, f"message_fit_in has bug: {msg}" assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
prompt = msg[0]["content"] prompt = msg[0]["content"]
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min( gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
gen_conf["max_tokens"],
max_tokens - used_token_count)
def decorate_answer(answer): def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
refs = [] refs = []
image_markdowns = [] # 用于存储图片的 Markdown 字符串 image_markdowns = [] # 用于存储图片的 Markdown 字符串
ans = answer.split("</think>") ans = answer.split("</think>")
think = "" think = ""
if len(ans) == 2: if len(ans) == 2:
@ -275,29 +272,29 @@ def chat(dialog, messages, stream=True, **kwargs):
answer = ans[1] answer = ans[1]
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL) answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
cited_chunk_indices = set() # 用于存储被引用的 chunk 索引 cited_chunk_indices = set() # 用于存储被引用的 chunk 索引
if not re.search(r"##[0-9]+\$\$", answer): if not re.search(r"##[0-9]+\$\$", answer):
answer, idx = retriever.insert_citations(answer, answer, idx = retriever.insert_citations(
[ck["content_ltks"] answer,
for ck in kbinfos["chunks"]], [ck["content_ltks"] for ck in kbinfos["chunks"]],
[ck["vector"] [ck["vector"] for ck in kbinfos["chunks"]],
for ck in kbinfos["chunks"]], embd_mdl,
embd_mdl, tkweight=1 - dialog.vector_similarity_weight,
tkweight=1 - dialog.vector_similarity_weight, vtweight=dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight) )
cited_chunk_indices = idx # 获取 insert_citations 返回的索引 cited_chunk_indices = idx # 获取 insert_citations 返回的索引
else: else:
idx = set([]) idx = set([])
for r in re.finditer(r"##([0-9]+)\$\$", answer): for r in re.finditer(r"##([0-9]+)\$\$", answer):
i = int(r.group(1)) i = int(r.group(1))
if i < len(kbinfos["chunks"]): if i < len(kbinfos["chunks"]):
idx.add(i) idx.add(i)
cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引 cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引
# 根据引用的 chunk 索引提取图像信息并生成 Markdown # 根据引用的 chunk 索引提取图像信息并生成 Markdown
cited_doc_ids = set() cited_doc_ids = set()
processed_image_urls = set() # 避免重复添加同一张图片 processed_image_urls = set() # 避免重复添加同一张图片
print(f"DEBUG: cited_chunk_indices = {cited_chunk_indices}") print(f"DEBUG: cited_chunk_indices = {cited_chunk_indices}")
for i in cited_chunk_indices: for i in cited_chunk_indices:
i_int = int(i) i_int = int(i)
@ -312,11 +309,10 @@ def chat(dialog, messages, stream=True, **kwargs):
# 生成 Markdown 字符串alt text 可以简单设为 "image" 或 chunk ID # 生成 Markdown 字符串alt text 可以简单设为 "image" 或 chunk ID
alt_text = f"image_chunk_{chunk.get('chunk_id', i_int)}" alt_text = f"image_chunk_{chunk.get('chunk_id', i_int)}"
image_markdowns.append(f"\n![{alt_text}]({img_url})") image_markdowns.append(f"\n![{alt_text}]({img_url})")
processed_image_urls.add(img_url) # 标记为已处理 processed_image_urls.add(img_url) # 标记为已处理
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [ recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs: if not recall_docs:
recall_docs = kbinfos["doc_aggs"] recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs kbinfos["doc_aggs"] = recall_docs
@ -325,7 +321,7 @@ def chat(dialog, messages, stream=True, **kwargs):
for c in refs["chunks"]: for c in refs["chunks"]:
if c.get("vector"): if c.get("vector"):
del c["vector"] del c["vector"]
# 将图片的 Markdown 字符串追加到回答末尾 # 将图片的 Markdown 字符串追加到回答末尾
if image_markdowns: if image_markdowns:
answer += "".join(image_markdowns) answer += "".join(image_markdowns)
@ -347,30 +343,30 @@ def chat(dialog, messages, stream=True, **kwargs):
prompt += "\n\n### Query:\n%s" % " ".join(questions) prompt += "\n\n### Query:\n%s" % " ".join(questions)
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms" prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()} return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
if stream: if stream:
last_ans = "" # 记录上一次返回的完整回答 last_ans = "" # 记录上一次返回的完整回答
answer = "" # 当前累计的完整回答 answer = "" # 当前累计的完整回答
for ans in chat_mdl.chat_streamly(prompt+prompt4citation, msg[1:], gen_conf): for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
# 如果存在思考过程(thought),移除相关标记 # 如果存在思考过程(thought),移除相关标记
if thought: if thought:
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
answer = ans answer = ans
# 计算新增的文本片段(delta) # 计算新增的文本片段(delta)
delta_ans = ans[len(last_ans):] delta_ans = ans[len(last_ans) :]
# 如果新增token太少(小于16),跳过本次返回(避免频繁发送小片段) # 如果新增token太少(小于16),跳过本次返回(避免频繁发送小片段)
if num_tokens_from_string(delta_ans) < 16: if num_tokens_from_string(delta_ans) < 16:
continue continue
last_ans = answer last_ans = answer
# 返回当前累计回答(包含思考过程)+新增片段) # 返回当前累计回答(包含思考过程)+新增片段)
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans):] delta_ans = answer[len(last_ans) :]
if delta_ans: if delta_ans:
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought+answer) yield decorate_answer(thought + answer)
else: else:
answer = chat_mdl.chat(prompt+prompt4citation, msg[1:], gen_conf) answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]") user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer)) logging.debug("User: {}|Assistant: {}".format(user_content, answer))
res = decorate_answer(answer) res = decorate_answer(answer)
@ -388,27 +384,22 @@ Table of database fields are as follows:
Question are as follows: Question are as follows:
{} {}
Please write the SQL, only SQL, without any other explanations or text. Please write the SQL, only SQL, without any other explanations or text.
""".format( """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question
)
tried_times = 0 tried_times = 0
def get_table(): def get_table():
nonlocal sys_prompt, user_prompt, question, tried_times nonlocal sys_prompt, user_prompt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], { sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
"temperature": 0.06})
sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL) sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL)
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower()) sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower()) sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql) sql = re.sub(r" +", " ", sql)
sql = re.sub(r"([;]|```).*", "", sql) sql = re.sub(r"([;]|```).*", "", sql)
if sql[:len("select ")] != "select ": if sql[: len("select ")] != "select ":
return None, None return None, None
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()): if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
if sql[:len("select *")] != "select *": if sql[: len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:] sql = "select doc_id,docnm_kwd," + sql[6:]
else: else:
flds = [] flds = []
@ -445,11 +436,7 @@ Please write the SQL, only SQL, without any other explanations or text.
{} {}
Please correct the error and write SQL again, only SQL, without any other explanations or text. Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format( """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question, sql, tbl["error"]
)
tbl, sql = get_table() tbl, sql = get_table()
logging.debug("TRY it again: {}".format(sql)) logging.debug("TRY it again: {}".format(sql))
@ -457,24 +444,18 @@ Please write the SQL, only SQL, without any other explanations or text.
if tbl.get("error") or len(tbl["rows"]) == 0: if tbl.get("error") or len(tbl["rows"]) == 0:
return None return None
docid_idx = set([ii for ii, c in enumerate( docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
tbl["columns"]) if c["name"] == "doc_id"]) doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
doc_name_idx = set([ii for ii, c in enumerate( column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
tbl["columns"]) if c["name"] == "docnm_kwd"])
column_idx = [ii for ii in range(
len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
# compose Markdown table # compose Markdown table
columns = "|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], columns = (
tbl["columns"][i]["name"])) for i in "|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") )
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \ line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
("|------|" if docid_idx and docid_idx else "")
rows = ["|" + rows = ["|" + "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
"|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
"|" for r in tbl["rows"]]
rows = [r for r in rows if re.sub(r"[ |]+", "", r)] rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
if quota: if quota:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
@ -484,11 +465,7 @@ Please write the SQL, only SQL, without any other explanations or text.
if not docid_idx or not doc_name_idx: if not docid_idx or not doc_name_idx:
logging.warning("SQL missing field: " + sql) logging.warning("SQL missing field: " + sql)
return { return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
"answer": "\n".join([columns, line, rows]),
"reference": {"chunks": [], "doc_aggs": []},
"prompt": sys_prompt
}
docid_idx = list(docid_idx)[0] docid_idx = list(docid_idx)[0]
doc_name_idx = list(doc_name_idx)[0] doc_name_idx = list(doc_name_idx)[0]
@ -499,10 +476,11 @@ Please write the SQL, only SQL, without any other explanations or text.
doc_aggs[r[docid_idx]]["count"] += 1 doc_aggs[r[docid_idx]]["count"] += 1
return { return {
"answer": "\n".join([columns, line, rows]), "answer": "\n".join([columns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], "reference": {
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
doc_aggs.items()]}, "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()],
"prompt": sys_prompt },
"prompt": sys_prompt,
} }
@ -518,12 +496,12 @@ def tts(tts_mdl, text):
def ask(question, kb_ids, tenant_id): def ask(question, kb_ids, tenant_id):
""" """
处理用户搜索请求从知识库中检索相关信息并生成回答 处理用户搜索请求从知识库中检索相关信息并生成回答
参数: 参数:
question (str): 用户的问题或查询 question (str): 用户的问题或查询
kb_ids (list): 知识库ID列表指定要搜索的知识库 kb_ids (list): 知识库ID列表指定要搜索的知识库
tenant_id (str): 租户ID用于权限控制和资源隔离 tenant_id (str): 租户ID用于权限控制和资源隔离
流程: 流程:
1. 获取指定知识库的信息 1. 获取指定知识库的信息
2. 确定使用的嵌入模型 2. 确定使用的嵌入模型
@ -534,11 +512,11 @@ def ask(question, kb_ids, tenant_id):
7. 构建系统提示词 7. 构建系统提示词
8. 生成回答并添加引用标记 8. 生成回答并添加引用标记
9. 流式返回生成的回答 9. 流式返回生成的回答
返回: 返回:
generator: 生成器对象产生包含回答和引用信息的字典 generator: 生成器对象产生包含回答和引用信息的字典
""" """
kbs = KnowledgebaseService.get_by_ids(kb_ids) kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs])) embedding_list = list(set([kb.embd_id for kb in kbs]))
@ -552,27 +530,24 @@ def ask(question, kb_ids, tenant_id):
max_tokens = chat_mdl.max_length max_tokens = chat_mdl.max_length
# 获取所有知识库的租户ID并去重 # 获取所有知识库的租户ID并去重
tenant_ids = list(set([kb.tenant_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs]))
# 调用检索器检索相关文档片段 # 调用检索器检索相关文档片段
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))
1, 12, 0.1, 0.3, aggs=False, # 将检索结果格式化为提示词并确保不超过模型最大token限制
rank_feature=label_question(question, kbs)
)
# 将检索结果格式化为提示词并确保不超过模型最大token限制
knowledges = kb_prompt(kbinfos, max_tokens) knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """ prompt = """
Role: You're a smart assistant. Your name is Miss R. 角色你是一个聪明的助手
Task: Summarize the information from knowledge bases and answer user's question. 任务总结知识库中的信息并回答用户的问题
Requirements and restriction: 要求与限制
- DO NOT make things up, especially for numbers. - 绝不要捏造内容尤其是数字
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided. - 如果知识库中的信息与用户问题无关**只需回答对不起未提供相关信息
- Answer with markdown format text. - 使用Markdown格式进行回答
- Answer in language of user's question. - 使用用户提问所用的语言作答
- DO NOT make things up, especially for numbers. - 绝不要捏造内容尤其是数字
### Information from knowledge bases ### 来自知识库的信息
%s %s
The above is information from knowledge bases. 以上是来自知识库的信息
""" % "\n".join(knowledges) """ % "\n".join(knowledges)
msg = [{"role": "user", "content": question}] msg = [{"role": "user", "content": question}]
@ -580,17 +555,9 @@ def ask(question, kb_ids, tenant_id):
# 生成完成后添加回答中的引用标记 # 生成完成后添加回答中的引用标记
def decorate_answer(answer): def decorate_answer(answer):
nonlocal knowledges, kbinfos, prompt nonlocal knowledges, kbinfos, prompt
answer, idx = retriever.insert_citations(answer, answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=0.7,
vtweight=0.3)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [ recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs: if not recall_docs:
recall_docs = kbinfos["doc_aggs"] recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs kbinfos["doc_aggs"] = recall_docs
@ -608,4 +575,4 @@ def ask(question, kb_ids, tenant_id):
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
answer = ans answer = ans
yield {"answer": answer, "reference": {}} yield {"answer": answer, "reference": {}}
yield decorate_answer(answer) yield decorate_answer(answer)

View File

@ -0,0 +1,123 @@
import { Authorization } from '@/constants/authorization';
import { IAnswer } from '@/interfaces/database/chat';
import { IKnowledge } from '@/interfaces/database/knowledge';
import kbService from '@/services/knowledge-service';
import api from '@/utils/api';
import { getAuthorization } from '@/utils/authorization-util';
import { useQuery } from '@tanstack/react-query';
import { EventSourceParserStream } from 'eventsource-parser/stream';
import { useCallback, useRef, useState } from 'react';
// 查询知识库数据
export const useFetchKnowledgeList = (
shouldFilterListWithoutDocument: boolean = false,
): {
list: IKnowledge[];
loading: boolean;
} => {
const { data, isFetching: loading } = useQuery({
queryKey: ['fetchKnowledgeList'],
initialData: [],
gcTime: 0,
queryFn: async () => {
const { data } = await kbService.getList();
const list = data?.data?.kbs ?? [];
return shouldFilterListWithoutDocument
? list.filter((x: IKnowledge) => x.chunk_num > 0)
: list;
},
});
return { list: data, loading };
};
// 发送问答信息
export const useSendMessageWithSse = (url: string = api.writeChat) => {
const [answer, setAnswer] = useState<IAnswer>({} as IAnswer);
const [done, setDone] = useState(true);
const timer = useRef<any>();
const sseRef = useRef<AbortController>();
const initializeSseRef = useCallback(() => {
sseRef.current = new AbortController();
}, []);
const resetAnswer = useCallback(() => {
if (timer.current) {
clearTimeout(timer.current);
}
timer.current = setTimeout(() => {
setAnswer({} as IAnswer);
clearTimeout(timer.current);
}, 1000);
}, []);
const send = useCallback(
async (
body: any,
controller?: AbortController,
): Promise<{ response: Response; data: ResponseType } | undefined> => {
initializeSseRef();
try {
setDone(false);
const response = await fetch(url, {
method: 'POST',
headers: {
[Authorization]: getAuthorization(),
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
signal: controller?.signal || sseRef.current?.signal,
});
const res = response.clone().json();
const reader = response?.body
?.pipeThrough(new TextDecoderStream())
.pipeThrough(new EventSourceParserStream())
.getReader();
while (true) {
const x = await reader?.read();
if (x) {
const { done, value } = x;
if (done) {
console.info('done');
resetAnswer();
break;
}
try {
const val = JSON.parse(value?.data || '');
const d = val?.data;
if (typeof d !== 'boolean') {
console.info('data:', d);
setAnswer({
...d,
conversationId: body?.conversation_id,
});
}
} catch (e) {
console.warn(e);
}
}
}
console.info('done?');
setDone(true);
resetAnswer();
return { data: await res, response };
} catch (e) {
setDone(true);
resetAnswer();
console.warn(e);
}
},
[initializeSseRef, url, resetAnswer],
);
const stopOutputMessage = useCallback(() => {
sseRef.current?.abort();
}, []);
return { send, answer, done, setDone, resetAnswer, stopOutputMessage };
};

View File

@ -1,5 +1,6 @@
import HightLightMarkdown from '@/components/highlight-markdown'; import HightLightMarkdown from '@/components/highlight-markdown';
import { useTranslate } from '@/hooks/common-hooks'; import { useTranslate } from '@/hooks/common-hooks';
import { useFetchKnowledgeList } from '@/hooks/write-hooks';
import { DeleteOutlined } from '@ant-design/icons'; import { DeleteOutlined } from '@ant-design/icons';
import { import {
Button, Button,
@ -18,7 +19,6 @@ import {
Space, Space,
Typography, Typography,
} from 'antd'; } from 'antd';
import axios from 'axios';
import { import {
AlignmentType, AlignmentType,
Document, Document,
@ -84,6 +84,10 @@ const Write = () => {
const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBaseItem[]>([]); const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBaseItem[]>([]);
const [isLoadingKbs, setIsLoadingKbs] = useState(false); const [isLoadingKbs, setIsLoadingKbs] = useState(false);
// 使用 useFetchKnowledgeList hook 获取真实数据
const { list: knowledgeList, loading: isLoadingKnowledgeList } =
useFetchKnowledgeList(true);
const getInitialDefaultTemplateDefinitions = useCallback( const getInitialDefaultTemplateDefinitions = useCallback(
(): TemplateItem[] => [ (): TemplateItem[] => [
{ {
@ -167,55 +171,18 @@ const Write = () => {
loadOrInitializeTemplates(); loadOrInitializeTemplates();
}, [loadOrInitializeTemplates]); }, [loadOrInitializeTemplates]);
// 将 knowledgeList 数据同步到 knowledgeBases 状态
useEffect(() => { useEffect(() => {
const fetchKbs = async () => { if (knowledgeList && knowledgeList.length > 0) {
const authorization = localStorage.getItem('Authorization'); setKnowledgeBases(
if (!authorization) { knowledgeList.map((kb) => ({
setKnowledgeBases([]); id: kb.id,
return; name: kb.name,
} })),
setIsLoadingKbs(true); );
try { setIsLoadingKbs(isLoadingKnowledgeList);
await new Promise((resolve) => { }
setTimeout(resolve, 500); }, [knowledgeList, isLoadingKnowledgeList]);
});
const mockKbs: KnowledgeBaseItem[] = [
{
id: 'kb_default',
name: t('defaultKnowledgeBase', { defaultValue: '默认知识库' }),
},
{
id: 'kb_tech',
name: t('technicalDocsKnowledgeBase', {
defaultValue: '技术文档知识库',
}),
},
{
id: 'kb_product',
name: t('productInfoKnowledgeBase', {
defaultValue: '产品信息知识库',
}),
},
{
id: 'kb_marketing',
name: t('marketingMaterialsKB', { defaultValue: '市场营销材料库' }),
},
{
id: 'kb_legal',
name: t('legalDocumentsKB', { defaultValue: '法律文件库' }),
},
];
setKnowledgeBases(mockKbs);
} catch (error) {
console.error('获取知识库失败:', error);
message.error(t('fetchKnowledgeBaseFailed'));
setKnowledgeBases([]);
} finally {
setIsLoadingKbs(false);
}
};
fetchKbs();
}, [t]);
useEffect(() => { useEffect(() => {
const loadDraftContent = () => { const loadDraftContent = () => {
@ -284,6 +251,7 @@ const Write = () => {
} }
}; };
// 删除模板
const handleDeleteTemplate = (templateId: string) => { const handleDeleteTemplate = (templateId: string) => {
try { try {
const updatedTemplates = templates.filter((t) => t.id !== templateId); const updatedTemplates = templates.filter((t) => t.id !== templateId);
@ -339,102 +307,6 @@ const Write = () => {
setIsAiLoading(false); setIsAiLoading(false);
return; return;
} }
const conversationId =
Math.random().toString(36).substring(2) + Date.now().toString(36);
await axios.post(
'v1/conversation/set',
{
name: '文档撰写对话',
is_new: true,
conversation_id: conversationId,
message: [{ role: 'assistant', content: '新对话' }],
},
{ headers: { authorization }, signal: controller.signal },
);
const combinedQuestion = `${aiQuestion}\n\n${t('currentDocumentContextLabel')}:\n${originalContent}`;
let lastReceivedContent = '';
const response = await axios.post(
'/v1/conversation/completion',
{
conversation_id: conversationId,
messages: [{ role: 'user', content: combinedQuestion }],
knowledge_base_ids:
selectedKnowledgeBases.length > 0
? selectedKnowledgeBases
: undefined,
similarity_threshold: similarityThreshold,
keyword_similarity_weight: keywordSimilarityWeight,
temperature: modelTemperature,
},
{
timeout: aiAssistantConfig.api.timeout,
headers: { authorization },
signal: controller.signal,
},
);
if (response.data) {
const lines = response.data
.split('\n')
.filter((line: string) => line.trim());
for (let i = 0; i < lines.length; i++) {
try {
const jsonStr = lines[i].replace('data:', '').trim();
const jsonData = JSON.parse(jsonStr);
if (jsonData.code === 0 && jsonData.data?.answer) {
const answerChunk = jsonData.data.answer;
const cleanedAnswerChunk = answerChunk
.replace(/<think>[\s\S]*?<\/think>/g, '')
.trim();
const hasUnclosedThink =
cleanedAnswerChunk.includes('<think>') &&
(!cleanedAnswerChunk.includes('</think>') ||
cleanedAnswerChunk.indexOf('<think>') >
cleanedAnswerChunk.lastIndexOf('</think>'));
if (cleanedAnswerChunk && !hasUnclosedThink) {
const incrementalContent = cleanedAnswerChunk.substring(
lastReceivedContent.length,
);
if (incrementalContent) {
lastReceivedContent = cleanedAnswerChunk;
let newFullContent,
newCursorPosAfterInsertion = cursorPosition;
if (initialCursorPos !== null && showCursorIndicator) {
newFullContent =
beforeCursor + cleanedAnswerChunk + afterCursor;
newCursorPosAfterInsertion =
initialCursorPos + cleanedAnswerChunk.length;
} else {
newFullContent = originalContent + cleanedAnswerChunk;
newCursorPosAfterInsertion = newFullContent.length;
}
setContent(newFullContent);
setCursorPosition(newCursorPosAfterInsertion);
setTimeout(() => {
if (textAreaRef.current) {
textAreaRef.current.focus();
textAreaRef.current.setSelectionRange(
newCursorPosAfterInsertion!,
newCursorPosAfterInsertion!,
);
}
}, 0);
}
}
}
} catch (parseErr) {
console.error('解析单行数据失败:', parseErr);
}
if (i < lines.length - 1)
await new Promise((resolve) => {
setTimeout(resolve, 10);
});
}
}
await axios.post(
'/v1/conversation/rm',
{ conversation_ids: [conversationId], dialog_id: dialogId },
{ headers: { authorization } },
);
} catch (error: any) { } catch (error: any) {
console.error('AI助手处理失败:', error); console.error('AI助手处理失败:', error);
if (error.code === 'ECONNABORTED' || error.name === 'AbortError') { if (error.code === 'ECONNABORTED' || error.name === 'AbortError') {
@ -455,6 +327,7 @@ const Write = () => {
} }
}; };
// 导出为Word
const handleSave = () => { const handleSave = () => {
const selectedTemplateItem = templates.find( const selectedTemplateItem = templates.find(
(item) => item.id === selectedTemplate, (item) => item.id === selectedTemplate,

View File

@ -100,6 +100,8 @@ export default {
getExternalConversation: `${api_host}/api/conversation`, getExternalConversation: `${api_host}/api/conversation`,
completeExternalConversation: `${api_host}/api/completion`, completeExternalConversation: `${api_host}/api/completion`,
uploadAndParseExternal: `${api_host}/api/document/upload_and_parse`, uploadAndParseExternal: `${api_host}/api/document/upload_and_parse`,
// 文档撰写模式中的问答API
writeChat: `${api_host}/conversation/writechat`,
// file manager // file manager
listFile: `${api_host}/file/list`, listFile: `${api_host}/file/list`,