Merge pull request #144 from zstar1003/dev

feat(write): 重构文档写作接口,实现问答接口解耦,并支持流式输出
This commit is contained in:
zstar 2025-06-04 18:57:56 +08:00 committed by GitHub
commit e4a4786ca3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 838 additions and 414 deletions

View File

@ -30,13 +30,14 @@ from api.db.services.dialog_service import DialogService, chat, ask
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantService from api.db.services.llm_service import LLMBundle, TenantService
from api import settings from api import settings
from api.db.services.write_service import write_dialog
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from graphrag.general.mind_map_extractor import MindMapExtractor from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.tag import label_question from rag.app.tag import label_question
@manager.route('/set', methods=['POST']) # noqa: F821 @manager.route("/set", methods=["POST"]) # type: ignore # noqa: F821
@login_required @login_required
def set_conversation(): def set_conversation():
req = request.json req = request.json
@ -50,8 +51,7 @@ def set_conversation():
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
e, conv = ConversationService.get_by_id(conv_id) e, conv = ConversationService.get_by_id(conv_id)
if not e: if not e:
return get_data_error_result( return get_data_error_result(message="Fail to update a conversation!")
message="Fail to update a conversation!")
conv = conv.to_dict() conv = conv.to_dict()
return get_json_result(data=conv) return get_json_result(data=conv)
except Exception as e: except Exception as e:
@ -61,38 +61,30 @@ def set_conversation():
e, dia = DialogService.get_by_id(req["dialog_id"]) e, dia = DialogService.get_by_id(req["dialog_id"])
if not e: if not e:
return get_data_error_result(message="Dialog not found") return get_data_error_result(message="Dialog not found")
conv = { conv = {"id": conv_id, "dialog_id": req["dialog_id"], "name": req.get("name", "New conversation"), "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]}
"id": conv_id,
"dialog_id": req["dialog_id"],
"name": req.get("name", "New conversation"),
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
}
ConversationService.save(**conv) ConversationService.save(**conv)
return get_json_result(data=conv) return get_json_result(data=conv)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/get', methods=['GET']) # noqa: F821 @manager.route("/get", methods=["GET"]) # type: ignore # type: ignore # noqa: F821
@login_required @login_required
def get(): def get():
conv_id = request.args["conversation_id"] conv_id = request.args["conversation_id"]
try: try:
e, conv = ConversationService.get_by_id(conv_id) e, conv = ConversationService.get_by_id(conv_id)
if not e: if not e:
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
avatar =None avatar = None
for tenant in tenants: for tenant in tenants:
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id) dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
if dialog and len(dialog)>0: if dialog and len(dialog) > 0:
avatar = dialog[0].icon avatar = dialog[0].icon
break break
else: else:
return get_json_result( return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
data=False, message='Only owner of conversation authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
def get_value(d, k1, k2): def get_value(d, k1, k2):
return d.get(k1, d.get(k2)) return d.get(k1, d.get(k2))
@ -100,26 +92,29 @@ def get():
for ref in conv.reference: for ref in conv.reference:
if isinstance(ref, list): if isinstance(ref, list):
continue continue
ref["chunks"] = [{ ref["chunks"] = [
"id": get_value(ck, "chunk_id", "id"), {
"content": get_value(ck, "content", "content_with_weight"), "id": get_value(ck, "chunk_id", "id"),
"document_id": get_value(ck, "doc_id", "document_id"), "content": get_value(ck, "content", "content_with_weight"),
"document_name": get_value(ck, "docnm_kwd", "document_name"), "document_id": get_value(ck, "doc_id", "document_id"),
"dataset_id": get_value(ck, "kb_id", "dataset_id"), "document_name": get_value(ck, "docnm_kwd", "document_name"),
"image_id": get_value(ck, "image_id", "img_id"), "dataset_id": get_value(ck, "kb_id", "dataset_id"),
"positions": get_value(ck, "positions", "position_int"), "image_id": get_value(ck, "image_id", "img_id"),
} for ck in ref.get("chunks", [])] "positions": get_value(ck, "positions", "position_int"),
}
for ck in ref.get("chunks", [])
]
conv = conv.to_dict() conv = conv.to_dict()
conv["avatar"]=avatar conv["avatar"] = avatar
return get_json_result(data=conv) return get_json_result(data=conv)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/getsse/<dialog_id>', methods=['GET']) # type: ignore # noqa: F821
@manager.route("/getsse/<dialog_id>", methods=["GET"]) # type: ignore # noqa: F821
def getsse(dialog_id): def getsse(dialog_id):
token = request.headers.get("Authorization").split()
token = request.headers.get('Authorization').split()
if len(token) != 2: if len(token) != 2:
return get_data_error_result(message='Authorization is not valid!"') return get_data_error_result(message='Authorization is not valid!"')
token = token[1] token = token[1]
@ -131,13 +126,14 @@ def getsse(dialog_id):
if not e: if not e:
return get_data_error_result(message="Dialog not found!") return get_data_error_result(message="Dialog not found!")
conv = conv.to_dict() conv = conv.to_dict()
conv["avatar"]= conv["icon"] conv["avatar"] = conv["icon"]
del conv["icon"] del conv["icon"]
return get_json_result(data=conv) return get_json_result(data=conv)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/rm', methods=['POST']) # noqa: F821
@manager.route("/rm", methods=["POST"]) # type: ignore # type: ignore # noqa: F821
@login_required @login_required
def rm(): def rm():
conv_ids = request.json["conversation_ids"] conv_ids = request.json["conversation_ids"]
@ -151,28 +147,21 @@ def rm():
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id): if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
break break
else: else:
return get_json_result( return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
data=False, message='Only owner of conversation authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
ConversationService.delete_by_id(cid) ConversationService.delete_by_id(cid)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/list', methods=['GET']) # noqa: F821 @manager.route("/list", methods=["GET"]) # type: ignore # noqa: F821
@login_required @login_required
def list_convsersation(): def list_convsersation():
dialog_id = request.args["dialog_id"] dialog_id = request.args["dialog_id"]
try: try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id): if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
return get_json_result( return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
data=False, message='Only owner of dialog authorized for this operation.', convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True)
code=settings.RetCode.OPERATING_ERROR)
convs = ConversationService.query(
dialog_id=dialog_id,
order_by=ConversationService.model.create_time,
reverse=True)
convs = [d.to_dict() for d in convs] convs = [d.to_dict() for d in convs]
return get_json_result(data=convs) return get_json_result(data=convs)
@ -180,7 +169,7 @@ def list_convsersation():
return server_error_response(e) return server_error_response(e)
@manager.route('/completion', methods=['POST']) # noqa: F821 @manager.route("/completion", methods=["POST"]) # type: ignore # noqa: F821
@login_required @login_required
@validate_request("conversation_id", "messages") @validate_request("conversation_id", "messages")
def completion(): def completion():
@ -207,25 +196,30 @@ def completion():
if not conv.reference: if not conv.reference:
conv.reference = [] conv.reference = []
else: else:
def get_value(d, k1, k2): def get_value(d, k1, k2):
return d.get(k1, d.get(k2)) return d.get(k1, d.get(k2))
for ref in conv.reference: for ref in conv.reference:
if isinstance(ref, list): if isinstance(ref, list):
continue continue
ref["chunks"] = [{ ref["chunks"] = [
"id": get_value(ck, "chunk_id", "id"), {
"content": get_value(ck, "content", "content_with_weight"), "id": get_value(ck, "chunk_id", "id"),
"document_id": get_value(ck, "doc_id", "document_id"), "content": get_value(ck, "content", "content_with_weight"),
"document_name": get_value(ck, "docnm_kwd", "document_name"), "document_id": get_value(ck, "doc_id", "document_id"),
"dataset_id": get_value(ck, "kb_id", "dataset_id"), "document_name": get_value(ck, "docnm_kwd", "document_name"),
"image_id": get_value(ck, "image_id", "img_id"), "dataset_id": get_value(ck, "kb_id", "dataset_id"),
"positions": get_value(ck, "positions", "position_int"), "image_id": get_value(ck, "image_id", "img_id"),
} for ck in ref.get("chunks", [])] "positions": get_value(ck, "positions", "position_int"),
}
for ck in ref.get("chunks", [])
]
if not conv.reference: if not conv.reference:
conv.reference = [] conv.reference = []
conv.reference.append({"chunks": [], "doc_aggs": []}) conv.reference.append({"chunks": [], "doc_aggs": []})
def stream(): def stream():
nonlocal dia, msg, req, conv nonlocal dia, msg, req, conv
try: try:
@ -235,9 +229,7 @@ def completion():
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield "data:" + json.dumps({"code": 500, "message": str(e), yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True): if req.get("stream", True):
@ -259,7 +251,32 @@ def completion():
return server_error_response(e) return server_error_response(e)
@manager.route('/tts', methods=['POST']) # noqa: F821 # 用于文档撰写模式的问答调用
@manager.route("/writechat", methods=["POST"]) # type: ignore # noqa: F821
@login_required
@validate_request("question", "kb_ids")
def writechat():
req = request.json
uid = current_user.id
def stream():
nonlocal req, uid
try:
for ans in write_dialog(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
@manager.route("/tts", methods=["POST"]) # type: ignore # noqa: F821
@login_required @login_required
def tts(): def tts():
req = request.json req = request.json
@ -281,9 +298,7 @@ def tts():
for chunk in tts_mdl.tts(txt): for chunk in tts_mdl.tts(txt):
yield chunk yield chunk
except Exception as e: except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e), yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
"data": {"answer": "**ERROR**: " + str(e)}},
ensure_ascii=False)).encode('utf-8')
resp = Response(stream_audio(), mimetype="audio/mpeg") resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache") resp.headers.add_header("Cache-Control", "no-cache")
@ -293,7 +308,7 @@ def tts():
return resp return resp
@manager.route('/delete_msg', methods=['POST']) # noqa: F821 @manager.route("/delete_msg", methods=["POST"]) # type: ignore # noqa: F821
@login_required @login_required
@validate_request("conversation_id", "message_id") @validate_request("conversation_id", "message_id")
def delete_msg(): def delete_msg():
@ -316,7 +331,7 @@ def delete_msg():
return get_json_result(data=conv) return get_json_result(data=conv)
@manager.route('/thumbup', methods=['POST']) # noqa: F821 @manager.route("/thumbup", methods=["POST"]) # type: ignore # noqa: F821
@login_required @login_required
@validate_request("conversation_id", "message_id") @validate_request("conversation_id", "message_id")
def thumbup(): def thumbup():
@ -343,7 +358,7 @@ def thumbup():
return get_json_result(data=conv) return get_json_result(data=conv)
@manager.route('/ask', methods=['POST']) # noqa: F821 @manager.route("/ask", methods=["POST"]) # type: ignore # noqa: F821
@login_required @login_required
@validate_request("question", "kb_ids") @validate_request("question", "kb_ids")
def ask_about(): def ask_about():
@ -356,9 +371,7 @@ def ask_about():
for ans in ask(req["question"], req["kb_ids"], uid): for ans in ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e: except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream") resp = Response(stream(), mimetype="text/event-stream")
@ -369,7 +382,7 @@ def ask_about():
return resp return resp
@manager.route('/mindmap', methods=['POST']) # noqa: F821 @manager.route("/mindmap", methods=["POST"]) # type: ignore # noqa: F821
@login_required @login_required
@validate_request("question", "kb_ids") @validate_request("question", "kb_ids")
def mindmap(): def mindmap():
@ -382,10 +395,7 @@ def mindmap():
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
question = req["question"] question = req["question"]
ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, 0.3, 0.3, aggs=False, rank_feature=label_question(question, [kb]))
0.3, 0.3, aggs=False,
rank_feature=label_question(question, [kb])
)
mindmap = MindMapExtractor(chat_mdl) mindmap = MindMapExtractor(chat_mdl)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]]) mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
mind_map = mind_map.output mind_map = mind_map.output
@ -394,7 +404,7 @@ def mindmap():
return get_json_result(data=mind_map) return get_json_result(data=mind_map)
@manager.route('/related_questions', methods=['POST']) # noqa: F821 @manager.route("/related_questions", methods=["POST"]) # type: ignore # noqa: F821
@login_required @login_required
@validate_request("question") @validate_request("question")
def related_questions(): def related_questions():
@ -425,8 +435,17 @@ Reason:
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results. - At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
""" """
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f""" ans = chat_mdl.chat(
prompt,
[
{
"role": "user",
"content": f"""
Keywords: {question} Keywords: {question}
Related search terms: Related search terms:
"""}], {"temperature": 0.9}) """,
}
],
{"temperature": 0.9},
)
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])

