Merge pull request #144 from zstar1003/dev
feat(write): 重构文档写作接口,实现问答接口解耦,并支持流式输出
This commit is contained in:
commit
e4a4786ca3
|
@ -30,13 +30,14 @@ from api.db.services.dialog_service import DialogService, chat, ask
|
|||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle, TenantService
|
||||
from api import settings
|
||||
from api.db.services.write_service import write_dialog
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
from rag.app.tag import label_question
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@manager.route("/set", methods=["POST"]) # type: ignore # noqa: F821
|
||||
@login_required
|
||||
def set_conversation():
|
||||
req = request.json
|
||||
|
@ -50,8 +51,7 @@ def set_conversation():
|
|||
return get_data_error_result(message="Conversation not found!")
|
||||
e, conv = ConversationService.get_by_id(conv_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Fail to update a conversation!")
|
||||
return get_data_error_result(message="Fail to update a conversation!")
|
||||
conv = conv.to_dict()
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
|
@ -61,38 +61,30 @@ def set_conversation():
|
|||
e, dia = DialogService.get_by_id(req["dialog_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found")
|
||||
conv = {
|
||||
"id": conv_id,
|
||||
"dialog_id": req["dialog_id"],
|
||||
"name": req.get("name", "New conversation"),
|
||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
|
||||
}
|
||||
conv = {"id": conv_id, "dialog_id": req["dialog_id"], "name": req.get("name", "New conversation"), "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]}
|
||||
ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/get', methods=['GET']) # noqa: F821
|
||||
@manager.route("/get", methods=["GET"]) # type: ignore # type: ignore # noqa: F821
|
||||
@login_required
|
||||
def get():
|
||||
conv_id = request.args["conversation_id"]
|
||||
try:
|
||||
|
||||
e, conv = ConversationService.get_by_id(conv_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
avatar =None
|
||||
avatar = None
|
||||
for tenant in tenants:
|
||||
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
|
||||
if dialog and len(dialog)>0:
|
||||
if dialog and len(dialog) > 0:
|
||||
avatar = dialog[0].icon
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of conversation authorized for this operation.',
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
def get_value(d, k1, k2):
|
||||
return d.get(k1, d.get(k2))
|
||||
|
@ -100,7 +92,8 @@ def get():
|
|||
for ref in conv.reference:
|
||||
if isinstance(ref, list):
|
||||
continue
|
||||
ref["chunks"] = [{
|
||||
ref["chunks"] = [
|
||||
{
|
||||
"id": get_value(ck, "chunk_id", "id"),
|
||||
"content": get_value(ck, "content", "content_with_weight"),
|
||||
"document_id": get_value(ck, "doc_id", "document_id"),
|
||||
|
@ -108,18 +101,20 @@ def get():
|
|||
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
|
||||
"image_id": get_value(ck, "image_id", "img_id"),
|
||||
"positions": get_value(ck, "positions", "position_int"),
|
||||
} for ck in ref.get("chunks", [])]
|
||||
}
|
||||
for ck in ref.get("chunks", [])
|
||||
]
|
||||
|
||||
conv = conv.to_dict()
|
||||
conv["avatar"]=avatar
|
||||
conv["avatar"] = avatar
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/getsse/<dialog_id>', methods=['GET']) # type: ignore # noqa: F821
|
||||
def getsse(dialog_id):
|
||||
|
||||
token = request.headers.get('Authorization').split()
|
||||
@manager.route("/getsse/<dialog_id>", methods=["GET"]) # type: ignore # noqa: F821
|
||||
def getsse(dialog_id):
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_data_error_result(message='Authorization is not valid!"')
|
||||
token = token[1]
|
||||
|
@ -131,13 +126,14 @@ def getsse(dialog_id):
|
|||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
conv = conv.to_dict()
|
||||
conv["avatar"]= conv["icon"]
|
||||
conv["avatar"] = conv["icon"]
|
||||
del conv["icon"]
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
|
||||
@manager.route("/rm", methods=["POST"]) # type: ignore # type: ignore # noqa: F821
|
||||
@login_required
|
||||
def rm():
|
||||
conv_ids = request.json["conversation_ids"]
|
||||
|
@ -151,28 +147,21 @@ def rm():
|
|||
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of conversation authorized for this operation.',
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
ConversationService.delete_by_id(cid)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@manager.route("/list", methods=["GET"]) # type: ignore # noqa: F821
|
||||
@login_required
|
||||
def list_convsersation():
|
||||
dialog_id = request.args["dialog_id"]
|
||||
try:
|
||||
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of dialog authorized for this operation.',
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
convs = ConversationService.query(
|
||||
dialog_id=dialog_id,
|
||||
order_by=ConversationService.model.create_time,
|
||||
reverse=True)
|
||||
return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True)
|
||||
|
||||
convs = [d.to_dict() for d in convs]
|
||||
return get_json_result(data=convs)
|
||||
|
@ -180,7 +169,7 @@ def list_convsersation():
|
|||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@manager.route("/completion", methods=["POST"]) # type: ignore # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
|
@ -207,13 +196,15 @@ def completion():
|
|||
if not conv.reference:
|
||||
conv.reference = []
|
||||
else:
|
||||
|
||||
def get_value(d, k1, k2):
|
||||
return d.get(k1, d.get(k2))
|
||||
|
||||
for ref in conv.reference:
|
||||
if isinstance(ref, list):
|
||||
continue
|
||||
ref["chunks"] = [{
|
||||
ref["chunks"] = [
|
||||
{
|
||||
"id": get_value(ck, "chunk_id", "id"),
|
||||
"content": get_value(ck, "content", "content_with_weight"),
|
||||
"document_id": get_value(ck, "doc_id", "document_id"),
|
||||
|
@ -221,11 +212,14 @@ def completion():
|
|||
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
|
||||
"image_id": get_value(ck, "image_id", "img_id"),
|
||||
"positions": get_value(ck, "positions", "position_int"),
|
||||
} for ck in ref.get("chunks", [])]
|
||||
}
|
||||
for ck in ref.get("chunks", [])
|
||||
]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
def stream():
|
||||
nonlocal dia, msg, req, conv
|
||||
try:
|
||||
|
@ -235,9 +229,7 @@ def completion():
|
|||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
if req.get("stream", True):
|
||||
|
@ -259,7 +251,32 @@ def completion():
|
|||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/tts', methods=['POST']) # noqa: F821
|
||||
# 用于文档撰写模式的问答调用
|
||||
@manager.route("/writechat", methods=["POST"]) # type: ignore # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def writechat():
|
||||
req = request.json
|
||||
uid = current_user.id
|
||||
|
||||
def stream():
|
||||
nonlocal req, uid
|
||||
try:
|
||||
for ans in write_dialog(req["question"], req["kb_ids"], uid):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route("/tts", methods=["POST"]) # type: ignore # noqa: F821
|
||||
@login_required
|
||||
def tts():
|
||||
req = request.json
|
||||
|
@ -281,9 +298,7 @@ def tts():
|
|||
for chunk in tts_mdl.tts(txt):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
yield ("data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e)}},
|
||||
ensure_ascii=False)).encode('utf-8')
|
||||
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
|
||||
|
||||
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||
resp.headers.add_header("Cache-Control", "no-cache")
|
||||
|
@ -293,7 +308,7 @@ def tts():
|
|||
return resp
|
||||
|
||||
|
||||
@manager.route('/delete_msg', methods=['POST']) # noqa: F821
|
||||
@manager.route("/delete_msg", methods=["POST"]) # type: ignore # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def delete_msg():
|
||||
|
@ -316,7 +331,7 @@ def delete_msg():
|
|||
return get_json_result(data=conv)
|
||||
|
||||
|
||||
@manager.route('/thumbup', methods=['POST']) # noqa: F821
|
||||
@manager.route("/thumbup", methods=["POST"]) # type: ignore # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def thumbup():
|
||||
|
@ -343,7 +358,7 @@ def thumbup():
|
|||
return get_json_result(data=conv)
|
||||
|
||||
|
||||
@manager.route('/ask', methods=['POST']) # noqa: F821
|
||||
@manager.route("/ask", methods=["POST"]) # type: ignore # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def ask_about():
|
||||
|
@ -356,9 +371,7 @@ def ask_about():
|
|||
for ans in ask(req["question"], req["kb_ids"], uid):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
|
@ -369,7 +382,7 @@ def ask_about():
|
|||
return resp
|
||||
|
||||
|
||||
@manager.route('/mindmap', methods=['POST']) # noqa: F821
|
||||
@manager.route("/mindmap", methods=["POST"]) # type: ignore # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def mindmap():
|
||||
|
@ -382,10 +395,7 @@ def mindmap():
|
|||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)
|
||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
||||
question = req["question"]
|
||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
||||
0.3, 0.3, aggs=False,
|
||||
rank_feature=label_question(question, [kb])
|
||||
)
|
||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, 0.3, 0.3, aggs=False, rank_feature=label_question(question, [kb]))
|
||||
mindmap = MindMapExtractor(chat_mdl)
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
||||
mind_map = mind_map.output
|
||||
|
@ -394,7 +404,7 @@ def mindmap():
|
|||
return get_json_result(data=mind_map)
|
||||
|
||||
|
||||
@manager.route('/related_questions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/related_questions", methods=["POST"]) # type: ignore # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question")
|
||||
def related_questions():
|
||||
|
@ -425,8 +435,17 @@ Reason:
|
|||
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
||||
|
||||
"""
|
||||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f"""
|
||||
ans = chat_mdl.chat(
|
||||
prompt,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Keywords: {question}
|
||||
Related search terms:
|
||||
"""}], {"temperature": 0.9})
|
||||
""",
|
||||
}
|
||||
],
|
||||
{"temperature": 0.9},
|
||||
)
|
||||
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
|
||||
|
|
|
@ -30,8 +30,7 @@ from api import settings
|
|||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp.search import index_name
|
||||
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, \
|
||||
citation_prompt
|
||||
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, citation_prompt
|
||||
from rag.utils import rmSpace, num_tokens_from_string
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
|
||||
|
@ -41,17 +40,13 @@ class DialogService(CommonService):
|
|||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, tenant_id,
|
||||
page_number, items_per_page, orderby, desc, id, name):
|
||||
def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
|
||||
chats = cls.model.select()
|
||||
if id:
|
||||
chats = chats.where(cls.model.id == id)
|
||||
if name:
|
||||
chats = chats.where(cls.model.name == name)
|
||||
chats = chats.where(
|
||||
(cls.model.tenant_id == tenant_id)
|
||||
& (cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value))
|
||||
if desc:
|
||||
chats = chats.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
|
@ -72,13 +67,12 @@ def chat_solo(dialog, messages, stream=True):
|
|||
tts_mdl = None
|
||||
if prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
|
||||
for m in messages if m["role"] != "system"]
|
||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
||||
if stream:
|
||||
last_ans = ""
|
||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
delta_ans = ans[len(last_ans) :]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
continue
|
||||
last_ans = answer
|
||||
|
@ -159,8 +153,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
if p["key"] not in kwargs and not p["optional"]:
|
||||
raise KeyError("Miss parameter: " + p["key"])
|
||||
if p["key"] not in kwargs:
|
||||
prompt_config["system"] = prompt_config["system"].replace(
|
||||
"{%s}" % p["key"], " ")
|
||||
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
||||
|
||||
# 不再使用多轮对话优化
|
||||
# if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
||||
|
@ -226,17 +219,24 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
# kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
# knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
||||
kbinfos = retriever.retrieval(
|
||||
" ".join(questions),
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
dialog.kb_ids,
|
||||
1,
|
||||
dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight,
|
||||
doc_ids=attachments,
|
||||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(" ".join(questions), kbs)
|
||||
top=dialog.top_k,
|
||||
aggs=False,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(" ".join(questions), kbs),
|
||||
)
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
|
||||
logging.debug(
|
||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
|
||||
retrieval_ts = timer()
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
|
@ -252,16 +252,13 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
prompt4citation = citation_prompt()
|
||||
# 过滤掉 system 角色的消息(因为前面已经单独处理了系统消息)
|
||||
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
|
||||
for m in messages if m["role"] != "system"])
|
||||
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
|
||||
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
|
||||
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
||||
prompt = msg[0]["content"]
|
||||
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_tokens"] = min(
|
||||
gen_conf["max_tokens"],
|
||||
max_tokens - used_token_count)
|
||||
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
|
||||
|
||||
def decorate_answer(answer):
|
||||
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
|
||||
|
@ -277,14 +274,14 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
|
||||
cited_chunk_indices = set() # 用于存储被引用的 chunk 索引
|
||||
if not re.search(r"##[0-9]+\$\$", answer):
|
||||
answer, idx = retriever.insert_citations(answer,
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
answer, idx = retriever.insert_citations(
|
||||
answer,
|
||||
[ck["content_ltks"] for ck in kbinfos["chunks"]],
|
||||
[ck["vector"] for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=1 - dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
vtweight=dialog.vector_similarity_weight,
|
||||
)
|
||||
cited_chunk_indices = idx # 获取 insert_citations 返回的索引
|
||||
|
||||
else:
|
||||
|
@ -315,8 +312,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
processed_image_urls.add(img_url) # 标记为已处理
|
||||
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [
|
||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
if not recall_docs:
|
||||
recall_docs = kbinfos["doc_aggs"]
|
||||
kbinfos["doc_aggs"] = recall_docs
|
||||
|
@ -347,30 +343,30 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
|
||||
prompt += "\n\n### Query:\n%s" % " ".join(questions)
|
||||
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
|
||||
return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
|
||||
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
|
||||
|
||||
if stream:
|
||||
last_ans = "" # 记录上一次返回的完整回答
|
||||
answer = "" # 当前累计的完整回答
|
||||
for ans in chat_mdl.chat_streamly(prompt+prompt4citation, msg[1:], gen_conf):
|
||||
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
||||
# 如果存在思考过程(thought),移除相关标记
|
||||
if thought:
|
||||
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||||
answer = ans
|
||||
# 计算新增的文本片段(delta)
|
||||
delta_ans = ans[len(last_ans):]
|
||||
delta_ans = ans[len(last_ans) :]
|
||||
# 如果新增token太少(小于16),跳过本次返回(避免频繁发送小片段)
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
continue
|
||||
last_ans = answer
|
||||
# 返回当前累计回答(包含思考过程)+新增片段)
|
||||
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
delta_ans = answer[len(last_ans):]
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
delta_ans = answer[len(last_ans) :]
|
||||
if delta_ans:
|
||||
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
yield decorate_answer(thought+answer)
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
yield decorate_answer(thought + answer)
|
||||
else:
|
||||
answer = chat_mdl.chat(prompt+prompt4citation, msg[1:], gen_conf)
|
||||
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||
res = decorate_answer(answer)
|
||||
|
@ -388,27 +384,22 @@ Table of database fields are as follows:
|
|||
Question are as follows:
|
||||
{}
|
||||
Please write the SQL, only SQL, without any other explanations or text.
|
||||
""".format(
|
||||
index_name(tenant_id),
|
||||
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
||||
question
|
||||
)
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
|
||||
tried_times = 0
|
||||
|
||||
def get_table():
|
||||
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
|
||||
"temperature": 0.06})
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||
sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL)
|
||||
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
sql = re.sub(r".*select ", "select ", sql.lower())
|
||||
sql = re.sub(r" +", " ", sql)
|
||||
sql = re.sub(r"([;;]|```).*", "", sql)
|
||||
if sql[:len("select ")] != "select ":
|
||||
if sql[: len("select ")] != "select ":
|
||||
return None, None
|
||||
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
|
||||
if sql[:len("select *")] != "select *":
|
||||
if sql[: len("select *")] != "select *":
|
||||
sql = "select doc_id,docnm_kwd," + sql[6:]
|
||||
else:
|
||||
flds = []
|
||||
|
@ -445,11 +436,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||
{}
|
||||
|
||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
||||
""".format(
|
||||
index_name(tenant_id),
|
||||
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
||||
question, sql, tbl["error"]
|
||||
)
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
|
||||
tbl, sql = get_table()
|
||||
logging.debug("TRY it again: {}".format(sql))
|
||||
|
||||
|
@ -457,24 +444,18 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||
if tbl.get("error") or len(tbl["rows"]) == 0:
|
||||
return None
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(
|
||||
tbl["columns"]) if c["name"] == "doc_id"])
|
||||
doc_name_idx = set([ii for ii, c in enumerate(
|
||||
tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
column_idx = [ii for ii in range(
|
||||
len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
|
||||
|
||||
# compose Markdown table
|
||||
columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
||||
tbl["columns"][i]["name"])) for i in
|
||||
column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||
columns = (
|
||||
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||
)
|
||||
|
||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \
|
||||
("|------|" if docid_idx and docid_idx else "")
|
||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||
|
||||
rows = ["|" +
|
||||
"|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
|
||||
"|" for r in tbl["rows"]]
|
||||
rows = ["|" + "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
|
||||
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
||||
if quota:
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
|
@ -484,11 +465,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||
|
||||
if not docid_idx or not doc_name_idx:
|
||||
logging.warning("SQL missing field: " + sql)
|
||||
return {
|
||||
"answer": "\n".join([columns, line, rows]),
|
||||
"reference": {"chunks": [], "doc_aggs": []},
|
||||
"prompt": sys_prompt
|
||||
}
|
||||
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
||||
|
||||
docid_idx = list(docid_idx)[0]
|
||||
doc_name_idx = list(doc_name_idx)[0]
|
||||
|
@ -499,10 +476,11 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||
doc_aggs[r[docid_idx]]["count"] += 1
|
||||
return {
|
||||
"answer": "\n".join([columns, line, rows]),
|
||||
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
|
||||
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
||||
doc_aggs.items()]},
|
||||
"prompt": sys_prompt
|
||||
"reference": {
|
||||
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
|
||||
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()],
|
||||
},
|
||||
"prompt": sys_prompt,
|
||||
}
|
||||
|
||||
|
||||
|
@ -553,26 +531,23 @@ def ask(question, kb_ids, tenant_id):
|
|||
# 获取所有知识库的租户ID并去重
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
# 调用检索器检索相关文档片段
|
||||
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
|
||||
1, 12, 0.1, 0.3, aggs=False,
|
||||
rank_feature=label_question(question, kbs)
|
||||
)
|
||||
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))
|
||||
# 将检索结果格式化为提示词,并确保不超过模型最大token限制
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
prompt = """
|
||||
Role: You're a smart assistant. Your name is Miss R.
|
||||
Task: Summarize the information from knowledge bases and answer user's question.
|
||||
Requirements and restriction:
|
||||
- DO NOT make things up, especially for numbers.
|
||||
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
|
||||
- Answer with markdown format text.
|
||||
- Answer in language of user's question.
|
||||
- DO NOT make things up, especially for numbers.
|
||||
角色:你是一个聪明的助手。
|
||||
任务:总结知识库中的信息并回答用户的问题。
|
||||
要求与限制:
|
||||
- 绝不要捏造内容,尤其是数字。
|
||||
- 如果知识库中的信息与用户问题无关,**只需回答:对不起,未提供相关信息。
|
||||
- 使用Markdown格式进行回答。
|
||||
- 使用用户提问所用的语言作答。
|
||||
- 绝不要捏造内容,尤其是数字。
|
||||
|
||||
### Information from knowledge bases
|
||||
### 来自知识库的信息
|
||||
%s
|
||||
|
||||
The above is information from knowledge bases.
|
||||
以上是来自知识库的信息。
|
||||
|
||||
""" % "\n".join(knowledges)
|
||||
msg = [{"role": "user", "content": question}]
|
||||
|
@ -580,17 +555,9 @@ def ask(question, kb_ids, tenant_id):
|
|||
# 生成完成后添加回答中的引用标记
|
||||
def decorate_answer(answer):
|
||||
nonlocal knowledges, kbinfos, prompt
|
||||
answer, idx = retriever.insert_citations(answer,
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=0.7,
|
||||
vtweight=0.3)
|
||||
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [
|
||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
if not recall_docs:
|
||||
recall_docs = kbinfos["doc_aggs"]
|
||||
kbinfos["doc_aggs"] = recall_docs
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
from api.db import LLMType, ParserType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api import settings
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts import kb_prompt
|
||||
|
||||
|
||||
def write_dialog(question, kb_ids, tenant_id):
|
||||
"""
|
||||
处理用户搜索请求,从知识库中检索相关信息并生成回答
|
||||
|
||||
参数:
|
||||
question (str): 用户的问题或查询
|
||||
kb_ids (list): 知识库ID列表,指定要搜索的知识库
|
||||
tenant_id (str): 租户ID,用于权限控制和资源隔离
|
||||
|
||||
流程:
|
||||
1. 获取指定知识库的信息
|
||||
2. 确定使用的嵌入模型
|
||||
3. 根据知识库类型选择检索器(普通检索器或知识图谱检索器)
|
||||
4. 初始化嵌入模型和聊天模型
|
||||
5. 执行检索操作获取相关文档片段
|
||||
6. 格式化知识库内容作为上下文
|
||||
7. 构建系统提示词
|
||||
8. 生成回答并添加引用标记
|
||||
9. 流式返回生成的回答
|
||||
|
||||
返回:
|
||||
generator: 生成器对象,产生包含回答和引用信息的字典
|
||||
"""
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||
|
||||
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
|
||||
# 初始化嵌入模型,用于将文本转换为向量表示
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
# 初始化聊天模型,用于生成回答
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||
# 获取聊天模型的最大token长度,用于控制上下文长度
|
||||
max_tokens = chat_mdl.max_length
|
||||
# 获取所有知识库的租户ID并去重
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
# 调用检索器检索相关文档片段
|
||||
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))
|
||||
# 将检索结果格式化为提示词,并确保不超过模型最大token限制
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
prompt = """
|
||||
角色:你是一个聪明的助手。
|
||||
任务:总结知识库中的信息并回答用户的问题。
|
||||
要求与限制:
|
||||
- 绝不要捏造内容,尤其是数字。
|
||||
- 如果知识库中的信息与用户问题无关,**只需回答:对不起,未提供相关信息。
|
||||
- 使用Markdown格式进行回答。
|
||||
- 使用用户提问所用的语言作答。
|
||||
- 绝不要捏造内容,尤其是数字。
|
||||
|
||||
### 来自知识库的信息
|
||||
%s
|
||||
|
||||
以上是来自知识库的信息。
|
||||
|
||||
""" % "\n".join(knowledges)
|
||||
msg = [{"role": "user", "content": question}]
|
||||
|
||||
# 生成完成后添加回答中的引用标记
|
||||
# def decorate_answer(answer):
|
||||
# nonlocal knowledges, kbinfos, prompt
|
||||
# answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
|
||||
# idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
# recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
# if not recall_docs:
|
||||
# recall_docs = kbinfos["doc_aggs"]
|
||||
# kbinfos["doc_aggs"] = recall_docs
|
||||
# refs = deepcopy(kbinfos)
|
||||
# for c in refs["chunks"]:
|
||||
# if c.get("vector"):
|
||||
# del c["vector"]
|
||||
|
||||
# if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||
# answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||||
# refs["chunks"] = chunks_format(refs)
|
||||
# return {"answer": answer, "reference": refs}
|
||||
|
||||
answer = ""
|
||||
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
|
||||
answer = ans
|
||||
yield {"answer": answer, "reference": {}}
|
||||
# yield decorate_answer(answer)
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg fill="#000000" width="800px" height="800px" viewBox="0 0 36 36" version="1.1" preserveAspectRatio="xMidYMid meet" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>storage-solid</title>
|
||||
<path class="clr-i-solid clr-i-solid-path-1" d="M17.91,18.28c8.08,0,14.66-1.74,15.09-3.94V8.59c-.43,2.2-7,3.94-15.09,3.94A39.4,39.4,0,0,1,6.25,11V9a39.4,39.4,0,0,0,11.66,1.51C26,10.53,32.52,8.79,33,6.61h0C32.8,3.2,23.52,2.28,18,2.28S3,3.21,3,6.71V29.29c0,3.49,9.43,4.43,15,4.43s15-.93,15-4.43V24.09C32.57,26.28,26,28,17.91,28A39.4,39.4,0,0,1,6.25,26.52v-2A39.4,39.4,0,0,0,17.91,26C26,26,32.57,24.28,33,22.09V16.34c-.43,2.2-7,3.94-15.09,3.94A39.4,39.4,0,0,1,6.25,18.77v-2A39.4,39.4,0,0,0,17.91,18.28Z"></path>
|
||||
<rect x="0" y="0" width="36" height="36" fill-opacity="0"/>
|
||||
</svg>
|
After Width: | Height: | Size: 934 B |
|
@ -0,0 +1,123 @@
|
|||
import { Authorization } from '@/constants/authorization';
|
||||
import { IAnswer } from '@/interfaces/database/chat';
|
||||
import { IKnowledge } from '@/interfaces/database/knowledge';
|
||||
import kbService from '@/services/knowledge-service';
|
||||
import api from '@/utils/api';
|
||||
import { getAuthorization } from '@/utils/authorization-util';
|
||||
import { useQuery } from '@tanstack/react-query';
|
||||
import { EventSourceParserStream } from 'eventsource-parser/stream';
|
||||
import { useCallback, useRef, useState } from 'react';
|
||||
|
||||
// 查询知识库数据
|
||||
export const useFetchKnowledgeList = (
|
||||
shouldFilterListWithoutDocument: boolean = false,
|
||||
): {
|
||||
list: IKnowledge[];
|
||||
loading: boolean;
|
||||
} => {
|
||||
const { data, isFetching: loading } = useQuery({
|
||||
queryKey: ['fetchKnowledgeList'],
|
||||
initialData: [],
|
||||
gcTime: 0,
|
||||
queryFn: async () => {
|
||||
const { data } = await kbService.getList();
|
||||
const list = data?.data?.kbs ?? [];
|
||||
return shouldFilterListWithoutDocument
|
||||
? list.filter((x: IKnowledge) => x.chunk_num > 0)
|
||||
: list;
|
||||
},
|
||||
});
|
||||
|
||||
return { list: data, loading };
|
||||
};
|
||||
|
||||
// 发送问答信息
|
||||
export const useSendMessageWithSse = (url: string = api.writeChat) => {
|
||||
const [answer, setAnswer] = useState<IAnswer>({} as IAnswer);
|
||||
const [done, setDone] = useState(true);
|
||||
const timer = useRef<any>();
|
||||
const sseRef = useRef<AbortController>();
|
||||
|
||||
const initializeSseRef = useCallback(() => {
|
||||
sseRef.current = new AbortController();
|
||||
}, []);
|
||||
|
||||
const resetAnswer = useCallback(() => {
|
||||
if (timer.current) {
|
||||
clearTimeout(timer.current);
|
||||
}
|
||||
timer.current = setTimeout(() => {
|
||||
setAnswer({} as IAnswer);
|
||||
clearTimeout(timer.current);
|
||||
}, 1000);
|
||||
}, []);
|
||||
|
||||
const send = useCallback(
|
||||
async (
|
||||
body: any,
|
||||
controller?: AbortController,
|
||||
): Promise<{ response: Response; data: ResponseType } | undefined> => {
|
||||
initializeSseRef();
|
||||
try {
|
||||
setDone(false);
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
[Authorization]: getAuthorization(),
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
signal: controller?.signal || sseRef.current?.signal,
|
||||
});
|
||||
|
||||
const res = response.clone().json();
|
||||
|
||||
const reader = response?.body
|
||||
?.pipeThrough(new TextDecoderStream())
|
||||
.pipeThrough(new EventSourceParserStream())
|
||||
.getReader();
|
||||
|
||||
while (true) {
|
||||
const x = await reader?.read();
|
||||
if (x) {
|
||||
const { done, value } = x;
|
||||
if (done) {
|
||||
console.info('done');
|
||||
resetAnswer();
|
||||
break;
|
||||
}
|
||||
try {
|
||||
const val = JSON.parse(value?.data || '');
|
||||
const d = val?.data;
|
||||
if (typeof d !== 'boolean') {
|
||||
console.info('data:', d);
|
||||
setAnswer({
|
||||
...d,
|
||||
conversationId: body?.conversation_id,
|
||||
});
|
||||
}
|
||||
} catch (e) {
|
||||
console.warn(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
console.info('done?');
|
||||
setDone(true);
|
||||
resetAnswer();
|
||||
return { data: await res, response };
|
||||
} catch (e) {
|
||||
setDone(true);
|
||||
resetAnswer();
|
||||
|
||||
console.warn(e);
|
||||
}
|
||||
},
|
||||
[initializeSseRef, url, resetAnswer],
|
||||
);
|
||||
|
||||
const stopOutputMessage = useCallback(() => {
|
||||
sseRef.current?.abort();
|
||||
}, []);
|
||||
|
||||
return { send, answer, done, setDone, resetAnswer, stopOutputMessage };
|
||||
};
|
|
@ -591,6 +591,7 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
|||
decisions: '决定事项',
|
||||
actionItems: '行动项',
|
||||
nextMeeting: '下次会议',
|
||||
noTemplatesAvailable: "没有可用模板",
|
||||
// 模型配置相关
|
||||
modelConfigurationTitle: "模型配置",
|
||||
knowledgeBaseLabel: "知识库",
|
||||
|
@ -601,6 +602,7 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
|||
fetchKnowledgeBaseFailed: "获取知识库列表失败",
|
||||
defaultKnowledgeBase: "默认知识库",
|
||||
technicalDocsKnowledgeBase: "技术文档知识库",
|
||||
aiRequestFailedError: "问答模型请求失败",
|
||||
},
|
||||
setting: {
|
||||
profile: '概要',
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import HightLightMarkdown from '@/components/highlight-markdown';
|
||||
import { useTranslate } from '@/hooks/common-hooks';
|
||||
// 假设 aiAssistantConfig 在实际项目中是正确导入的
|
||||
// import { aiAssistantConfig } from '@/pages/write/ai-assistant-config';
|
||||
const aiAssistantConfig = { api: { timeout: 30000 } }; // 模拟定义
|
||||
import {
|
||||
useFetchKnowledgeList,
|
||||
useSendMessageWithSse,
|
||||
} from '@/hooks/write-hooks';
|
||||
|
||||
import { DeleteOutlined } from '@ant-design/icons';
|
||||
import {
|
||||
|
@ -22,7 +23,6 @@ import {
|
|||
Space,
|
||||
Typography,
|
||||
} from 'antd';
|
||||
import axios from 'axios';
|
||||
import {
|
||||
AlignmentType,
|
||||
Document,
|
||||
|
@ -32,7 +32,7 @@ import {
|
|||
TextRun,
|
||||
} from 'docx';
|
||||
import { saveAs } from 'file-saver';
|
||||
import { marked, Token, Tokens } from 'marked'; // 从 marked 导入 Token 和 Tokens 类型
|
||||
import { marked, Token, Tokens } from 'marked';
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
|
||||
const { Sider, Content } = Layout;
|
||||
|
@ -53,22 +53,28 @@ interface KnowledgeBaseItem {
|
|||
name: string;
|
||||
}
|
||||
|
||||
// 使用 marked 导出的类型或更精确的自定义类型
|
||||
type MarkedHeadingToken = Tokens.Heading;
|
||||
type MarkedParagraphToken = Tokens.Paragraph;
|
||||
type MarkedListItem = Tokens.ListItem;
|
||||
type MarkedListToken = Tokens.List;
|
||||
type MarkedSpaceToken = Tokens.Space;
|
||||
|
||||
// 定义插入点标记,以便在onChange时识别并移除
|
||||
// const INSERTION_MARKER = '【AI内容将插入此处】';
|
||||
const INSERTION_MARKER = ''; // 保持为空字符串,不显示实际标记
|
||||
|
||||
const Write = () => {
|
||||
const { t } = useTranslate('write');
|
||||
const [content, setContent] = useState('');
|
||||
const [aiQuestion, setAiQuestion] = useState('');
|
||||
const [isAiLoading, setIsAiLoading] = useState(false);
|
||||
const [dialogId, setDialogId] = useState('');
|
||||
const [dialogId] = useState('');
|
||||
// cursorPosition 存储用户点击设定的插入点位置
|
||||
const [cursorPosition, setCursorPosition] = useState<number | null>(null);
|
||||
// showCursorIndicator 现在仅用于控制文档中是否显示 'INSERTION_MARKER',
|
||||
// 并且一旦设置了光标位置,就希望它保持为 true,除非内容被清空或主动重置。
|
||||
const [showCursorIndicator, setShowCursorIndicator] = useState(false);
|
||||
const textAreaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const textAreaRef = useRef<any>(null); // Ant Design Input.TextArea 的 ref 类型
|
||||
|
||||
const [templates, setTemplates] = useState<TemplateItem[]>([]);
|
||||
const [isTemplateModalVisible, setIsTemplateModalVisible] = useState(false);
|
||||
|
@ -87,6 +93,27 @@ const Write = () => {
|
|||
const [modelTemperature, setModelTemperature] = useState<number>(0.7);
|
||||
const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBaseItem[]>([]);
|
||||
const [isLoadingKbs, setIsLoadingKbs] = useState(false);
|
||||
const [isStreaming, setIsStreaming] = useState(false); // 标记AI是否正在流式输出
|
||||
|
||||
// 新增状态和 useRef,用于流式输出管理
|
||||
// currentStreamedAiOutput 现在将直接接收 useSendMessageWithSse 返回的累积内容
|
||||
const [currentStreamedAiOutput, setCurrentStreamedAiOutput] = useState('');
|
||||
// 使用 useRef 存储 AI 插入点前后的内容,以及插入点位置,避免在流式更新中出现闭包陷阱
|
||||
const contentBeforeAiInsertionRef = useRef('');
|
||||
const contentAfterAiInsertionRef = useRef('');
|
||||
const aiInsertionStartPosRef = useRef<number | null>(null);
|
||||
|
||||
// 使用 useFetchKnowledgeList hook 获取真实数据
|
||||
const { list: knowledgeList, loading: isLoadingKnowledgeList } =
|
||||
useFetchKnowledgeList(true);
|
||||
|
||||
// 使用流式消息发送钩子
|
||||
const {
|
||||
send: sendMessage,
|
||||
answer,
|
||||
done,
|
||||
stopOutputMessage,
|
||||
} = useSendMessageWithSse();
|
||||
|
||||
const getInitialDefaultTemplateDefinitions = useCallback(
|
||||
(): TemplateItem[] => [
|
||||
|
@ -171,70 +198,102 @@ const Write = () => {
|
|||
loadOrInitializeTemplates();
|
||||
}, [loadOrInitializeTemplates]);
|
||||
|
||||
// 将 knowledgeList 数据同步到 knowledgeBases 状态
|
||||
useEffect(() => {
|
||||
const fetchKbs = async () => {
|
||||
const authorization = localStorage.getItem('Authorization');
|
||||
if (!authorization) {
|
||||
setKnowledgeBases([]);
|
||||
return;
|
||||
if (knowledgeList && knowledgeList.length > 0) {
|
||||
setKnowledgeBases(
|
||||
knowledgeList.map((kb) => ({
|
||||
id: kb.id,
|
||||
name: kb.name,
|
||||
})),
|
||||
);
|
||||
setIsLoadingKbs(isLoadingKnowledgeList);
|
||||
}
|
||||
setIsLoadingKbs(true);
|
||||
try {
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, 500);
|
||||
}, [knowledgeList, isLoadingKnowledgeList]);
|
||||
|
||||
// --- 调整流式响应处理逻辑 ---
|
||||
// 阶段1: 累积 AI 输出片段,用于实时显示(包括 <think> 标签)
|
||||
// 这个 useEffect 确保 currentStreamedAiOutput 始终是实时更新的、包含 <think> 标签的完整内容
|
||||
useEffect(() => {
|
||||
if (isStreaming && answer && answer.answer) {
|
||||
setCurrentStreamedAiOutput(answer.answer);
|
||||
}
|
||||
}, [isStreaming, answer]);
|
||||
|
||||
// 阶段2: 当流式输出完成时 (done 为 true)
|
||||
// 这个 useEffect 负责在流式输出结束时执行清理和最终内容更新
|
||||
useEffect(() => {
|
||||
if (done) {
|
||||
setIsStreaming(false);
|
||||
setIsAiLoading(false);
|
||||
|
||||
// --- Process the final streamed AI output before committing ---
|
||||
// 关键修改:这里**必须**使用 currentStreamedAiOutput,因为它是在流式过程中积累的、包含 <think> 标签的内容
|
||||
// answer.answer 可能在 done 阶段已经提前被钩子内部清理过,所以不能依赖它来获取带标签的原始内容。
|
||||
let processedAiOutput = currentStreamedAiOutput;
|
||||
if (processedAiOutput) {
|
||||
// Regex to remove <think>...</think> including content
|
||||
processedAiOutput = processedAiOutput.replace(
|
||||
/<think>.*?<\/think>/gs,
|
||||
'',
|
||||
);
|
||||
}
|
||||
// --- END NEW ---
|
||||
|
||||
// 将最终累积的AI内容(已处理移除<think>标签)和初始文档内容拼接,更新到主内容状态
|
||||
setContent((prevContent) => {
|
||||
if (aiInsertionStartPosRef.current !== null) {
|
||||
// 使用 useRef 中存储的初始内容和最终处理过的 AI 输出
|
||||
const finalContent =
|
||||
contentBeforeAiInsertionRef.current +
|
||||
processedAiOutput +
|
||||
contentAfterAiInsertionRef.current;
|
||||
return finalContent;
|
||||
}
|
||||
return prevContent;
|
||||
});
|
||||
const mockKbs: KnowledgeBaseItem[] = [
|
||||
{
|
||||
id: 'kb_default',
|
||||
name: t('defaultKnowledgeBase', { defaultValue: '默认知识库' }),
|
||||
},
|
||||
{
|
||||
id: 'kb_tech',
|
||||
name: t('technicalDocsKnowledgeBase', {
|
||||
defaultValue: '技术文档知识库',
|
||||
}),
|
||||
},
|
||||
{
|
||||
id: 'kb_product',
|
||||
name: t('productInfoKnowledgeBase', {
|
||||
defaultValue: '产品信息知识库',
|
||||
}),
|
||||
},
|
||||
{
|
||||
id: 'kb_marketing',
|
||||
name: t('marketingMaterialsKB', { defaultValue: '市场营销材料库' }),
|
||||
},
|
||||
{
|
||||
id: 'kb_legal',
|
||||
name: t('legalDocumentsKB', { defaultValue: '法律文件库' }),
|
||||
},
|
||||
];
|
||||
setKnowledgeBases(mockKbs);
|
||||
} catch (error) {
|
||||
console.error('获取知识库失败:', error);
|
||||
message.error(t('fetchKnowledgeBaseFailed'));
|
||||
setKnowledgeBases([]);
|
||||
} finally {
|
||||
setIsLoadingKbs(false);
|
||||
|
||||
// AI完成回答后,将光标实际移到新内容末尾
|
||||
if (
|
||||
textAreaRef.current?.resizableTextArea?.textArea &&
|
||||
aiInsertionStartPosRef.current !== null
|
||||
) {
|
||||
const newCursorPos =
|
||||
aiInsertionStartPosRef.current + processedAiOutput.length;
|
||||
textAreaRef.current.resizableTextArea.textArea.selectionStart =
|
||||
newCursorPos;
|
||||
textAreaRef.current.resizableTextArea.textArea.selectionEnd =
|
||||
newCursorPos;
|
||||
textAreaRef.current.resizableTextArea.textArea.focus();
|
||||
setCursorPosition(newCursorPos);
|
||||
}
|
||||
};
|
||||
fetchKbs();
|
||||
}, [t]);
|
||||
|
||||
// 清理流式相关的临时状态和 useRef
|
||||
setCurrentStreamedAiOutput(''); // 清空累积内容
|
||||
contentBeforeAiInsertionRef.current = '';
|
||||
contentAfterAiInsertionRef.current = '';
|
||||
aiInsertionStartPosRef.current = null;
|
||||
setShowCursorIndicator(true);
|
||||
}
|
||||
}, [done, currentStreamedAiOutput]); // 依赖 done 和 currentStreamedAiOutput,确保在 done 时拿到最新的 currentStreamedAiOutput
|
||||
|
||||
// 监听 currentStreamedAiOutput 的变化,实时更新主 content 状态以实现流式显示
|
||||
useEffect(() => {
|
||||
if (isStreaming && aiInsertionStartPosRef.current !== null) {
|
||||
// 实时更新编辑器内容,保留 <think> 标签内容
|
||||
setContent(
|
||||
contentBeforeAiInsertionRef.current +
|
||||
currentStreamedAiOutput +
|
||||
contentAfterAiInsertionRef.current,
|
||||
);
|
||||
// 同时更新 cursorPosition,让光标跟随 AI 输出移动(基于包含 think 标签的原始长度)
|
||||
setCursorPosition(
|
||||
aiInsertionStartPosRef.current + currentStreamedAiOutput.length,
|
||||
);
|
||||
}
|
||||
}, [currentStreamedAiOutput, isStreaming, aiInsertionStartPosRef]);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchDialogs = async () => {
|
||||
try {
|
||||
const authorization = localStorage.getItem('Authorization');
|
||||
if (!authorization) return;
|
||||
const response = await axios.get('/v1/dialog', {
|
||||
headers: { authorization },
|
||||
});
|
||||
if (response.data?.data?.length > 0)
|
||||
setDialogId(response.data.data[0].id);
|
||||
} catch (error) {
|
||||
console.error('获取对话列表失败:', error);
|
||||
}
|
||||
};
|
||||
const loadDraftContent = () => {
|
||||
try {
|
||||
const draftContent = localStorage.getItem('writeDraftContent');
|
||||
|
@ -250,13 +309,13 @@ const Write = () => {
|
|||
console.error('加载暂存内容失败:', error);
|
||||
}
|
||||
};
|
||||
fetchDialogs();
|
||||
if (localStorage.getItem(LOCAL_STORAGE_INIT_FLAG_KEY) === 'true') {
|
||||
loadDraftContent();
|
||||
}
|
||||
}, [content, selectedTemplate, templates]);
|
||||
|
||||
useEffect(() => {
|
||||
// 防抖保存,防止频繁写入 localStorage
|
||||
const timer = setTimeout(
|
||||
() => localStorage.setItem('writeDraftContent', content),
|
||||
1000,
|
||||
|
@ -302,6 +361,7 @@ const Write = () => {
|
|||
}
|
||||
};
|
||||
|
||||
// 删除模板
|
||||
const handleDeleteTemplate = (templateId: string) => {
|
||||
try {
|
||||
const updatedTemplates = templates.filter((t) => t.id !== templateId);
|
||||
|
@ -326,6 +386,51 @@ const Write = () => {
|
|||
}
|
||||
};
|
||||
|
||||
// 获取上下文内容的辅助函数
|
||||
const getContextContent = (
|
||||
cursorPos: number,
|
||||
currentDocumentContent: string,
|
||||
maxLength: number = 4000,
|
||||
) => {
|
||||
// 注意: 这里的 currentDocumentContent 传入的是 AI 提问时编辑器里的总内容,
|
||||
// 而不是 contentBeforeAiInsertionRef + contentAfterAiInsertionRef,因为可能包含标记
|
||||
const beforeCursor = currentDocumentContent.substring(0, cursorPos);
|
||||
const afterCursor = currentDocumentContent.substring(cursorPos);
|
||||
|
||||
// 使用更明显的插入点标记,这个标记是给AI看的,不是给用户看的
|
||||
const insertMarker = '[AI 内容插入点]';
|
||||
const availableLength = maxLength - insertMarker.length;
|
||||
|
||||
if (currentDocumentContent.length <= availableLength) {
|
||||
return {
|
||||
beforeCursor,
|
||||
afterCursor,
|
||||
contextContent: beforeCursor + insertMarker + afterCursor,
|
||||
};
|
||||
}
|
||||
|
||||
const halfLength = Math.floor(availableLength / 2);
|
||||
let finalBefore = beforeCursor;
|
||||
let finalAfter = afterCursor;
|
||||
|
||||
// 如果前半部分太长,截断并在前面加省略号
|
||||
if (beforeCursor.length > halfLength) {
|
||||
finalBefore =
|
||||
'...' + beforeCursor.substring(beforeCursor.length - halfLength + 3);
|
||||
}
|
||||
|
||||
// 如果后半部分太长,截断并在后面加省略号
|
||||
if (afterCursor.length > halfLength) {
|
||||
finalAfter = afterCursor.substring(0, halfLength - 3) + '...';
|
||||
}
|
||||
|
||||
return {
|
||||
beforeCursor,
|
||||
afterCursor,
|
||||
contextContent: finalBefore + insertMarker + finalAfter,
|
||||
};
|
||||
};
|
||||
|
||||
const handleAiQuestionSubmit = async (
|
||||
e: React.KeyboardEvent<HTMLTextAreaElement>,
|
||||
) => {
|
||||
|
@ -335,128 +440,106 @@ const Write = () => {
|
|||
message.warning(t('enterYourQuestion'));
|
||||
return;
|
||||
}
|
||||
if (!dialogId) {
|
||||
message.error(t('noDialogFound'));
|
||||
|
||||
// 检查是否选择了知识库
|
||||
if (selectedKnowledgeBases.length === 0) {
|
||||
message.warning('请至少选择一个知识库');
|
||||
return;
|
||||
}
|
||||
setIsAiLoading(true);
|
||||
const initialCursorPos = cursorPosition;
|
||||
const originalContent = content;
|
||||
let beforeCursor = '',
|
||||
afterCursor = '';
|
||||
if (initialCursorPos !== null && showCursorIndicator) {
|
||||
beforeCursor = originalContent.substring(0, initialCursorPos);
|
||||
afterCursor = originalContent.substring(initialCursorPos);
|
||||
}
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(
|
||||
() => controller.abort(),
|
||||
aiAssistantConfig.api.timeout || 30000,
|
||||
|
||||
// 如果AI正在流式输出,停止它,并处理新问题
|
||||
if (isStreaming) {
|
||||
stopOutputMessage(); // 停止当前的流式输出
|
||||
setIsStreaming(false); // 立即设置为false,中断流
|
||||
setIsAiLoading(false); // 确保加载状态也停止
|
||||
|
||||
// 中断时立即清除流中的 <think> 标签,并更新主内容
|
||||
// 这里使用 currentStreamedAiOutput 作为基准来构建中断时的内容,
|
||||
// 因为它是屏幕上实际显示的,包含了 <think> 标签。
|
||||
const contentToCleanOnInterrupt =
|
||||
contentBeforeAiInsertionRef.current +
|
||||
currentStreamedAiOutput +
|
||||
contentAfterAiInsertionRef.current;
|
||||
const cleanedContent = contentToCleanOnInterrupt.replace(
|
||||
/<think>.*?<\/think>/gs,
|
||||
'',
|
||||
);
|
||||
setContent(cleanedContent);
|
||||
|
||||
setCurrentStreamedAiOutput(''); // 清除旧的流式内容
|
||||
contentBeforeAiInsertionRef.current = ''; // 清理 useRef
|
||||
contentAfterAiInsertionRef.current = '';
|
||||
aiInsertionStartPosRef.current = null;
|
||||
message.info('已中断上一次AI回答,正在处理新问题...');
|
||||
// 稍作延迟,确保状态更新后再处理新问题,防止竞态条件
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, 100);
|
||||
});
|
||||
}
|
||||
|
||||
// 如果当前光标位置无效,提醒用户设置
|
||||
if (cursorPosition === null) {
|
||||
message.warning('请先点击文本框以设置AI内容插入位置。');
|
||||
return;
|
||||
}
|
||||
|
||||
// 捕获 AI 插入点前后的静态内容,存储到 useRef
|
||||
const currentCursorPos = cursorPosition;
|
||||
// 此时的 content 应该是用户当前编辑器的实际内容,包括可能存在的INSERTION_MARKER
|
||||
// 但由于 INSERTION_MARKER 为空,所以就是当前的主 content
|
||||
contentBeforeAiInsertionRef.current = content.substring(
|
||||
0,
|
||||
currentCursorPos,
|
||||
);
|
||||
contentAfterAiInsertionRef.current = content.substring(currentCursorPos);
|
||||
aiInsertionStartPosRef.current = currentCursorPos; // 记录确切的开始插入位置
|
||||
|
||||
setIsAiLoading(true);
|
||||
setIsStreaming(true); // 标记AI开始流式输出
|
||||
setCurrentStreamedAiOutput(''); // 清空历史累积内容,为新的流做准备
|
||||
|
||||
try {
|
||||
const authorization = localStorage.getItem('Authorization');
|
||||
if (!authorization) {
|
||||
message.error(t('loginRequiredError'));
|
||||
setIsAiLoading(false);
|
||||
setIsStreaming(false); // 停止流式标记
|
||||
// 失败时也清理临时状态
|
||||
setCurrentStreamedAiOutput('');
|
||||
contentBeforeAiInsertionRef.current = '';
|
||||
contentAfterAiInsertionRef.current = '';
|
||||
aiInsertionStartPosRef.current = null;
|
||||
return;
|
||||
}
|
||||
const conversationId =
|
||||
Math.random().toString(36).substring(2) + Date.now().toString(36);
|
||||
await axios.post(
|
||||
'v1/conversation/set',
|
||||
{
|
||||
dialog_id: dialogId,
|
||||
name: '文档撰写对话',
|
||||
is_new: true,
|
||||
conversation_id: conversationId,
|
||||
message: [{ role: 'assistant', content: '新对话' }],
|
||||
},
|
||||
{ headers: { authorization }, signal: controller.signal },
|
||||
|
||||
// 构建请求内容,将上下文内容发送给AI
|
||||
let questionWithContext = aiQuestion;
|
||||
|
||||
// 只有当用户设置了插入位置时才包含上下文
|
||||
if (aiInsertionStartPosRef.current !== null) {
|
||||
// 传递给 getContextContent 的 content 应该是当前编辑器完整的,包含marker的
|
||||
const { contextContent } = getContextContent(
|
||||
aiInsertionStartPosRef.current,
|
||||
content,
|
||||
);
|
||||
const combinedQuestion = `${aiQuestion}\n\n${t('currentDocumentContextLabel')}:\n${originalContent}`;
|
||||
let lastReceivedContent = '';
|
||||
const response = await axios.post(
|
||||
'/v1/conversation/completion',
|
||||
{
|
||||
conversation_id: conversationId,
|
||||
messages: [{ role: 'user', content: combinedQuestion }],
|
||||
knowledge_base_ids:
|
||||
selectedKnowledgeBases.length > 0
|
||||
? selectedKnowledgeBases
|
||||
: undefined,
|
||||
questionWithContext = `${aiQuestion}\n\n上下文内容:\n${contextContent}`;
|
||||
}
|
||||
|
||||
// 发送流式请求
|
||||
await sendMessage({
|
||||
question: questionWithContext,
|
||||
kb_ids: selectedKnowledgeBases,
|
||||
dialog_id: dialogId,
|
||||
similarity_threshold: similarityThreshold,
|
||||
keyword_similarity_weight: keywordSimilarityWeight,
|
||||
temperature: modelTemperature,
|
||||
},
|
||||
{
|
||||
timeout: aiAssistantConfig.api.timeout,
|
||||
headers: { authorization },
|
||||
signal: controller.signal,
|
||||
},
|
||||
);
|
||||
if (response.data) {
|
||||
const lines = response.data
|
||||
.split('\n')
|
||||
.filter((line: string) => line.trim());
|
||||
for (let i = 0; i < lines.length; i++) {
|
||||
try {
|
||||
const jsonStr = lines[i].replace('data:', '').trim();
|
||||
const jsonData = JSON.parse(jsonStr);
|
||||
if (jsonData.code === 0 && jsonData.data?.answer) {
|
||||
const answerChunk = jsonData.data.answer;
|
||||
const cleanedAnswerChunk = answerChunk
|
||||
.replace(/<think>[\s\S]*?<\/think>/g, '')
|
||||
.trim();
|
||||
const hasUnclosedThink =
|
||||
cleanedAnswerChunk.includes('<think>') &&
|
||||
(!cleanedAnswerChunk.includes('</think>') ||
|
||||
cleanedAnswerChunk.indexOf('<think>') >
|
||||
cleanedAnswerChunk.lastIndexOf('</think>'));
|
||||
if (cleanedAnswerChunk && !hasUnclosedThink) {
|
||||
const incrementalContent = cleanedAnswerChunk.substring(
|
||||
lastReceivedContent.length,
|
||||
);
|
||||
if (incrementalContent) {
|
||||
lastReceivedContent = cleanedAnswerChunk;
|
||||
let newFullContent,
|
||||
newCursorPosAfterInsertion = cursorPosition;
|
||||
if (initialCursorPos !== null && showCursorIndicator) {
|
||||
newFullContent =
|
||||
beforeCursor + cleanedAnswerChunk + afterCursor;
|
||||
newCursorPosAfterInsertion =
|
||||
initialCursorPos + cleanedAnswerChunk.length;
|
||||
} else {
|
||||
newFullContent = originalContent + cleanedAnswerChunk;
|
||||
newCursorPosAfterInsertion = newFullContent.length;
|
||||
}
|
||||
setContent(newFullContent);
|
||||
setCursorPosition(newCursorPosAfterInsertion);
|
||||
setTimeout(() => {
|
||||
if (textAreaRef.current) {
|
||||
textAreaRef.current.focus();
|
||||
textAreaRef.current.setSelectionRange(
|
||||
newCursorPosAfterInsertion!,
|
||||
newCursorPosAfterInsertion!,
|
||||
);
|
||||
}
|
||||
}, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (parseErr) {
|
||||
console.error('解析单行数据失败:', parseErr);
|
||||
}
|
||||
if (i < lines.length - 1)
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, 10);
|
||||
});
|
||||
|
||||
setAiQuestion(''); // 清空输入框
|
||||
// 重新聚焦文本框,但不是AI问答框,而是主编辑区
|
||||
if (textAreaRef.current?.resizableTextArea?.textArea) {
|
||||
textAreaRef.current.resizableTextArea.textArea.focus();
|
||||
}
|
||||
}
|
||||
await axios.post(
|
||||
'/v1/conversation/rm',
|
||||
{ conversation_ids: [conversationId], dialog_id: dialogId },
|
||||
{ headers: { authorization } },
|
||||
);
|
||||
} catch (error: any) {
|
||||
console.error('AI助手处理失败:', error);
|
||||
if (error.code === 'ECONNABORTED' || error.name === 'AbortError') {
|
||||
|
@ -469,14 +552,13 @@ const Write = () => {
|
|||
message.error(t('aiRequestFailedError'));
|
||||
}
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
setIsAiLoading(false);
|
||||
setAiQuestion('');
|
||||
if (textAreaRef.current) textAreaRef.current.focus();
|
||||
// AI加载状态在 done 状态或错误处理中会更新,这里不主动设置为 false
|
||||
// 只有当 isStreaming 状态完全结束时,才彻底清除临时状态
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// 导出为Word
|
||||
const handleSave = () => {
|
||||
const selectedTemplateItem = templates.find(
|
||||
(item) => item.id === selectedTemplate,
|
||||
|
@ -674,7 +756,27 @@ const Write = () => {
|
|||
}
|
||||
};
|
||||
|
||||
const renderEditor = () => (
|
||||
// 修改编辑器渲染函数,添加光标标记
|
||||
const renderEditor = () => {
|
||||
let displayContent = content; // 默认显示主内容状态
|
||||
|
||||
// 如果 AI 正在流式输出,则动态拼接显示内容
|
||||
if (isStreaming && aiInsertionStartPosRef.current !== null) {
|
||||
// 实时显示时,保留 <think> 标签内容
|
||||
displayContent =
|
||||
contentBeforeAiInsertionRef.current +
|
||||
currentStreamedAiOutput +
|
||||
contentAfterAiInsertionRef.current;
|
||||
} else if (showCursorIndicator && cursorPosition !== null) {
|
||||
// 如果不处于流式输出中,但设置了光标,则显示插入标记
|
||||
// (由于 INSERTION_MARKER 为空字符串,这一步实际上不会添加可见标记)
|
||||
const beforeCursor = content.substring(0, cursorPosition);
|
||||
const afterCursor = content.substring(cursorPosition);
|
||||
displayContent = beforeCursor + INSERTION_MARKER + afterCursor;
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ position: 'relative', height: '100%', width: '100%' }}>
|
||||
<Input.TextArea
|
||||
ref={textAreaRef}
|
||||
style={{
|
||||
|
@ -685,22 +787,94 @@ const Write = () => {
|
|||
fontSize: 16,
|
||||
resize: 'none',
|
||||
}}
|
||||
value={content}
|
||||
onChange={(e) => setContent(e.target.value)}
|
||||
value={displayContent} // 使用动态的 displayContent
|
||||
onChange={(e) => {
|
||||
const currentInputValue = e.target.value; // 获取当前输入框中的完整内容
|
||||
const newCursorSelectionStart = e.target.selectionStart;
|
||||
let finalContent = currentInputValue;
|
||||
let finalCursorPosition = newCursorSelectionStart;
|
||||
|
||||
// 如果用户在 AI 流式输出时输入,则中断 AI 输出,并“固化”当前内容(清除 <think> 标签)
|
||||
if (isStreaming) {
|
||||
stopOutputMessage(); // 中断 SSE 连接
|
||||
setIsStreaming(false); // 停止流式输出
|
||||
setIsAiLoading(false); // 停止加载状态
|
||||
|
||||
// 此时 currentInputValue 已经包含了所有已流出的 AI 内容 (包括 <think> 标签)
|
||||
// 移除 <think> 标签
|
||||
const contentWithoutThinkTags = currentInputValue.replace(
|
||||
/<think>.*?<\/think>/gs,
|
||||
'',
|
||||
);
|
||||
finalContent = contentWithoutThinkTags;
|
||||
|
||||
// 重新计算光标位置,因为内容长度可能因移除 <think> 标签而改变
|
||||
const originalLength = currentInputValue.length;
|
||||
const cleanedLength = finalContent.length;
|
||||
|
||||
// 假设光标是在 AI 插入点之后,或者在用户输入后新位置,需要调整
|
||||
// 如果光标在被移除的 <think> 区域内部,或者在移除区域之后,需要回退相应长度
|
||||
if (
|
||||
newCursorSelectionStart > (aiInsertionStartPosRef.current || 0)
|
||||
) {
|
||||
// 假设 aiInsertionStartPosRef.current 是 AI 内容的起始点
|
||||
finalCursorPosition =
|
||||
newCursorSelectionStart - (originalLength - cleanedLength);
|
||||
// 确保光标不会超出新内容的末尾
|
||||
if (finalCursorPosition > cleanedLength) {
|
||||
finalCursorPosition = cleanedLength;
|
||||
}
|
||||
} else {
|
||||
finalCursorPosition = newCursorSelectionStart; // 光标在 AI 插入点之前,无需调整
|
||||
}
|
||||
|
||||
// 清理流式相关的临时状态和 useRef
|
||||
setCurrentStreamedAiOutput('');
|
||||
contentBeforeAiInsertionRef.current = '';
|
||||
contentAfterAiInsertionRef.current = '';
|
||||
aiInsertionStartPosRef.current = null;
|
||||
}
|
||||
|
||||
// 检查内容中是否包含 INSERTION_MARKER,如果包含则移除
|
||||
// 由于 INSERTION_MARKER 为空字符串,此逻辑块影响很小
|
||||
const markerIndex = finalContent.indexOf(INSERTION_MARKER); // 对已处理的 finalContent 进行检查
|
||||
if (markerIndex !== -1) {
|
||||
const contentWithoutMarker = finalContent.replace(
|
||||
INSERTION_MARKER,
|
||||
'',
|
||||
);
|
||||
finalContent = contentWithoutMarker;
|
||||
if (newCursorSelectionStart > markerIndex) {
|
||||
// 此处的 newCursorSelectionStart 仍然是原始的,需要与 markerIndex 比较
|
||||
finalCursorPosition =
|
||||
finalCursorPosition - INSERTION_MARKER.length;
|
||||
}
|
||||
}
|
||||
|
||||
setContent(finalContent); // 更新主内容状态
|
||||
setCursorPosition(finalCursorPosition); // 更新光标位置状态
|
||||
// 手动设置光标位置
|
||||
// 这里不能直接操作 DOM,因为是在 setState 之后,DOM 尚未更新
|
||||
// Ant Design Input.TextArea 会在 value 更新后自动处理光标位置
|
||||
setShowCursorIndicator(true); // 用户输入时,表明已设置光标位置,持续显示标记
|
||||
}}
|
||||
onClick={(e) => {
|
||||
const target = e.target as HTMLTextAreaElement;
|
||||
setCursorPosition(target.selectionStart);
|
||||
setShowCursorIndicator(true);
|
||||
setShowCursorIndicator(true); // 点击时设置光标位置并显示标记
|
||||
target.focus(); // 确保点击后立即聚焦
|
||||
}}
|
||||
onKeyUp={(e) => {
|
||||
const target = e.target as HTMLTextAreaElement;
|
||||
setCursorPosition(target.selectionStart);
|
||||
setShowCursorIndicator(true);
|
||||
setShowCursorIndicator(true); // 键盘抬起时设置光标位置并显示标记
|
||||
}}
|
||||
placeholder={t('writePlaceholder')}
|
||||
autoSize={false}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
const renderPreview = () => (
|
||||
<div
|
||||
style={{
|
||||
|
@ -712,6 +886,7 @@ const Write = () => {
|
|||
}}
|
||||
>
|
||||
<HightLightMarkdown>
|
||||
{/* 预览模式下,通常不显示 <think> 标签,所以这里不需要特殊处理 */}
|
||||
{content || t('previewPlaceholder')}
|
||||
</HightLightMarkdown>
|
||||
</div>
|
||||
|
@ -754,10 +929,10 @@ const Write = () => {
|
|||
return (
|
||||
<Layout
|
||||
style={{
|
||||
height: 'calc(100vh - 80px)',
|
||||
display: 'flex',
|
||||
flexDirection: 'row',
|
||||
overflow: 'hidden',
|
||||
flexGrow: 1,
|
||||
}}
|
||||
>
|
||||
<Sider
|
||||
|
@ -773,7 +948,7 @@ const Write = () => {
|
|||
<div
|
||||
style={{
|
||||
padding: '16px 16px 0 16px',
|
||||
height: '70%',
|
||||
height: '65%',
|
||||
minHeight: '250px',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
|
@ -1023,11 +1198,6 @@ const Write = () => {
|
|||
}}
|
||||
style={{ flexShrink: 0 }}
|
||||
>
|
||||
{isAiLoading && (
|
||||
<div style={{ textAlign: 'center', marginBottom: 8 }}>
|
||||
{t('aiLoadingMessage')}...
|
||||
</div>
|
||||
)}
|
||||
<Input.TextArea
|
||||
placeholder={t('askAI')}
|
||||
autoSize={{ minRows: 2, maxRows: 5 }}
|
||||
|
@ -1036,6 +1206,50 @@ const Write = () => {
|
|||
onKeyDown={handleAiQuestionSubmit}
|
||||
disabled={isAiLoading}
|
||||
/>
|
||||
|
||||
{/* 插入位置提示 或 AI正在回答时的提示 - 现已常驻显示 */}
|
||||
{isStreaming ? ( // AI正在回答时优先显示此提示
|
||||
<div
|
||||
style={{
|
||||
fontSize: '12px',
|
||||
color: '#faad14', // 警告色
|
||||
padding: '6px 10px',
|
||||
backgroundColor: '#fffbe6',
|
||||
borderRadius: '4px',
|
||||
border: '1px solid #ffe58f',
|
||||
}}
|
||||
>
|
||||
✨ AI正在生成回答,请稍候...
|
||||
</div>
|
||||
) : // AI未回答时
|
||||
cursorPosition !== null ? ( // 如果光标已设置
|
||||
<div
|
||||
style={{
|
||||
fontSize: '12px',
|
||||
color: '#666',
|
||||
padding: '6px 10px',
|
||||
backgroundColor: '#e6f7ff',
|
||||
borderRadius: '4px',
|
||||
border: '1px solid #91d5ff',
|
||||
}}
|
||||
>
|
||||
💡 AI回答将插入到文档光标位置 (第 {cursorPosition} 个字符)。
|
||||
</div>
|
||||
) : (
|
||||
// 如果光标未设置
|
||||
<div
|
||||
style={{
|
||||
fontSize: '12px',
|
||||
color: '#f5222d', // 错误色,提醒用户
|
||||
padding: '6px 10px',
|
||||
backgroundColor: '#fff1f0',
|
||||
borderRadius: '4px',
|
||||
border: '1px solid #ffccc7',
|
||||
}}
|
||||
>
|
||||
👆 请在上方文档中点击,设置AI内容插入位置。
|
||||
</div>
|
||||
)}
|
||||
</Card>
|
||||
</Flex>
|
||||
</Content>
|
||||
|
|
|
@ -100,6 +100,8 @@ export default {
|
|||
getExternalConversation: `${api_host}/api/conversation`,
|
||||
completeExternalConversation: `${api_host}/api/completion`,
|
||||
uploadAndParseExternal: `${api_host}/api/document/upload_and_parse`,
|
||||
// 文档撰写模式中的问答API
|
||||
writeChat: `${api_host}/conversation/writechat`,
|
||||
|
||||
// file manager
|
||||
listFile: `${api_host}/file/list`,
|
||||
|
|
Loading…
Reference in New Issue