View File

@ -30,8 +30,7 @@ from api import settings
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, \ from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, citation_prompt
citation_prompt
from rag.utils import rmSpace, num_tokens_from_string from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.tavily_conn import Tavily from rag.utils.tavily_conn import Tavily
@ -41,17 +40,13 @@ class DialogService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_list(cls, tenant_id, def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
page_number, items_per_page, orderby, desc, id, name):
chats = cls.model.select() chats = cls.model.select()
if id: if id:
chats = chats.where(cls.model.id == id) chats = chats.where(cls.model.id == id)
if name: if name:
chats = chats.where(cls.model.name == name) chats = chats.where(cls.model.name == name)
chats = chats.where( chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value))
(cls.model.tenant_id == tenant_id)
& (cls.model.status == StatusEnum.VALID.value)
)
if desc: if desc:
chats = chats.order_by(cls.model.getter_by(orderby).desc()) chats = chats.order_by(cls.model.getter_by(orderby).desc())
else: else:
@ -72,13 +67,12 @@ def chat_solo(dialog, messages, stream=True):
tts_mdl = None tts_mdl = None
if prompt_config.get("tts"): if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
for m in messages if m["role"] != "system"]
if stream: if stream:
last_ans = "" last_ans = ""
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans answer = ans
delta_ans = ans[len(last_ans):] delta_ans = ans[len(last_ans) :]
if num_tokens_from_string(delta_ans) < 16: if num_tokens_from_string(delta_ans) < 16:
continue continue
last_ans = answer last_ans = answer
@ -159,9 +153,8 @@ def chat(dialog, messages, stream=True, **kwargs):
if p["key"] not in kwargs and not p["optional"]: if p["key"] not in kwargs and not p["optional"]:
raise KeyError("Miss parameter: " + p["key"]) raise KeyError("Miss parameter: " + p["key"])
if p["key"] not in kwargs: if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace( prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
"{%s}" % p["key"], " ")
# 不再使用多轮对话优化 # 不再使用多轮对话优化
# if len(questions) > 1 and prompt_config.get("refine_multiturn"): # if len(questions) > 1 and prompt_config.get("refine_multiturn"):
# questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] # questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
@ -190,7 +183,7 @@ def chat(dialog, messages, stream=True, **kwargs):
tenant_ids = list(set([kb.tenant_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs]))
knowledges = [] knowledges = []
# 不再使用推理 # 不再使用推理
# if prompt_config.get("reasoning", False): # if prompt_config.get("reasoning", False):
# reasoner = DeepResearcher(chat_mdl, # reasoner = DeepResearcher(chat_mdl,
@ -226,17 +219,24 @@ def chat(dialog, messages, stream=True, **kwargs):
# kbinfos["chunks"].insert(0, ck) # kbinfos["chunks"].insert(0, ck)
# knowledges = kb_prompt(kbinfos, max_tokens) # knowledges = kb_prompt(kbinfos, max_tokens)
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, kbinfos = retriever.retrieval(
dialog.similarity_threshold, " ".join(questions),
dialog.vector_similarity_weight, embd_mdl,
doc_ids=attachments, tenant_ids,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, dialog.kb_ids,
rank_feature=label_question(" ".join(questions), kbs) 1,
) dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs),
)
knowledges = kb_prompt(kbinfos, max_tokens) knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug( logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_ts = timer() retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"): if not knowledges and prompt_config.get("empty_response"):
@ -252,22 +252,19 @@ def chat(dialog, messages, stream=True, **kwargs):
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
prompt4citation = citation_prompt() prompt4citation = citation_prompt()
# 过滤掉 system 角色的消息(因为前面已经单独处理了系统消息) # 过滤掉 system 角色的消息(因为前面已经单独处理了系统消息)
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
for m in messages if m["role"] != "system"])
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95)) used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
assert len(msg) >= 2, f"message_fit_in has bug: {msg}" assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
prompt = msg[0]["content"] prompt = msg[0]["content"]
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min( gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
gen_conf["max_tokens"],
max_tokens - used_token_count)
def decorate_answer(answer): def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
refs = [] refs = []
image_markdowns = [] # 用于存储图片的 Markdown 字符串 image_markdowns = [] # 用于存储图片的 Markdown 字符串
ans = answer.split("</think>") ans = answer.split("</think>")
think = "" think = ""
if len(ans) == 2: if len(ans) == 2:
@ -275,29 +272,29 @@ def chat(dialog, messages, stream=True, **kwargs):
answer = ans[1] answer = ans[1]
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL) answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
cited_chunk_indices = set() # 用于存储被引用的 chunk 索引 cited_chunk_indices = set() # 用于存储被引用的 chunk 索引
if not re.search(r"##[0-9]+\$\$", answer): if not re.search(r"##[0-9]+\$\$", answer):
answer, idx = retriever.insert_citations(answer, answer, idx = retriever.insert_citations(
[ck["content_ltks"] answer,
for ck in kbinfos["chunks"]], [ck["content_ltks"] for ck in kbinfos["chunks"]],
[ck["vector"] [ck["vector"] for ck in kbinfos["chunks"]],
for ck in kbinfos["chunks"]], embd_mdl,
embd_mdl, tkweight=1 - dialog.vector_similarity_weight,
tkweight=1 - dialog.vector_similarity_weight, vtweight=dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight) )
cited_chunk_indices = idx # 获取 insert_citations 返回的索引 cited_chunk_indices = idx # 获取 insert_citations 返回的索引
else: else:
idx = set([]) idx = set([])
for r in re.finditer(r"##([0-9]+)\$\$", answer): for r in re.finditer(r"##([0-9]+)\$\$", answer):
i = int(r.group(1)) i = int(r.group(1))
if i < len(kbinfos["chunks"]): if i < len(kbinfos["chunks"]):
idx.add(i) idx.add(i)
cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引 cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引
# 根据引用的 chunk 索引提取图像信息并生成 Markdown # 根据引用的 chunk 索引提取图像信息并生成 Markdown
cited_doc_ids = set() cited_doc_ids = set()
processed_image_urls = set() # 避免重复添加同一张图片 processed_image_urls = set() # 避免重复添加同一张图片
print(f"DEBUG: cited_chunk_indices = {cited_chunk_indices}") print(f"DEBUG: cited_chunk_indices = {cited_chunk_indices}")
for i in cited_chunk_indices: for i in cited_chunk_indices:
i_int = int(i) i_int = int(i)
@ -312,11 +309,10 @@ def chat(dialog, messages, stream=True, **kwargs):
# 生成 Markdown 字符串alt text 可以简单设为 "image" 或 chunk ID # 生成 Markdown 字符串alt text 可以简单设为 "image" 或 chunk ID
alt_text = f"image_chunk_{chunk.get('chunk_id', i_int)}" alt_text = f"image_chunk_{chunk.get('chunk_id', i_int)}"
image_markdowns.append(f"\n![{alt_text}]({img_url})") image_markdowns.append(f"\n![{alt_text}]({img_url})")
processed_image_urls.add(img_url) # 标记为已处理 processed_image_urls.add(img_url) # 标记为已处理
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [ recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs: if not recall_docs:
recall_docs = kbinfos["doc_aggs"] recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs kbinfos["doc_aggs"] = recall_docs
@ -325,7 +321,7 @@ def chat(dialog, messages, stream=True, **kwargs):
for c in refs["chunks"]: for c in refs["chunks"]:
if c.get("vector"): if c.get("vector"):
del c["vector"] del c["vector"]
# 将图片的 Markdown 字符串追加到回答末尾 # 将图片的 Markdown 字符串追加到回答末尾
if image_markdowns: if image_markdowns:
answer += "".join(image_markdowns) answer += "".join(image_markdowns)
@ -347,30 +343,30 @@ def chat(dialog, messages, stream=True, **kwargs):
prompt += "\n\n### Query:\n%s" % " ".join(questions) prompt += "\n\n### Query:\n%s" % " ".join(questions)
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms" prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()} return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
if stream: if stream:
last_ans = "" # 记录上一次返回的完整回答 last_ans = "" # 记录上一次返回的完整回答
answer = "" # 当前累计的完整回答 answer = "" # 当前累计的完整回答
for ans in chat_mdl.chat_streamly(prompt+prompt4citation, msg[1:], gen_conf): for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
# 如果存在思考过程(thought),移除相关标记 # 如果存在思考过程(thought),移除相关标记
if thought: if thought:
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
answer = ans answer = ans
# 计算新增的文本片段(delta) # 计算新增的文本片段(delta)
delta_ans = ans[len(last_ans):] delta_ans = ans[len(last_ans) :]
# 如果新增token太少(小于16),跳过本次返回(避免频繁发送小片段) # 如果新增token太少(小于16),跳过本次返回(避免频繁发送小片段)
if num_tokens_from_string(delta_ans) < 16: if num_tokens_from_string(delta_ans) < 16:
continue continue
last_ans = answer last_ans = answer
# 返回当前累计回答(包含思考过程)+新增片段) # 返回当前累计回答(包含思考过程)+新增片段)
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans):] delta_ans = answer[len(last_ans) :]
if delta_ans: if delta_ans:
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought+answer) yield decorate_answer(thought + answer)
else: else:
answer = chat_mdl.chat(prompt+prompt4citation, msg[1:], gen_conf) answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]") user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer)) logging.debug("User: {}|Assistant: {}".format(user_content, answer))
res = decorate_answer(answer) res = decorate_answer(answer)
@ -388,27 +384,22 @@ Table of database fields are as follows:
Question are as follows: Question are as follows:
{} {}
Please write the SQL, only SQL, without any other explanations or text. Please write the SQL, only SQL, without any other explanations or text.
""".format( """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question
)
tried_times = 0 tried_times = 0
def get_table(): def get_table():
nonlocal sys_prompt, user_prompt, question, tried_times nonlocal sys_prompt, user_prompt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], { sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
"temperature": 0.06})
sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL) sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL)
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower()) sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower()) sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql) sql = re.sub(r" +", " ", sql)
sql = re.sub(r"([;]|```).*", "", sql) sql = re.sub(r"([;]|```).*", "", sql)
if sql[:len("select ")] != "select ": if sql[: len("select ")] != "select ":
return None, None return None, None
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()): if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
if sql[:len("select *")] != "select *": if sql[: len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:] sql = "select doc_id,docnm_kwd," + sql[6:]
else: else:
flds = [] flds = []
@ -445,11 +436,7 @@ Please write the SQL, only SQL, without any other explanations or text.
{} {}
Please correct the error and write SQL again, only SQL, without any other explanations or text. Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format( """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question, sql, tbl["error"]
)
tbl, sql = get_table() tbl, sql = get_table()
logging.debug("TRY it again: {}".format(sql)) logging.debug("TRY it again: {}".format(sql))
@ -457,24 +444,18 @@ Please write the SQL, only SQL, without any other explanations or text.
if tbl.get("error") or len(tbl["rows"]) == 0: if tbl.get("error") or len(tbl["rows"]) == 0:
return None return None
docid_idx = set([ii for ii, c in enumerate( docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
tbl["columns"]) if c["name"] == "doc_id"]) doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
doc_name_idx = set([ii for ii, c in enumerate( column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
tbl["columns"]) if c["name"] == "docnm_kwd"])
column_idx = [ii for ii in range(
len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
# compose Markdown table # compose Markdown table
columns = "|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], columns = (
tbl["columns"][i]["name"])) for i in "|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") )
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \ line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
("|------|" if docid_idx and docid_idx else "")
rows = ["|" + rows = ["|" + "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
"|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
"|" for r in tbl["rows"]]
rows = [r for r in rows if re.sub(r"[ |]+", "", r)] rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
if quota: if quota:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
@ -484,11 +465,7 @@ Please write the SQL, only SQL, without any other explanations or text.
if not docid_idx or not doc_name_idx: if not docid_idx or not doc_name_idx:
logging.warning("SQL missing field: " + sql) logging.warning("SQL missing field: " + sql)
return { return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
"answer": "\n".join([columns, line, rows]),
"reference": {"chunks": [], "doc_aggs": []},
"prompt": sys_prompt
}
docid_idx = list(docid_idx)[0] docid_idx = list(docid_idx)[0]
doc_name_idx = list(doc_name_idx)[0] doc_name_idx = list(doc_name_idx)[0]
@ -499,10 +476,11 @@ Please write the SQL, only SQL, without any other explanations or text.
doc_aggs[r[docid_idx]]["count"] += 1 doc_aggs[r[docid_idx]]["count"] += 1
return { return {
"answer": "\n".join([columns, line, rows]), "answer": "\n".join([columns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], "reference": {
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
doc_aggs.items()]}, "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()],
"prompt": sys_prompt },
"prompt": sys_prompt,
} }
@ -518,12 +496,12 @@ def tts(tts_mdl, text):
def ask(question, kb_ids, tenant_id): def ask(question, kb_ids, tenant_id):
""" """
处理用户搜索请求从知识库中检索相关信息并生成回答 处理用户搜索请求从知识库中检索相关信息并生成回答
参数: 参数:
question (str): 用户的问题或查询 question (str): 用户的问题或查询
kb_ids (list): 知识库ID列表指定要搜索的知识库 kb_ids (list): 知识库ID列表指定要搜索的知识库
tenant_id (str): 租户ID用于权限控制和资源隔离 tenant_id (str): 租户ID用于权限控制和资源隔离
流程: 流程:
1. 获取指定知识库的信息 1. 获取指定知识库的信息
2. 确定使用的嵌入模型 2. 确定使用的嵌入模型
@ -534,11 +512,11 @@ def ask(question, kb_ids, tenant_id):
7. 构建系统提示词 7. 构建系统提示词
8. 生成回答并添加引用标记 8. 生成回答并添加引用标记
9. 流式返回生成的回答 9. 流式返回生成的回答
返回: 返回:
generator: 生成器对象产生包含回答和引用信息的字典 generator: 生成器对象产生包含回答和引用信息的字典
""" """
kbs = KnowledgebaseService.get_by_ids(kb_ids) kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs])) embedding_list = list(set([kb.embd_id for kb in kbs]))
@ -552,27 +530,24 @@ def ask(question, kb_ids, tenant_id):
max_tokens = chat_mdl.max_length max_tokens = chat_mdl.max_length
# 获取所有知识库的租户ID并去重 # 获取所有知识库的租户ID并去重
tenant_ids = list(set([kb.tenant_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs]))
# 调用检索器检索相关文档片段 # 调用检索器检索相关文档片段
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))
1, 12, 0.1, 0.3, aggs=False, # 将检索结果格式化为提示词并确保不超过模型最大token限制
rank_feature=label_question(question, kbs)
)
# 将检索结果格式化为提示词并确保不超过模型最大token限制
knowledges = kb_prompt(kbinfos, max_tokens) knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """ prompt = """
Role: You're a smart assistant. Your name is Miss R. 角色你是一个聪明的助手
Task: Summarize the information from knowledge bases and answer user's question. 任务总结知识库中的信息并回答用户的问题
Requirements and restriction: 要求与限制
- DO NOT make things up, especially for numbers. - 绝不要捏造内容尤其是数字
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided. - 如果知识库中的信息与用户问题无关**只需回答对不起未提供相关信息
- Answer with markdown format text. - 使用Markdown格式进行回答
- Answer in language of user's question. - 使用用户提问所用的语言作答
- DO NOT make things up, especially for numbers. - 绝不要捏造内容尤其是数字
### Information from knowledge bases ### 来自知识库的信息
%s %s
The above is information from knowledge bases. 以上是来自知识库的信息
""" % "\n".join(knowledges) """ % "\n".join(knowledges)
msg = [{"role": "user", "content": question}] msg = [{"role": "user", "content": question}]
@ -580,17 +555,9 @@ def ask(question, kb_ids, tenant_id):
# 生成完成后添加回答中的引用标记 # 生成完成后添加回答中的引用标记
def decorate_answer(answer): def decorate_answer(answer):
nonlocal knowledges, kbinfos, prompt nonlocal knowledges, kbinfos, prompt
answer, idx = retriever.insert_citations(answer, answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=0.7,
vtweight=0.3)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [ recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs: if not recall_docs:
recall_docs = kbinfos["doc_aggs"] recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs kbinfos["doc_aggs"] = recall_docs
@ -608,4 +575,4 @@ def ask(question, kb_ids, tenant_id):
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
answer = ans answer = ans
yield {"answer": answer, "reference": {}} yield {"answer": answer, "reference": {}}
yield decorate_answer(answer) yield decorate_answer(answer)

View File

@ -0,0 +1,91 @@
from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
from rag.app.tag import label_question
from rag.prompts import kb_prompt
def write_dialog(question, kb_ids, tenant_id):
"""
处理用户搜索请求从知识库中检索相关信息并生成回答
参数:
question (str): 用户的问题或查询
kb_ids (list): 知识库ID列表指定要搜索的知识库
tenant_id (str): 租户ID用于权限控制和资源隔离
流程:
1. 获取指定知识库的信息
2. 确定使用的嵌入模型
3. 根据知识库类型选择检索器(普通检索器或知识图谱检索器)
4. 初始化嵌入模型和聊天模型
5. 执行检索操作获取相关文档片段
6. 格式化知识库内容作为上下文
7. 构建系统提示词
8. 生成回答并添加引用标记
9. 流式返回生成的回答
返回:
generator: 生成器对象产生包含回答和引用信息的字典
"""
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
# 初始化嵌入模型,用于将文本转换为向量表示
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
# 初始化聊天模型,用于生成回答
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
# 获取聊天模型的最大token长度用于控制上下文长度
max_tokens = chat_mdl.max_length
# 获取所有知识库的租户ID并去重
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
# 调用检索器检索相关文档片段
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))
# 将检索结果格式化为提示词并确保不超过模型最大token限制
knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """
角色你是一个聪明的助手
任务总结知识库中的信息并回答用户的问题
要求与限制
- 绝不要捏造内容尤其是数字
- 如果知识库中的信息与用户问题无关**只需回答对不起未提供相关信息
- 使用Markdown格式进行回答
- 使用用户提问所用的语言作答
- 绝不要捏造内容尤其是数字
### 来自知识库的信息
%s
以上是来自知识库的信息
""" % "\n".join(knowledges)
msg = [{"role": "user", "content": question}]
# 生成完成后添加回答中的引用标记
# def decorate_answer(answer):
# nonlocal knowledges, kbinfos, prompt
# answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
# idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
# recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
# if not recall_docs:
# recall_docs = kbinfos["doc_aggs"]
# kbinfos["doc_aggs"] = recall_docs
# refs = deepcopy(kbinfos)
# for c in refs["chunks"]:
# if c.get("vector"):
# del c["vector"]
# if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
# answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
# refs["chunks"] = chunks_format(refs)
# return {"answer": answer, "reference": refs}
answer = ""
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
answer = ans
yield {"answer": answer, "reference": {}}
# yield decorate_answer(answer)

View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
<svg fill="#000000" width="800px" height="800px" viewBox="0 0 36 36" version="1.1" preserveAspectRatio="xMidYMid meet" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>storage-solid</title>
<path class="clr-i-solid clr-i-solid-path-1" d="M17.91,18.28c8.08,0,14.66-1.74,15.09-3.94V8.59c-.43,2.2-7,3.94-15.09,3.94A39.4,39.4,0,0,1,6.25,11V9a39.4,39.4,0,0,0,11.66,1.51C26,10.53,32.52,8.79,33,6.61h0C32.8,3.2,23.52,2.28,18,2.28S3,3.21,3,6.71V29.29c0,3.49,9.43,4.43,15,4.43s15-.93,15-4.43V24.09C32.57,26.28,26,28,17.91,28A39.4,39.4,0,0,1,6.25,26.52v-2A39.4,39.4,0,0,0,17.91,26C26,26,32.57,24.28,33,22.09V16.34c-.43,2.2-7,3.94-15.09,3.94A39.4,39.4,0,0,1,6.25,18.77v-2A39.4,39.4,0,0,0,17.91,18.28Z"></path>
<rect x="0" y="0" width="36" height="36" fill-opacity="0"/>
</svg>

After

Width:  |  Height:  |  Size: 934 B

View File

@ -0,0 +1,123 @@
import { Authorization } from '@/constants/authorization';
import { IAnswer } from '@/interfaces/database/chat';
import { IKnowledge } from '@/interfaces/database/knowledge';
import kbService from '@/services/knowledge-service';
import api from '@/utils/api';
import { getAuthorization } from '@/utils/authorization-util';
import { useQuery } from '@tanstack/react-query';
import { EventSourceParserStream } from 'eventsource-parser/stream';
import { useCallback, useRef, useState } from 'react';
// 查询知识库数据
export const useFetchKnowledgeList = (
shouldFilterListWithoutDocument: boolean = false,
): {
list: IKnowledge[];
loading: boolean;
} => {
const { data, isFetching: loading } = useQuery({
queryKey: ['fetchKnowledgeList'],
initialData: [],
gcTime: 0,
queryFn: async () => {
const { data } = await kbService.getList();
const list = data?.data?.kbs ?? [];
return shouldFilterListWithoutDocument
? list.filter((x: IKnowledge) => x.chunk_num > 0)
: list;
},
});
return { list: data, loading };
};
// 发送问答信息
export const useSendMessageWithSse = (url: string = api.writeChat) => {
const [answer, setAnswer] = useState<IAnswer>({} as IAnswer);
const [done, setDone] = useState(true);
const timer = useRef<any>();
const sseRef = useRef<AbortController>();
const initializeSseRef = useCallback(() => {
sseRef.current = new AbortController();
}, []);
const resetAnswer = useCallback(() => {
if (timer.current) {
clearTimeout(timer.current);
}
timer.current = setTimeout(() => {
setAnswer({} as IAnswer);
clearTimeout(timer.current);
}, 1000);
}, []);
const send = useCallback(
async (
body: any,
controller?: AbortController,
): Promise<{ response: Response; data: ResponseType } | undefined> => {
initializeSseRef();
try {
setDone(false);
const response = await fetch(url, {
method: 'POST',
headers: {
[Authorization]: getAuthorization(),
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
signal: controller?.signal || sseRef.current?.signal,
});
const res = response.clone().json();
const reader = response?.body
?.pipeThrough(new TextDecoderStream())
.pipeThrough(new EventSourceParserStream())
.getReader();
while (true) {
const x = await reader?.read();
if (x) {
const { done, value } = x;
if (done) {
console.info('done');
resetAnswer();
break;
}
try {
const val = JSON.parse(value?.data || '');
const d = val?.data;
if (typeof d !== 'boolean') {
console.info('data:', d);
setAnswer({
...d,
conversationId: body?.conversation_id,
});
}
} catch (e) {
console.warn(e);
}
}
}
console.info('done?');
setDone(true);
resetAnswer();
return { data: await res, response };
} catch (e) {
setDone(true);
resetAnswer();
console.warn(e);
}
},
[initializeSseRef, url, resetAnswer],
);
const stopOutputMessage = useCallback(() => {
sseRef.current?.abort();
}, []);
return { send, answer, done, setDone, resetAnswer, stopOutputMessage };
};

View File

@ -591,6 +591,7 @@ General实体和关系提取提示来自 GitHub - microsoft/graphrag基于
decisions: '决定事项', decisions: '决定事项',
actionItems: '行动项', actionItems: '行动项',
nextMeeting: '下次会议', nextMeeting: '下次会议',
noTemplatesAvailable: "没有可用模板",
// 模型配置相关 // 模型配置相关
modelConfigurationTitle: "模型配置", modelConfigurationTitle: "模型配置",
knowledgeBaseLabel: "知识库", knowledgeBaseLabel: "知识库",
@ -601,6 +602,7 @@ General实体和关系提取提示来自 GitHub - microsoft/graphrag基于
fetchKnowledgeBaseFailed: "获取知识库列表失败", fetchKnowledgeBaseFailed: "获取知识库列表失败",
defaultKnowledgeBase: "默认知识库", defaultKnowledgeBase: "默认知识库",
technicalDocsKnowledgeBase: "技术文档知识库", technicalDocsKnowledgeBase: "技术文档知识库",
aiRequestFailedError: "问答模型请求失败",
}, },
setting: { setting: {
profile: '概要', profile: '概要',

View File

@ -1,8 +1,9 @@
import HightLightMarkdown from '@/components/highlight-markdown'; import HightLightMarkdown from '@/components/highlight-markdown';
import { useTranslate } from '@/hooks/common-hooks'; import { useTranslate } from '@/hooks/common-hooks';
// 假设 aiAssistantConfig 在实际项目中是正确导入的 import {
// import { aiAssistantConfig } from '@/pages/write/ai-assistant-config'; useFetchKnowledgeList,
const aiAssistantConfig = { api: { timeout: 30000 } }; // 模拟定义 useSendMessageWithSse,
} from '@/hooks/write-hooks';
import { DeleteOutlined } from '@ant-design/icons'; import { DeleteOutlined } from '@ant-design/icons';
import { import {
@ -22,7 +23,6 @@ import {
Space, Space,
Typography, Typography,
} from 'antd'; } from 'antd';
import axios from 'axios';
import { import {
AlignmentType, AlignmentType,
Document, Document,
@ -32,7 +32,7 @@ import {
TextRun, TextRun,
} from 'docx'; } from 'docx';
import { saveAs } from 'file-saver'; import { saveAs } from 'file-saver';
import { marked, Token, Tokens } from 'marked'; // 从 marked 导入 Token 和 Tokens 类型 import { marked, Token, Tokens } from 'marked';
import { useCallback, useEffect, useRef, useState } from 'react'; import { useCallback, useEffect, useRef, useState } from 'react';
const { Sider, Content } = Layout; const { Sider, Content } = Layout;
@ -53,22 +53,28 @@ interface KnowledgeBaseItem {
name: string; name: string;
} }
// 使用 marked 导出的类型或更精确的自定义类型
type MarkedHeadingToken = Tokens.Heading; type MarkedHeadingToken = Tokens.Heading;
type MarkedParagraphToken = Tokens.Paragraph; type MarkedParagraphToken = Tokens.Paragraph;
type MarkedListItem = Tokens.ListItem; type MarkedListItem = Tokens.ListItem;
type MarkedListToken = Tokens.List; type MarkedListToken = Tokens.List;
type MarkedSpaceToken = Tokens.Space; type MarkedSpaceToken = Tokens.Space;
// 定义插入点标记以便在onChange时识别并移除
// const INSERTION_MARKER = '【AI内容将插入此处】';
const INSERTION_MARKER = ''; // 保持为空字符串,不显示实际标记
const Write = () => { const Write = () => {
const { t } = useTranslate('write'); const { t } = useTranslate('write');
const [content, setContent] = useState(''); const [content, setContent] = useState('');
const [aiQuestion, setAiQuestion] = useState(''); const [aiQuestion, setAiQuestion] = useState('');
const [isAiLoading, setIsAiLoading] = useState(false); const [isAiLoading, setIsAiLoading] = useState(false);
const [dialogId, setDialogId] = useState(''); const [dialogId] = useState('');
// cursorPosition 存储用户点击设定的插入点位置
const [cursorPosition, setCursorPosition] = useState<number | null>(null); const [cursorPosition, setCursorPosition] = useState<number | null>(null);
// showCursorIndicator 现在仅用于控制文档中是否显示 'INSERTION_MARKER'
// 并且一旦设置了光标位置,就希望它保持为 true除非内容被清空或主动重置。
const [showCursorIndicator, setShowCursorIndicator] = useState(false); const [showCursorIndicator, setShowCursorIndicator] = useState(false);
const textAreaRef = useRef<HTMLTextAreaElement>(null); const textAreaRef = useRef<any>(null); // Ant Design Input.TextArea 的 ref 类型
const [templates, setTemplates] = useState<TemplateItem[]>([]); const [templates, setTemplates] = useState<TemplateItem[]>([]);
const [isTemplateModalVisible, setIsTemplateModalVisible] = useState(false); const [isTemplateModalVisible, setIsTemplateModalVisible] = useState(false);
@ -87,6 +93,27 @@ const Write = () => {
const [modelTemperature, setModelTemperature] = useState<number>(0.7); const [modelTemperature, setModelTemperature] = useState<number>(0.7);
const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBaseItem[]>([]); const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBaseItem[]>([]);
const [isLoadingKbs, setIsLoadingKbs] = useState(false); const [isLoadingKbs, setIsLoadingKbs] = useState(false);
const [isStreaming, setIsStreaming] = useState(false); // 标记AI是否正在流式输出
// 新增状态和 useRef用于流式输出管理
// currentStreamedAiOutput 现在将直接接收 useSendMessageWithSse 返回的累积内容
const [currentStreamedAiOutput, setCurrentStreamedAiOutput] = useState('');
// 使用 useRef 存储 AI 插入点前后的内容,以及插入点位置,避免在流式更新中出现闭包陷阱
const contentBeforeAiInsertionRef = useRef('');
const contentAfterAiInsertionRef = useRef('');
const aiInsertionStartPosRef = useRef<number | null>(null);
// 使用 useFetchKnowledgeList hook 获取真实数据
const { list: knowledgeList, loading: isLoadingKnowledgeList } =
useFetchKnowledgeList(true);
// 使用流式消息发送钩子
const {
send: sendMessage,
answer,
done,
stopOutputMessage,
} = useSendMessageWithSse();
const getInitialDefaultTemplateDefinitions = useCallback( const getInitialDefaultTemplateDefinitions = useCallback(
(): TemplateItem[] => [ (): TemplateItem[] => [
@ -171,70 +198,102 @@ const Write = () => {
loadOrInitializeTemplates(); loadOrInitializeTemplates();
}, [loadOrInitializeTemplates]); }, [loadOrInitializeTemplates]);
// 将 knowledgeList 数据同步到 knowledgeBases 状态
useEffect(() => { useEffect(() => {
const fetchKbs = async () => { if (knowledgeList && knowledgeList.length > 0) {
const authorization = localStorage.getItem('Authorization'); setKnowledgeBases(
if (!authorization) { knowledgeList.map((kb) => ({
setKnowledgeBases([]); id: kb.id,
return; name: kb.name,
})),
);
setIsLoadingKbs(isLoadingKnowledgeList);
}
}, [knowledgeList, isLoadingKnowledgeList]);
// --- 调整流式响应处理逻辑 ---
// 阶段1: 累积 AI 输出片段,用于实时显示(包括 <think> 标签)
// 这个 useEffect 确保 currentStreamedAiOutput 始终是实时更新的、包含 <think> 标签的完整内容
useEffect(() => {
if (isStreaming && answer && answer.answer) {
setCurrentStreamedAiOutput(answer.answer);
}
}, [isStreaming, answer]);
// 阶段2: 当流式输出完成时 (done 为 true)
// 这个 useEffect 负责在流式输出结束时执行清理和最终内容更新
useEffect(() => {
if (done) {
setIsStreaming(false);
setIsAiLoading(false);
// --- Process the final streamed AI output before committing ---
// 关键修改:这里**必须**使用 currentStreamedAiOutput因为它是在流式过程中积累的、包含 <think> 标签的内容
// answer.answer 可能在 done 阶段已经提前被钩子内部清理过,所以不能依赖它来获取带标签的原始内容。
let processedAiOutput = currentStreamedAiOutput;
if (processedAiOutput) {
// Regex to remove <think>...</think> including content
processedAiOutput = processedAiOutput.replace(
/<think>.*?<\/think>/gs,
'',
);
} }
setIsLoadingKbs(true); // --- END NEW ---
try {
await new Promise((resolve) => { // 将最终累积的AI内容已处理移除<think>标签)和初始文档内容拼接,更新到主内容状态
setTimeout(resolve, 500); setContent((prevContent) => {
}); if (aiInsertionStartPosRef.current !== null) {
const mockKbs: KnowledgeBaseItem[] = [ // 使用 useRef 中存储的初始内容和最终处理过的 AI 输出
{ const finalContent =
id: 'kb_default', contentBeforeAiInsertionRef.current +
name: t('defaultKnowledgeBase', { defaultValue: '默认知识库' }), processedAiOutput +
}, contentAfterAiInsertionRef.current;
{ return finalContent;
id: 'kb_tech', }
name: t('technicalDocsKnowledgeBase', { return prevContent;
defaultValue: '技术文档知识库', });
}),
}, // AI完成回答后将光标实际移到新内容末尾
{ if (
id: 'kb_product', textAreaRef.current?.resizableTextArea?.textArea &&
name: t('productInfoKnowledgeBase', { aiInsertionStartPosRef.current !== null
defaultValue: '产品信息知识库', ) {
}), const newCursorPos =
}, aiInsertionStartPosRef.current + processedAiOutput.length;
{ textAreaRef.current.resizableTextArea.textArea.selectionStart =
id: 'kb_marketing', newCursorPos;
name: t('marketingMaterialsKB', { defaultValue: '市场营销材料库' }), textAreaRef.current.resizableTextArea.textArea.selectionEnd =
}, newCursorPos;
{ textAreaRef.current.resizableTextArea.textArea.focus();
id: 'kb_legal', setCursorPosition(newCursorPos);
name: t('legalDocumentsKB', { defaultValue: '法律文件库' }),
},
];
setKnowledgeBases(mockKbs);
} catch (error) {
console.error('获取知识库失败:', error);
message.error(t('fetchKnowledgeBaseFailed'));
setKnowledgeBases([]);
} finally {
setIsLoadingKbs(false);
} }
};
fetchKbs(); // 清理流式相关的临时状态和 useRef
}, [t]); setCurrentStreamedAiOutput(''); // 清空累积内容
contentBeforeAiInsertionRef.current = '';
contentAfterAiInsertionRef.current = '';
aiInsertionStartPosRef.current = null;
setShowCursorIndicator(true);
}
}, [done, currentStreamedAiOutput]); // 依赖 done 和 currentStreamedAiOutput确保在 done 时拿到最新的 currentStreamedAiOutput
// 监听 currentStreamedAiOutput 的变化,实时更新主 content 状态以实现流式显示
useEffect(() => {
if (isStreaming && aiInsertionStartPosRef.current !== null) {
// 实时更新编辑器内容,保留 <think> 标签内容
setContent(
contentBeforeAiInsertionRef.current +
currentStreamedAiOutput +
contentAfterAiInsertionRef.current,
);
// 同时更新 cursorPosition让光标跟随 AI 输出移动(基于包含 think 标签的原始长度)
setCursorPosition(
aiInsertionStartPosRef.current + currentStreamedAiOutput.length,
);
}
}, [currentStreamedAiOutput, isStreaming, aiInsertionStartPosRef]);
useEffect(() => { useEffect(() => {
const fetchDialogs = async () => {
try {
const authorization = localStorage.getItem('Authorization');
if (!authorization) return;
const response = await axios.get('/v1/dialog', {
headers: { authorization },
});
if (response.data?.data?.length > 0)
setDialogId(response.data.data[0].id);
} catch (error) {
console.error('获取对话列表失败:', error);
}
};
const loadDraftContent = () => { const loadDraftContent = () => {
try { try {
const draftContent = localStorage.getItem('writeDraftContent'); const draftContent = localStorage.getItem('writeDraftContent');
@ -250,13 +309,13 @@ const Write = () => {
console.error('加载暂存内容失败:', error); console.error('加载暂存内容失败:', error);
} }
}; };
fetchDialogs();
if (localStorage.getItem(LOCAL_STORAGE_INIT_FLAG_KEY) === 'true') { if (localStorage.getItem(LOCAL_STORAGE_INIT_FLAG_KEY) === 'true') {
loadDraftContent(); loadDraftContent();
} }
}, [content, selectedTemplate, templates]); }, [content, selectedTemplate, templates]);
useEffect(() => { useEffect(() => {
// 防抖保存,防止频繁写入 localStorage
const timer = setTimeout( const timer = setTimeout(
() => localStorage.setItem('writeDraftContent', content), () => localStorage.setItem('writeDraftContent', content),
1000, 1000,
@ -302,6 +361,7 @@ const Write = () => {
} }
}; };
// 删除模板
const handleDeleteTemplate = (templateId: string) => { const handleDeleteTemplate = (templateId: string) => {
try { try {
const updatedTemplates = templates.filter((t) => t.id !== templateId); const updatedTemplates = templates.filter((t) => t.id !== templateId);
@ -326,6 +386,51 @@ const Write = () => {
} }
}; };
// 获取上下文内容的辅助函数
const getContextContent = (
cursorPos: number,
currentDocumentContent: string,
maxLength: number = 4000,
) => {
// 注意: 这里的 currentDocumentContent 传入的是 AI 提问时编辑器里的总内容,
// 而不是 contentBeforeAiInsertionRef + contentAfterAiInsertionRef因为可能包含标记
const beforeCursor = currentDocumentContent.substring(0, cursorPos);
const afterCursor = currentDocumentContent.substring(cursorPos);
// 使用更明显的插入点标记这个标记是给AI看的不是给用户看的
const insertMarker = '[AI 内容插入点]';
const availableLength = maxLength - insertMarker.length;
if (currentDocumentContent.length <= availableLength) {
return {
beforeCursor,
afterCursor,
contextContent: beforeCursor + insertMarker + afterCursor,
};
}
const halfLength = Math.floor(availableLength / 2);
let finalBefore = beforeCursor;
let finalAfter = afterCursor;
// 如果前半部分太长,截断并在前面加省略号
if (beforeCursor.length > halfLength) {
finalBefore =
'...' + beforeCursor.substring(beforeCursor.length - halfLength + 3);
}
// 如果后半部分太长,截断并在后面加省略号
if (afterCursor.length > halfLength) {
finalAfter = afterCursor.substring(0, halfLength - 3) + '...';
}
return {
beforeCursor,
afterCursor,
contextContent: finalBefore + insertMarker + finalAfter,
};
};
const handleAiQuestionSubmit = async ( const handleAiQuestionSubmit = async (
e: React.KeyboardEvent<HTMLTextAreaElement>, e: React.KeyboardEvent<HTMLTextAreaElement>,
) => { ) => {
@ -335,128 +440,106 @@ const Write = () => {
message.warning(t('enterYourQuestion')); message.warning(t('enterYourQuestion'));
return; return;
} }
if (!dialogId) {
message.error(t('noDialogFound')); // 检查是否选择了知识库
if (selectedKnowledgeBases.length === 0) {
message.warning('请至少选择一个知识库');
return; return;
} }
setIsAiLoading(true);
const initialCursorPos = cursorPosition; // 如果AI正在流式输出停止它并处理新问题
const originalContent = content; if (isStreaming) {
let beforeCursor = '', stopOutputMessage(); // 停止当前的流式输出
afterCursor = ''; setIsStreaming(false); // 立即设置为false中断流
if (initialCursorPos !== null && showCursorIndicator) { setIsAiLoading(false); // 确保加载状态也停止
beforeCursor = originalContent.substring(0, initialCursorPos);
afterCursor = originalContent.substring(initialCursorPos); // 中断时立即清除流中的 <think> 标签,并更新主内容
// 这里使用 currentStreamedAiOutput 作为基准来构建中断时的内容,
// 因为它是屏幕上实际显示的,包含了 <think> 标签。
const contentToCleanOnInterrupt =
contentBeforeAiInsertionRef.current +
currentStreamedAiOutput +
contentAfterAiInsertionRef.current;
const cleanedContent = contentToCleanOnInterrupt.replace(
/<think>.*?<\/think>/gs,
'',
);
setContent(cleanedContent);
setCurrentStreamedAiOutput(''); // 清除旧的流式内容
contentBeforeAiInsertionRef.current = ''; // 清理 useRef
contentAfterAiInsertionRef.current = '';
aiInsertionStartPosRef.current = null;
message.info('已中断上一次AI回答正在处理新问题...');
// 稍作延迟,确保状态更新后再处理新问题,防止竞态条件
await new Promise((resolve) => {
setTimeout(resolve, 100);
});
} }
const controller = new AbortController();
const timeoutId = setTimeout( // 如果当前光标位置无效,提醒用户设置
() => controller.abort(), if (cursorPosition === null) {
aiAssistantConfig.api.timeout || 30000, message.warning('请先点击文本框以设置AI内容插入位置。');
return;
}
// 捕获 AI 插入点前后的静态内容,存储到 useRef
const currentCursorPos = cursorPosition;
// 此时的 content 应该是用户当前编辑器的实际内容包括可能存在的INSERTION_MARKER
// 但由于 INSERTION_MARKER 为空,所以就是当前的主 content
contentBeforeAiInsertionRef.current = content.substring(
0,
currentCursorPos,
); );
contentAfterAiInsertionRef.current = content.substring(currentCursorPos);
aiInsertionStartPosRef.current = currentCursorPos; // 记录确切的开始插入位置
setIsAiLoading(true);
setIsStreaming(true); // 标记AI开始流式输出
setCurrentStreamedAiOutput(''); // 清空历史累积内容,为新的流做准备
try { try {
const authorization = localStorage.getItem('Authorization'); const authorization = localStorage.getItem('Authorization');
if (!authorization) { if (!authorization) {
message.error(t('loginRequiredError')); message.error(t('loginRequiredError'));
setIsAiLoading(false); setIsAiLoading(false);
setIsStreaming(false); // 停止流式标记
// 失败时也清理临时状态
setCurrentStreamedAiOutput('');
contentBeforeAiInsertionRef.current = '';
contentAfterAiInsertionRef.current = '';
aiInsertionStartPosRef.current = null;
return; return;
} }
const conversationId =
Math.random().toString(36).substring(2) + Date.now().toString(36); // 构建请求内容将上下文内容发送给AI
await axios.post( let questionWithContext = aiQuestion;
'v1/conversation/set',
{ // 只有当用户设置了插入位置时才包含上下文
dialog_id: dialogId, if (aiInsertionStartPosRef.current !== null) {
name: '文档撰写对话', // 传递给 getContextContent 的 content 应该是当前编辑器完整的包含marker的
is_new: true, const { contextContent } = getContextContent(
conversation_id: conversationId, aiInsertionStartPosRef.current,
message: [{ role: 'assistant', content: '新对话' }], content,
}, );
{ headers: { authorization }, signal: controller.signal }, questionWithContext = `${aiQuestion}\n\n上下文内容\n${contextContent}`;
); }
const combinedQuestion = `${aiQuestion}\n\n${t('currentDocumentContextLabel')}:\n${originalContent}`;
let lastReceivedContent = ''; // 发送流式请求
const response = await axios.post( await sendMessage({
'/v1/conversation/completion', question: questionWithContext,
{ kb_ids: selectedKnowledgeBases,
conversation_id: conversationId, dialog_id: dialogId,
messages: [{ role: 'user', content: combinedQuestion }], similarity_threshold: similarityThreshold,
knowledge_base_ids: keyword_similarity_weight: keywordSimilarityWeight,
selectedKnowledgeBases.length > 0 temperature: modelTemperature,
? selectedKnowledgeBases });
: undefined,
similarity_threshold: similarityThreshold, setAiQuestion(''); // 清空输入框
keyword_similarity_weight: keywordSimilarityWeight, // 重新聚焦文本框但不是AI问答框而是主编辑区
temperature: modelTemperature, if (textAreaRef.current?.resizableTextArea?.textArea) {
}, textAreaRef.current.resizableTextArea.textArea.focus();
{
timeout: aiAssistantConfig.api.timeout,
headers: { authorization },
signal: controller.signal,
},
);
if (response.data) {
const lines = response.data
.split('\n')
.filter((line: string) => line.trim());
for (let i = 0; i < lines.length; i++) {
try {
const jsonStr = lines[i].replace('data:', '').trim();
const jsonData = JSON.parse(jsonStr);
if (jsonData.code === 0 && jsonData.data?.answer) {
const answerChunk = jsonData.data.answer;
const cleanedAnswerChunk = answerChunk
.replace(/<think>[\s\S]*?<\/think>/g, '')
.trim();
const hasUnclosedThink =
cleanedAnswerChunk.includes('<think>') &&
(!cleanedAnswerChunk.includes('</think>') ||
cleanedAnswerChunk.indexOf('<think>') >
cleanedAnswerChunk.lastIndexOf('</think>'));
if (cleanedAnswerChunk && !hasUnclosedThink) {
const incrementalContent = cleanedAnswerChunk.substring(
lastReceivedContent.length,
);
if (incrementalContent) {
lastReceivedContent = cleanedAnswerChunk;
let newFullContent,
newCursorPosAfterInsertion = cursorPosition;
if (initialCursorPos !== null && showCursorIndicator) {
newFullContent =
beforeCursor + cleanedAnswerChunk + afterCursor;
newCursorPosAfterInsertion =
initialCursorPos + cleanedAnswerChunk.length;
} else {
newFullContent = originalContent + cleanedAnswerChunk;
newCursorPosAfterInsertion = newFullContent.length;
}
setContent(newFullContent);
setCursorPosition(newCursorPosAfterInsertion);
setTimeout(() => {
if (textAreaRef.current) {
textAreaRef.current.focus();
textAreaRef.current.setSelectionRange(
newCursorPosAfterInsertion!,
newCursorPosAfterInsertion!,
);
}
}, 0);
}
}
}
} catch (parseErr) {
console.error('解析单行数据失败:', parseErr);
}
if (i < lines.length - 1)
await new Promise((resolve) => {
setTimeout(resolve, 10);
});
}
} }
await axios.post(
'/v1/conversation/rm',
{ conversation_ids: [conversationId], dialog_id: dialogId },
{ headers: { authorization } },
);
} catch (error: any) { } catch (error: any) {
console.error('AI助手处理失败:', error); console.error('AI助手处理失败:', error);
if (error.code === 'ECONNABORTED' || error.name === 'AbortError') { if (error.code === 'ECONNABORTED' || error.name === 'AbortError') {
@ -469,14 +552,13 @@ const Write = () => {
message.error(t('aiRequestFailedError')); message.error(t('aiRequestFailedError'));
} }
} finally { } finally {
clearTimeout(timeoutId); // AI加载状态在 done 状态或错误处理中会更新,这里不主动设置为 false
setIsAiLoading(false); // 只有当 isStreaming 状态完全结束时,才彻底清除临时状态
setAiQuestion('');
if (textAreaRef.current) textAreaRef.current.focus();
} }
} }
}; };
// 导出为Word
const handleSave = () => { const handleSave = () => {
const selectedTemplateItem = templates.find( const selectedTemplateItem = templates.find(
(item) => item.id === selectedTemplate, (item) => item.id === selectedTemplate,
@ -674,33 +756,125 @@ const Write = () => {
} }
}; };
const renderEditor = () => ( // 修改编辑器渲染函数,添加光标标记
<Input.TextArea const renderEditor = () => {
ref={textAreaRef} let displayContent = content; // 默认显示主内容状态
style={{
height: '100%', // 如果 AI 正在流式输出,则动态拼接显示内容
width: '100%', if (isStreaming && aiInsertionStartPosRef.current !== null) {
border: 'none', // 实时显示时,保留 <think> 标签内容
padding: 24, displayContent =
fontSize: 16, contentBeforeAiInsertionRef.current +
resize: 'none', currentStreamedAiOutput +
}} contentAfterAiInsertionRef.current;
value={content} } else if (showCursorIndicator && cursorPosition !== null) {
onChange={(e) => setContent(e.target.value)} // 如果不处于流式输出中,但设置了光标,则显示插入标记
onClick={(e) => { // (由于 INSERTION_MARKER 为空字符串,这一步实际上不会添加可见标记)
const target = e.target as HTMLTextAreaElement; const beforeCursor = content.substring(0, cursorPosition);
setCursorPosition(target.selectionStart); const afterCursor = content.substring(cursorPosition);
setShowCursorIndicator(true); displayContent = beforeCursor + INSERTION_MARKER + afterCursor;
}} }
onKeyUp={(e) => {
const target = e.target as HTMLTextAreaElement; return (
setCursorPosition(target.selectionStart); <div style={{ position: 'relative', height: '100%', width: '100%' }}>
setShowCursorIndicator(true); <Input.TextArea
}} ref={textAreaRef}
placeholder={t('writePlaceholder')} style={{
autoSize={false} height: '100%',
/> width: '100%',
); border: 'none',
padding: 24,
fontSize: 16,
resize: 'none',
}}
value={displayContent} // 使用动态的 displayContent
onChange={(e) => {
const currentInputValue = e.target.value; // 获取当前输入框中的完整内容
const newCursorSelectionStart = e.target.selectionStart;
let finalContent = currentInputValue;
let finalCursorPosition = newCursorSelectionStart;
// 如果用户在 AI 流式输出时输入,则中断 AI 输出,并“固化”当前内容(清除 <think> 标签)
if (isStreaming) {
stopOutputMessage(); // 中断 SSE 连接
setIsStreaming(false); // 停止流式输出
setIsAiLoading(false); // 停止加载状态
// 此时 currentInputValue 已经包含了所有已流出的 AI 内容 (包括 <think> 标签)
// 移除 <think> 标签
const contentWithoutThinkTags = currentInputValue.replace(
/<think>.*?<\/think>/gs,
'',
);
finalContent = contentWithoutThinkTags;
// 重新计算光标位置,因为内容长度可能因移除 <think> 标签而改变
const originalLength = currentInputValue.length;
const cleanedLength = finalContent.length;
// 假设光标是在 AI 插入点之后,或者在用户输入后新位置,需要调整
// 如果光标在被移除的 <think> 区域内部,或者在移除区域之后,需要回退相应长度
if (
newCursorSelectionStart > (aiInsertionStartPosRef.current || 0)
) {
// 假设 aiInsertionStartPosRef.current 是 AI 内容的起始点
finalCursorPosition =
newCursorSelectionStart - (originalLength - cleanedLength);
// 确保光标不会超出新内容的末尾
if (finalCursorPosition > cleanedLength) {
finalCursorPosition = cleanedLength;
}
} else {
finalCursorPosition = newCursorSelectionStart; // 光标在 AI 插入点之前,无需调整
}
// 清理流式相关的临时状态和 useRef
setCurrentStreamedAiOutput('');
contentBeforeAiInsertionRef.current = '';
contentAfterAiInsertionRef.current = '';
aiInsertionStartPosRef.current = null;
}
// 检查内容中是否包含 INSERTION_MARKER如果包含则移除
// 由于 INSERTION_MARKER 为空字符串,此逻辑块影响很小
const markerIndex = finalContent.indexOf(INSERTION_MARKER); // 对已处理的 finalContent 进行检查
if (markerIndex !== -1) {
const contentWithoutMarker = finalContent.replace(
INSERTION_MARKER,
'',
);
finalContent = contentWithoutMarker;
if (newCursorSelectionStart > markerIndex) {
// 此处的 newCursorSelectionStart 仍然是原始的,需要与 markerIndex 比较
finalCursorPosition =
finalCursorPosition - INSERTION_MARKER.length;
}
}
setContent(finalContent); // 更新主内容状态
setCursorPosition(finalCursorPosition); // 更新光标位置状态
// 手动设置光标位置
// 这里不能直接操作 DOM因为是在 setState 之后DOM 尚未更新
// Ant Design Input.TextArea 会在 value 更新后自动处理光标位置
setShowCursorIndicator(true); // 用户输入时,表明已设置光标位置,持续显示标记
}}
onClick={(e) => {
const target = e.target as HTMLTextAreaElement;
setCursorPosition(target.selectionStart);
setShowCursorIndicator(true); // 点击时设置光标位置并显示标记
target.focus(); // 确保点击后立即聚焦
}}
onKeyUp={(e) => {
const target = e.target as HTMLTextAreaElement;
setCursorPosition(target.selectionStart);
setShowCursorIndicator(true); // 键盘抬起时设置光标位置并显示标记
}}
placeholder={t('writePlaceholder')}
autoSize={false}
/>
</div>
);
};
const renderPreview = () => ( const renderPreview = () => (
<div <div
style={{ style={{
@ -712,6 +886,7 @@ const Write = () => {
}} }}
> >
<HightLightMarkdown> <HightLightMarkdown>
{/* 预览模式下,通常不显示 <think> 标签,所以这里不需要特殊处理 */}
{content || t('previewPlaceholder')} {content || t('previewPlaceholder')}
</HightLightMarkdown> </HightLightMarkdown>
</div> </div>
@ -754,10 +929,10 @@ const Write = () => {
return ( return (
<Layout <Layout
style={{ style={{
height: 'calc(100vh - 80px)',
display: 'flex', display: 'flex',
flexDirection: 'row', flexDirection: 'row',
overflow: 'hidden', overflow: 'hidden',
flexGrow: 1,
}} }}
> >
<Sider <Sider
@ -773,7 +948,7 @@ const Write = () => {
<div <div
style={{ style={{
padding: '16px 16px 0 16px', padding: '16px 16px 0 16px',
height: '70%', height: '65%',
minHeight: '250px', minHeight: '250px',
display: 'flex', display: 'flex',
flexDirection: 'column', flexDirection: 'column',
@ -1023,11 +1198,6 @@ const Write = () => {
}} }}
style={{ flexShrink: 0 }} style={{ flexShrink: 0 }}
> >
{isAiLoading && (
<div style={{ textAlign: 'center', marginBottom: 8 }}>
{t('aiLoadingMessage')}...
</div>
)}
<Input.TextArea <Input.TextArea
placeholder={t('askAI')} placeholder={t('askAI')}
autoSize={{ minRows: 2, maxRows: 5 }} autoSize={{ minRows: 2, maxRows: 5 }}
@ -1036,6 +1206,50 @@ const Write = () => {
onKeyDown={handleAiQuestionSubmit} onKeyDown={handleAiQuestionSubmit}
disabled={isAiLoading} disabled={isAiLoading}
/> />
{/* 插入位置提示 或 AI正在回答时的提示 - 现已常驻显示 */}
{isStreaming ? ( // AI正在回答时优先显示此提示
<div
style={{
fontSize: '12px',
color: '#faad14', // 警告色
padding: '6px 10px',
backgroundColor: '#fffbe6',
borderRadius: '4px',
border: '1px solid #ffe58f',
}}
>
AI正在生成回答...
</div>
) : // AI未回答时
cursorPosition !== null ? ( // 如果光标已设置
<div
style={{
fontSize: '12px',
color: '#666',
padding: '6px 10px',
backgroundColor: '#e6f7ff',
borderRadius: '4px',
border: '1px solid #91d5ff',
}}
>
💡 AI回答将插入到文档光标位置 ( {cursorPosition} )
</div>
) : (
// 如果光标未设置
<div
style={{
fontSize: '12px',
color: '#f5222d', // 错误色,提醒用户
padding: '6px 10px',
backgroundColor: '#fff1f0',
borderRadius: '4px',
border: '1px solid #ffccc7',
}}
>
👆 AI内容插入位置
</div>
)}
</Card> </Card>
</Flex> </Flex>
</Content> </Content>

View File

@ -100,6 +100,8 @@ export default {
getExternalConversation: `${api_host}/api/conversation`, getExternalConversation: `${api_host}/api/conversation`,
completeExternalConversation: `${api_host}/api/completion`, completeExternalConversation: `${api_host}/api/completion`,
uploadAndParseExternal: `${api_host}/api/document/upload_and_parse`, uploadAndParseExternal: `${api_host}/api/document/upload_and_parse`,
// 文档撰写模式中的问答API
writeChat: `${api_host}/conversation/writechat`,
// file manager // file manager
listFile: `${api_host}/file/list`, listFile: `${api_host}/file/list`,