feat(write): 新增文档撰写模式的问答单独调用函数路由
- 添加了 /writechat 路由,用于文档撰写模式的问答调用
This commit is contained in:
parent
d1ed2019e9
commit
cfab4bc7bf
|
@ -36,7 +36,7 @@ from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def set_conversation():
|
def set_conversation():
|
||||||
req = request.json
|
req = request.json
|
||||||
|
@ -50,8 +50,7 @@ def set_conversation():
|
||||||
return get_data_error_result(message="Conversation not found!")
|
return get_data_error_result(message="Conversation not found!")
|
||||||
e, conv = ConversationService.get_by_id(conv_id)
|
e, conv = ConversationService.get_by_id(conv_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(
|
return get_data_error_result(message="Fail to update a conversation!")
|
||||||
message="Fail to update a conversation!")
|
|
||||||
conv = conv.to_dict()
|
conv = conv.to_dict()
|
||||||
return get_json_result(data=conv)
|
return get_json_result(data=conv)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -61,38 +60,30 @@ def set_conversation():
|
||||||
e, dia = DialogService.get_by_id(req["dialog_id"])
|
e, dia = DialogService.get_by_id(req["dialog_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Dialog not found")
|
return get_data_error_result(message="Dialog not found")
|
||||||
conv = {
|
conv = {"id": conv_id, "dialog_id": req["dialog_id"], "name": req.get("name", "New conversation"), "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]}
|
||||||
"id": conv_id,
|
|
||||||
"dialog_id": req["dialog_id"],
|
|
||||||
"name": req.get("name", "New conversation"),
|
|
||||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
|
|
||||||
}
|
|
||||||
ConversationService.save(**conv)
|
ConversationService.save(**conv)
|
||||||
return get_json_result(data=conv)
|
return get_json_result(data=conv)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/get', methods=['GET']) # noqa: F821
|
@manager.route("/get", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def get():
|
def get():
|
||||||
conv_id = request.args["conversation_id"]
|
conv_id = request.args["conversation_id"]
|
||||||
try:
|
try:
|
||||||
|
|
||||||
e, conv = ConversationService.get_by_id(conv_id)
|
e, conv = ConversationService.get_by_id(conv_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Conversation not found!")
|
return get_data_error_result(message="Conversation not found!")
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
avatar =None
|
avatar = None
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
|
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
|
||||||
if dialog and len(dialog)>0:
|
if dialog and len(dialog) > 0:
|
||||||
avatar = dialog[0].icon
|
avatar = dialog[0].icon
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||||
data=False, message='Only owner of conversation authorized for this operation.',
|
|
||||||
code=settings.RetCode.OPERATING_ERROR)
|
|
||||||
|
|
||||||
def get_value(d, k1, k2):
|
def get_value(d, k1, k2):
|
||||||
return d.get(k1, d.get(k2))
|
return d.get(k1, d.get(k2))
|
||||||
|
@ -100,26 +91,29 @@ def get():
|
||||||
for ref in conv.reference:
|
for ref in conv.reference:
|
||||||
if isinstance(ref, list):
|
if isinstance(ref, list):
|
||||||
continue
|
continue
|
||||||
ref["chunks"] = [{
|
ref["chunks"] = [
|
||||||
"id": get_value(ck, "chunk_id", "id"),
|
{
|
||||||
"content": get_value(ck, "content", "content_with_weight"),
|
"id": get_value(ck, "chunk_id", "id"),
|
||||||
"document_id": get_value(ck, "doc_id", "document_id"),
|
"content": get_value(ck, "content", "content_with_weight"),
|
||||||
"document_name": get_value(ck, "docnm_kwd", "document_name"),
|
"document_id": get_value(ck, "doc_id", "document_id"),
|
||||||
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
|
"document_name": get_value(ck, "docnm_kwd", "document_name"),
|
||||||
"image_id": get_value(ck, "image_id", "img_id"),
|
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
|
||||||
"positions": get_value(ck, "positions", "position_int"),
|
"image_id": get_value(ck, "image_id", "img_id"),
|
||||||
} for ck in ref.get("chunks", [])]
|
"positions": get_value(ck, "positions", "position_int"),
|
||||||
|
}
|
||||||
|
for ck in ref.get("chunks", [])
|
||||||
|
]
|
||||||
|
|
||||||
conv = conv.to_dict()
|
conv = conv.to_dict()
|
||||||
conv["avatar"]=avatar
|
conv["avatar"] = avatar
|
||||||
return get_json_result(data=conv)
|
return get_json_result(data=conv)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@manager.route('/getsse/<dialog_id>', methods=['GET']) # type: ignore # noqa: F821
|
|
||||||
|
@manager.route("/getsse/<dialog_id>", methods=["GET"]) # type: ignore # noqa: F821
|
||||||
def getsse(dialog_id):
|
def getsse(dialog_id):
|
||||||
|
token = request.headers.get("Authorization").split()
|
||||||
token = request.headers.get('Authorization').split()
|
|
||||||
if len(token) != 2:
|
if len(token) != 2:
|
||||||
return get_data_error_result(message='Authorization is not valid!"')
|
return get_data_error_result(message='Authorization is not valid!"')
|
||||||
token = token[1]
|
token = token[1]
|
||||||
|
@ -131,13 +125,14 @@ def getsse(dialog_id):
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Dialog not found!")
|
return get_data_error_result(message="Dialog not found!")
|
||||||
conv = conv.to_dict()
|
conv = conv.to_dict()
|
||||||
conv["avatar"]= conv["icon"]
|
conv["avatar"] = conv["icon"]
|
||||||
del conv["icon"]
|
del conv["icon"]
|
||||||
return get_json_result(data=conv)
|
return get_json_result(data=conv)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
|
||||||
|
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def rm():
|
def rm():
|
||||||
conv_ids = request.json["conversation_ids"]
|
conv_ids = request.json["conversation_ids"]
|
||||||
|
@ -151,28 +146,21 @@ def rm():
|
||||||
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
|
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||||
data=False, message='Only owner of conversation authorized for this operation.',
|
|
||||||
code=settings.RetCode.OPERATING_ERROR)
|
|
||||||
ConversationService.delete_by_id(cid)
|
ConversationService.delete_by_id(cid)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def list_convsersation():
|
def list_convsersation():
|
||||||
dialog_id = request.args["dialog_id"]
|
dialog_id = request.args["dialog_id"]
|
||||||
try:
|
try:
|
||||||
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
||||||
return get_json_result(
|
return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||||
data=False, message='Only owner of dialog authorized for this operation.',
|
convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True)
|
||||||
code=settings.RetCode.OPERATING_ERROR)
|
|
||||||
convs = ConversationService.query(
|
|
||||||
dialog_id=dialog_id,
|
|
||||||
order_by=ConversationService.model.create_time,
|
|
||||||
reverse=True)
|
|
||||||
|
|
||||||
convs = [d.to_dict() for d in convs]
|
convs = [d.to_dict() for d in convs]
|
||||||
return get_json_result(data=convs)
|
return get_json_result(data=convs)
|
||||||
|
@ -180,7 +168,7 @@ def list_convsersation():
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
@manager.route("/completion", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("conversation_id", "messages")
|
@validate_request("conversation_id", "messages")
|
||||||
def completion():
|
def completion():
|
||||||
|
@ -207,25 +195,30 @@ def completion():
|
||||||
if not conv.reference:
|
if not conv.reference:
|
||||||
conv.reference = []
|
conv.reference = []
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def get_value(d, k1, k2):
|
def get_value(d, k1, k2):
|
||||||
return d.get(k1, d.get(k2))
|
return d.get(k1, d.get(k2))
|
||||||
|
|
||||||
for ref in conv.reference:
|
for ref in conv.reference:
|
||||||
if isinstance(ref, list):
|
if isinstance(ref, list):
|
||||||
continue
|
continue
|
||||||
ref["chunks"] = [{
|
ref["chunks"] = [
|
||||||
"id": get_value(ck, "chunk_id", "id"),
|
{
|
||||||
"content": get_value(ck, "content", "content_with_weight"),
|
"id": get_value(ck, "chunk_id", "id"),
|
||||||
"document_id": get_value(ck, "doc_id", "document_id"),
|
"content": get_value(ck, "content", "content_with_weight"),
|
||||||
"document_name": get_value(ck, "docnm_kwd", "document_name"),
|
"document_id": get_value(ck, "doc_id", "document_id"),
|
||||||
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
|
"document_name": get_value(ck, "docnm_kwd", "document_name"),
|
||||||
"image_id": get_value(ck, "image_id", "img_id"),
|
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
|
||||||
"positions": get_value(ck, "positions", "position_int"),
|
"image_id": get_value(ck, "image_id", "img_id"),
|
||||||
} for ck in ref.get("chunks", [])]
|
"positions": get_value(ck, "positions", "position_int"),
|
||||||
|
}
|
||||||
|
for ck in ref.get("chunks", [])
|
||||||
|
]
|
||||||
|
|
||||||
if not conv.reference:
|
if not conv.reference:
|
||||||
conv.reference = []
|
conv.reference = []
|
||||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||||
|
|
||||||
def stream():
|
def stream():
|
||||||
nonlocal dia, msg, req, conv
|
nonlocal dia, msg, req, conv
|
||||||
try:
|
try:
|
||||||
|
@ -235,9 +228,7 @@ def completion():
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
|
||||||
ensure_ascii=False) + "\n\n"
|
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
if req.get("stream", True):
|
if req.get("stream", True):
|
||||||
|
@ -259,7 +250,32 @@ def completion():
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/tts', methods=['POST']) # noqa: F821
|
# 用于文档撰写模式的问答调用
|
||||||
|
@manager.route("/writechat", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("question", "kb_ids")
|
||||||
|
def writechat():
|
||||||
|
req = request.json
|
||||||
|
uid = current_user.id
|
||||||
|
|
||||||
|
def stream():
|
||||||
|
nonlocal req, uid
|
||||||
|
try:
|
||||||
|
for ans in ask(req["question"], req["kb_ids"], uid):
|
||||||
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||||
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
|
resp = Response(stream(), mimetype="text/event-stream")
|
||||||
|
resp.headers.add_header("Cache-control", "no-cache")
|
||||||
|
resp.headers.add_header("Connection", "keep-alive")
|
||||||
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||||
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def tts():
|
def tts():
|
||||||
req = request.json
|
req = request.json
|
||||||
|
@ -281,9 +297,7 @@ def tts():
|
||||||
for chunk in tts_mdl.tts(txt):
|
for chunk in tts_mdl.tts(txt):
|
||||||
yield chunk
|
yield chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ("data:" + json.dumps({"code": 500, "message": str(e),
|
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
|
||||||
"data": {"answer": "**ERROR**: " + str(e)}},
|
|
||||||
ensure_ascii=False)).encode('utf-8')
|
|
||||||
|
|
||||||
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||||
resp.headers.add_header("Cache-Control", "no-cache")
|
resp.headers.add_header("Cache-Control", "no-cache")
|
||||||
|
@ -293,7 +307,7 @@ def tts():
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/delete_msg', methods=['POST']) # noqa: F821
|
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("conversation_id", "message_id")
|
@validate_request("conversation_id", "message_id")
|
||||||
def delete_msg():
|
def delete_msg():
|
||||||
|
@ -316,7 +330,7 @@ def delete_msg():
|
||||||
return get_json_result(data=conv)
|
return get_json_result(data=conv)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/thumbup', methods=['POST']) # noqa: F821
|
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("conversation_id", "message_id")
|
@validate_request("conversation_id", "message_id")
|
||||||
def thumbup():
|
def thumbup():
|
||||||
|
@ -343,7 +357,7 @@ def thumbup():
|
||||||
return get_json_result(data=conv)
|
return get_json_result(data=conv)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/ask', methods=['POST']) # noqa: F821
|
@manager.route("/ask", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("question", "kb_ids")
|
@validate_request("question", "kb_ids")
|
||||||
def ask_about():
|
def ask_about():
|
||||||
|
@ -356,9 +370,7 @@ def ask_about():
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid):
|
for ans in ask(req["question"], req["kb_ids"], uid):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
|
||||||
ensure_ascii=False) + "\n\n"
|
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
resp = Response(stream(), mimetype="text/event-stream")
|
resp = Response(stream(), mimetype="text/event-stream")
|
||||||
|
@ -369,7 +381,7 @@ def ask_about():
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/mindmap', methods=['POST']) # noqa: F821
|
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("question", "kb_ids")
|
@validate_request("question", "kb_ids")
|
||||||
def mindmap():
|
def mindmap():
|
||||||
|
@ -382,10 +394,7 @@ def mindmap():
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)
|
||||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
||||||
question = req["question"]
|
question = req["question"]
|
||||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, 0.3, 0.3, aggs=False, rank_feature=label_question(question, [kb]))
|
||||||
0.3, 0.3, aggs=False,
|
|
||||||
rank_feature=label_question(question, [kb])
|
|
||||||
)
|
|
||||||
mindmap = MindMapExtractor(chat_mdl)
|
mindmap = MindMapExtractor(chat_mdl)
|
||||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
||||||
mind_map = mind_map.output
|
mind_map = mind_map.output
|
||||||
|
@ -394,7 +403,7 @@ def mindmap():
|
||||||
return get_json_result(data=mind_map)
|
return get_json_result(data=mind_map)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/related_questions', methods=['POST']) # noqa: F821
|
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("question")
|
@validate_request("question")
|
||||||
def related_questions():
|
def related_questions():
|
||||||
|
@ -425,8 +434,17 @@ Reason:
|
||||||
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f"""
|
ans = chat_mdl.chat(
|
||||||
|
prompt,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"""
|
||||||
Keywords: {question}
|
Keywords: {question}
|
||||||
Related search terms:
|
Related search terms:
|
||||||
"""}], {"temperature": 0.9})
|
""",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
{"temperature": 0.9},
|
||||||
|
)
|
||||||
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
|
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
|
||||||
|
|
|
@ -30,8 +30,7 @@ from api import settings
|
||||||
from rag.app.resume import forbidden_select_fields4resume
|
from rag.app.resume import forbidden_select_fields4resume
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, \
|
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, citation_prompt
|
||||||
citation_prompt
|
|
||||||
from rag.utils import rmSpace, num_tokens_from_string
|
from rag.utils import rmSpace, num_tokens_from_string
|
||||||
from rag.utils.tavily_conn import Tavily
|
from rag.utils.tavily_conn import Tavily
|
||||||
|
|
||||||
|
@ -41,17 +40,13 @@ class DialogService(CommonService):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_list(cls, tenant_id,
|
def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
|
||||||
page_number, items_per_page, orderby, desc, id, name):
|
|
||||||
chats = cls.model.select()
|
chats = cls.model.select()
|
||||||
if id:
|
if id:
|
||||||
chats = chats.where(cls.model.id == id)
|
chats = chats.where(cls.model.id == id)
|
||||||
if name:
|
if name:
|
||||||
chats = chats.where(cls.model.name == name)
|
chats = chats.where(cls.model.name == name)
|
||||||
chats = chats.where(
|
chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value))
|
||||||
(cls.model.tenant_id == tenant_id)
|
|
||||||
& (cls.model.status == StatusEnum.VALID.value)
|
|
||||||
)
|
|
||||||
if desc:
|
if desc:
|
||||||
chats = chats.order_by(cls.model.getter_by(orderby).desc())
|
chats = chats.order_by(cls.model.getter_by(orderby).desc())
|
||||||
else:
|
else:
|
||||||
|
@ -72,13 +67,12 @@ def chat_solo(dialog, messages, stream=True):
|
||||||
tts_mdl = None
|
tts_mdl = None
|
||||||
if prompt_config.get("tts"):
|
if prompt_config.get("tts"):
|
||||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
|
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
||||||
for m in messages if m["role"] != "system"]
|
|
||||||
if stream:
|
if stream:
|
||||||
last_ans = ""
|
last_ans = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||||
answer = ans
|
answer = ans
|
||||||
delta_ans = ans[len(last_ans):]
|
delta_ans = ans[len(last_ans) :]
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
if num_tokens_from_string(delta_ans) < 16:
|
||||||
continue
|
continue
|
||||||
last_ans = answer
|
last_ans = answer
|
||||||
|
@ -159,9 +153,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||||
if p["key"] not in kwargs and not p["optional"]:
|
if p["key"] not in kwargs and not p["optional"]:
|
||||||
raise KeyError("Miss parameter: " + p["key"])
|
raise KeyError("Miss parameter: " + p["key"])
|
||||||
if p["key"] not in kwargs:
|
if p["key"] not in kwargs:
|
||||||
prompt_config["system"] = prompt_config["system"].replace(
|
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
||||||
"{%s}" % p["key"], " ")
|
|
||||||
|
|
||||||
# 不再使用多轮对话优化
|
# 不再使用多轮对话优化
|
||||||
# if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
# if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
||||||
# questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
# questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
||||||
|
@ -190,7 +183,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||||
|
|
||||||
knowledges = []
|
knowledges = []
|
||||||
|
|
||||||
# 不再使用推理
|
# 不再使用推理
|
||||||
# if prompt_config.get("reasoning", False):
|
# if prompt_config.get("reasoning", False):
|
||||||
# reasoner = DeepResearcher(chat_mdl,
|
# reasoner = DeepResearcher(chat_mdl,
|
||||||
|
@ -226,17 +219,24 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||||
# kbinfos["chunks"].insert(0, ck)
|
# kbinfos["chunks"].insert(0, ck)
|
||||||
|
|
||||||
# knowledges = kb_prompt(kbinfos, max_tokens)
|
# knowledges = kb_prompt(kbinfos, max_tokens)
|
||||||
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
kbinfos = retriever.retrieval(
|
||||||
dialog.similarity_threshold,
|
" ".join(questions),
|
||||||
dialog.vector_similarity_weight,
|
embd_mdl,
|
||||||
doc_ids=attachments,
|
tenant_ids,
|
||||||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
|
dialog.kb_ids,
|
||||||
rank_feature=label_question(" ".join(questions), kbs)
|
1,
|
||||||
)
|
dialog.top_n,
|
||||||
|
dialog.similarity_threshold,
|
||||||
|
dialog.vector_similarity_weight,
|
||||||
|
doc_ids=attachments,
|
||||||
|
top=dialog.top_k,
|
||||||
|
aggs=False,
|
||||||
|
rerank_mdl=rerank_mdl,
|
||||||
|
rank_feature=label_question(" ".join(questions), kbs),
|
||||||
|
)
|
||||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||||
|
|
||||||
logging.debug(
|
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
|
||||||
|
|
||||||
retrieval_ts = timer()
|
retrieval_ts = timer()
|
||||||
if not knowledges and prompt_config.get("empty_response"):
|
if not knowledges and prompt_config.get("empty_response"):
|
||||||
|
@ -252,22 +252,19 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||||
prompt4citation = citation_prompt()
|
prompt4citation = citation_prompt()
|
||||||
# 过滤掉 system 角色的消息(因为前面已经单独处理了系统消息)
|
# 过滤掉 system 角色的消息(因为前面已经单独处理了系统消息)
|
||||||
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
|
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
|
||||||
for m in messages if m["role"] != "system"])
|
|
||||||
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
|
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
|
||||||
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
||||||
prompt = msg[0]["content"]
|
prompt = msg[0]["content"]
|
||||||
|
|
||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
gen_conf["max_tokens"] = min(
|
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
|
||||||
gen_conf["max_tokens"],
|
|
||||||
max_tokens - used_token_count)
|
|
||||||
|
|
||||||
def decorate_answer(answer):
|
def decorate_answer(answer):
|
||||||
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
|
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
|
||||||
|
|
||||||
refs = []
|
refs = []
|
||||||
image_markdowns = [] # 用于存储图片的 Markdown 字符串
|
image_markdowns = [] # 用于存储图片的 Markdown 字符串
|
||||||
ans = answer.split("</think>")
|
ans = answer.split("</think>")
|
||||||
think = ""
|
think = ""
|
||||||
if len(ans) == 2:
|
if len(ans) == 2:
|
||||||
|
@ -275,29 +272,29 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||||
answer = ans[1]
|
answer = ans[1]
|
||||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||||
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
|
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
|
||||||
cited_chunk_indices = set() # 用于存储被引用的 chunk 索引
|
cited_chunk_indices = set() # 用于存储被引用的 chunk 索引
|
||||||
if not re.search(r"##[0-9]+\$\$", answer):
|
if not re.search(r"##[0-9]+\$\$", answer):
|
||||||
answer, idx = retriever.insert_citations(answer,
|
answer, idx = retriever.insert_citations(
|
||||||
[ck["content_ltks"]
|
answer,
|
||||||
for ck in kbinfos["chunks"]],
|
[ck["content_ltks"] for ck in kbinfos["chunks"]],
|
||||||
[ck["vector"]
|
[ck["vector"] for ck in kbinfos["chunks"]],
|
||||||
for ck in kbinfos["chunks"]],
|
embd_mdl,
|
||||||
embd_mdl,
|
tkweight=1 - dialog.vector_similarity_weight,
|
||||||
tkweight=1 - dialog.vector_similarity_weight,
|
vtweight=dialog.vector_similarity_weight,
|
||||||
vtweight=dialog.vector_similarity_weight)
|
)
|
||||||
cited_chunk_indices = idx # 获取 insert_citations 返回的索引
|
cited_chunk_indices = idx # 获取 insert_citations 返回的索引
|
||||||
|
|
||||||
else:
|
else:
|
||||||
idx = set([])
|
idx = set([])
|
||||||
for r in re.finditer(r"##([0-9]+)\$\$", answer):
|
for r in re.finditer(r"##([0-9]+)\$\$", answer):
|
||||||
i = int(r.group(1))
|
i = int(r.group(1))
|
||||||
if i < len(kbinfos["chunks"]):
|
if i < len(kbinfos["chunks"]):
|
||||||
idx.add(i)
|
idx.add(i)
|
||||||
cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引
|
cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引
|
||||||
|
|
||||||
# 根据引用的 chunk 索引提取图像信息并生成 Markdown
|
# 根据引用的 chunk 索引提取图像信息并生成 Markdown
|
||||||
cited_doc_ids = set()
|
cited_doc_ids = set()
|
||||||
processed_image_urls = set() # 避免重复添加同一张图片
|
processed_image_urls = set() # 避免重复添加同一张图片
|
||||||
print(f"DEBUG: cited_chunk_indices = {cited_chunk_indices}")
|
print(f"DEBUG: cited_chunk_indices = {cited_chunk_indices}")
|
||||||
for i in cited_chunk_indices:
|
for i in cited_chunk_indices:
|
||||||
i_int = int(i)
|
i_int = int(i)
|
||||||
|
@ -312,11 +309,10 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||||
# 生成 Markdown 字符串,alt text 可以简单设为 "image" 或 chunk ID
|
# 生成 Markdown 字符串,alt text 可以简单设为 "image" 或 chunk ID
|
||||||
alt_text = f"image_chunk_{chunk.get('chunk_id', i_int)}"
|
alt_text = f"image_chunk_{chunk.get('chunk_id', i_int)}"
|
||||||
image_markdowns.append(f"\n")
|
image_markdowns.append(f"\n")
|
||||||
processed_image_urls.add(img_url) # 标记为已处理
|
processed_image_urls.add(img_url) # 标记为已处理
|
||||||
|
|
||||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||||
recall_docs = [
|
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
|
||||||
if not recall_docs:
|
if not recall_docs:
|
||||||
recall_docs = kbinfos["doc_aggs"]
|
recall_docs = kbinfos["doc_aggs"]
|
||||||
kbinfos["doc_aggs"] = recall_docs
|
kbinfos["doc_aggs"] = recall_docs
|
||||||
|
@ -325,7 +321,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||||
for c in refs["chunks"]:
|
for c in refs["chunks"]:
|
||||||
if c.get("vector"):
|
if c.get("vector"):
|
||||||
del c["vector"]
|
del c["vector"]
|
||||||
|
|
||||||
# 将图片的 Markdown 字符串追加到回答末尾
|
# 将图片的 Markdown 字符串追加到回答末尾
|
||||||
if image_markdowns:
|
if image_markdowns:
|
||||||
answer += "".join(image_markdowns)
|
answer += "".join(image_markdowns)
|
||||||
|
@ -347,30 +343,30 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||||
|
|
||||||
prompt += "\n\n### Query:\n%s" % " ".join(questions)
|
prompt += "\n\n### Query:\n%s" % " ".join(questions)
|
||||||
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
|
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
|
||||||
return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
|
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
last_ans = "" # 记录上一次返回的完整回答
|
last_ans = "" # 记录上一次返回的完整回答
|
||||||
answer = "" # 当前累计的完整回答
|
answer = "" # 当前累计的完整回答
|
||||||
for ans in chat_mdl.chat_streamly(prompt+prompt4citation, msg[1:], gen_conf):
|
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
||||||
# 如果存在思考过程(thought),移除相关标记
|
# 如果存在思考过程(thought),移除相关标记
|
||||||
if thought:
|
if thought:
|
||||||
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||||||
answer = ans
|
answer = ans
|
||||||
# 计算新增的文本片段(delta)
|
# 计算新增的文本片段(delta)
|
||||||
delta_ans = ans[len(last_ans):]
|
delta_ans = ans[len(last_ans) :]
|
||||||
# 如果新增token太少(小于16),跳过本次返回(避免频繁发送小片段)
|
# 如果新增token太少(小于16),跳过本次返回(避免频繁发送小片段)
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
if num_tokens_from_string(delta_ans) < 16:
|
||||||
continue
|
continue
|
||||||
last_ans = answer
|
last_ans = answer
|
||||||
# 返回当前累计回答(包含思考过程)+新增片段)
|
# 返回当前累计回答(包含思考过程)+新增片段)
|
||||||
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
delta_ans = answer[len(last_ans):]
|
delta_ans = answer[len(last_ans) :]
|
||||||
if delta_ans:
|
if delta_ans:
|
||||||
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
yield decorate_answer(thought+answer)
|
yield decorate_answer(thought + answer)
|
||||||
else:
|
else:
|
||||||
answer = chat_mdl.chat(prompt+prompt4citation, msg[1:], gen_conf)
|
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||||
user_content = msg[-1].get("content", "[content not available]")
|
user_content = msg[-1].get("content", "[content not available]")
|
||||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||||
res = decorate_answer(answer)
|
res = decorate_answer(answer)
|
||||||
|
@ -388,27 +384,22 @@ Table of database fields are as follows:
|
||||||
Question are as follows:
|
Question are as follows:
|
||||||
{}
|
{}
|
||||||
Please write the SQL, only SQL, without any other explanations or text.
|
Please write the SQL, only SQL, without any other explanations or text.
|
||||||
""".format(
|
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
|
||||||
index_name(tenant_id),
|
|
||||||
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
|
||||||
question
|
|
||||||
)
|
|
||||||
tried_times = 0
|
tried_times = 0
|
||||||
|
|
||||||
def get_table():
|
def get_table():
|
||||||
nonlocal sys_prompt, user_prompt, question, tried_times
|
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
|
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||||
"temperature": 0.06})
|
|
||||||
sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL)
|
sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL)
|
||||||
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
||||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||||
sql = re.sub(r".*select ", "select ", sql.lower())
|
sql = re.sub(r".*select ", "select ", sql.lower())
|
||||||
sql = re.sub(r" +", " ", sql)
|
sql = re.sub(r" +", " ", sql)
|
||||||
sql = re.sub(r"([;;]|```).*", "", sql)
|
sql = re.sub(r"([;;]|```).*", "", sql)
|
||||||
if sql[:len("select ")] != "select ":
|
if sql[: len("select ")] != "select ":
|
||||||
return None, None
|
return None, None
|
||||||
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
|
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
|
||||||
if sql[:len("select *")] != "select *":
|
if sql[: len("select *")] != "select *":
|
||||||
sql = "select doc_id,docnm_kwd," + sql[6:]
|
sql = "select doc_id,docnm_kwd," + sql[6:]
|
||||||
else:
|
else:
|
||||||
flds = []
|
flds = []
|
||||||
|
@ -445,11 +436,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||||
{}
|
{}
|
||||||
|
|
||||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
||||||
""".format(
|
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
|
||||||
index_name(tenant_id),
|
|
||||||
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
|
||||||
question, sql, tbl["error"]
|
|
||||||
)
|
|
||||||
tbl, sql = get_table()
|
tbl, sql = get_table()
|
||||||
logging.debug("TRY it again: {}".format(sql))
|
logging.debug("TRY it again: {}".format(sql))
|
||||||
|
|
||||||
|
@ -457,24 +444,18 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||||
if tbl.get("error") or len(tbl["rows"]) == 0:
|
if tbl.get("error") or len(tbl["rows"]) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
docid_idx = set([ii for ii, c in enumerate(
|
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||||
tbl["columns"]) if c["name"] == "doc_id"])
|
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||||
doc_name_idx = set([ii for ii, c in enumerate(
|
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
|
||||||
tbl["columns"]) if c["name"] == "docnm_kwd"])
|
|
||||||
column_idx = [ii for ii in range(
|
|
||||||
len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
|
|
||||||
|
|
||||||
# compose Markdown table
|
# compose Markdown table
|
||||||
columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
columns = (
|
||||||
tbl["columns"][i]["name"])) for i in
|
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||||
column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
)
|
||||||
|
|
||||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \
|
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||||
("|------|" if docid_idx and docid_idx else "")
|
|
||||||
|
|
||||||
rows = ["|" +
|
rows = ["|" + "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
|
||||||
"|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
|
|
||||||
"|" for r in tbl["rows"]]
|
|
||||||
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
||||||
if quota:
|
if quota:
|
||||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||||
|
@ -484,11 +465,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||||
|
|
||||||
if not docid_idx or not doc_name_idx:
|
if not docid_idx or not doc_name_idx:
|
||||||
logging.warning("SQL missing field: " + sql)
|
logging.warning("SQL missing field: " + sql)
|
||||||
return {
|
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
||||||
"answer": "\n".join([columns, line, rows]),
|
|
||||||
"reference": {"chunks": [], "doc_aggs": []},
|
|
||||||
"prompt": sys_prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
docid_idx = list(docid_idx)[0]
|
docid_idx = list(docid_idx)[0]
|
||||||
doc_name_idx = list(doc_name_idx)[0]
|
doc_name_idx = list(doc_name_idx)[0]
|
||||||
|
@ -499,10 +476,11 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||||
doc_aggs[r[docid_idx]]["count"] += 1
|
doc_aggs[r[docid_idx]]["count"] += 1
|
||||||
return {
|
return {
|
||||||
"answer": "\n".join([columns, line, rows]),
|
"answer": "\n".join([columns, line, rows]),
|
||||||
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
|
"reference": {
|
||||||
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
|
||||||
doc_aggs.items()]},
|
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()],
|
||||||
"prompt": sys_prompt
|
},
|
||||||
|
"prompt": sys_prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -518,12 +496,12 @@ def tts(tts_mdl, text):
|
||||||
def ask(question, kb_ids, tenant_id):
|
def ask(question, kb_ids, tenant_id):
|
||||||
"""
|
"""
|
||||||
处理用户搜索请求,从知识库中检索相关信息并生成回答
|
处理用户搜索请求,从知识库中检索相关信息并生成回答
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
question (str): 用户的问题或查询
|
question (str): 用户的问题或查询
|
||||||
kb_ids (list): 知识库ID列表,指定要搜索的知识库
|
kb_ids (list): 知识库ID列表,指定要搜索的知识库
|
||||||
tenant_id (str): 租户ID,用于权限控制和资源隔离
|
tenant_id (str): 租户ID,用于权限控制和资源隔离
|
||||||
|
|
||||||
流程:
|
流程:
|
||||||
1. 获取指定知识库的信息
|
1. 获取指定知识库的信息
|
||||||
2. 确定使用的嵌入模型
|
2. 确定使用的嵌入模型
|
||||||
|
@ -534,11 +512,11 @@ def ask(question, kb_ids, tenant_id):
|
||||||
7. 构建系统提示词
|
7. 构建系统提示词
|
||||||
8. 生成回答并添加引用标记
|
8. 生成回答并添加引用标记
|
||||||
9. 流式返回生成的回答
|
9. 流式返回生成的回答
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
generator: 生成器对象,产生包含回答和引用信息的字典
|
generator: 生成器对象,产生包含回答和引用信息的字典
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||||
|
|
||||||
|
@ -552,27 +530,24 @@ def ask(question, kb_ids, tenant_id):
|
||||||
max_tokens = chat_mdl.max_length
|
max_tokens = chat_mdl.max_length
|
||||||
# 获取所有知识库的租户ID并去重
|
# 获取所有知识库的租户ID并去重
|
||||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||||
# 调用检索器检索相关文档片段
|
# 调用检索器检索相关文档片段
|
||||||
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
|
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))
|
||||||
1, 12, 0.1, 0.3, aggs=False,
|
# 将检索结果格式化为提示词,并确保不超过模型最大token限制
|
||||||
rank_feature=label_question(question, kbs)
|
|
||||||
)
|
|
||||||
# 将检索结果格式化为提示词,并确保不超过模型最大token限制
|
|
||||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||||
prompt = """
|
prompt = """
|
||||||
Role: You're a smart assistant. Your name is Miss R.
|
角色:你是一个聪明的助手。
|
||||||
Task: Summarize the information from knowledge bases and answer user's question.
|
任务:总结知识库中的信息并回答用户的问题。
|
||||||
Requirements and restriction:
|
要求与限制:
|
||||||
- DO NOT make things up, especially for numbers.
|
- 绝不要捏造内容,尤其是数字。
|
||||||
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
|
- 如果知识库中的信息与用户问题无关,**只需回答:对不起,未提供相关信息。
|
||||||
- Answer with markdown format text.
|
- 使用Markdown格式进行回答。
|
||||||
- Answer in language of user's question.
|
- 使用用户提问所用的语言作答。
|
||||||
- DO NOT make things up, especially for numbers.
|
- 绝不要捏造内容,尤其是数字。
|
||||||
|
|
||||||
### Information from knowledge bases
|
### 来自知识库的信息
|
||||||
%s
|
%s
|
||||||
|
|
||||||
The above is information from knowledge bases.
|
以上是来自知识库的信息。
|
||||||
|
|
||||||
""" % "\n".join(knowledges)
|
""" % "\n".join(knowledges)
|
||||||
msg = [{"role": "user", "content": question}]
|
msg = [{"role": "user", "content": question}]
|
||||||
|
@ -580,17 +555,9 @@ def ask(question, kb_ids, tenant_id):
|
||||||
# 生成完成后添加回答中的引用标记
|
# 生成完成后添加回答中的引用标记
|
||||||
def decorate_answer(answer):
|
def decorate_answer(answer):
|
||||||
nonlocal knowledges, kbinfos, prompt
|
nonlocal knowledges, kbinfos, prompt
|
||||||
answer, idx = retriever.insert_citations(answer,
|
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
|
||||||
[ck["content_ltks"]
|
|
||||||
for ck in kbinfos["chunks"]],
|
|
||||||
[ck["vector"]
|
|
||||||
for ck in kbinfos["chunks"]],
|
|
||||||
embd_mdl,
|
|
||||||
tkweight=0.7,
|
|
||||||
vtweight=0.3)
|
|
||||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||||
recall_docs = [
|
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
|
||||||
if not recall_docs:
|
if not recall_docs:
|
||||||
recall_docs = kbinfos["doc_aggs"]
|
recall_docs = kbinfos["doc_aggs"]
|
||||||
kbinfos["doc_aggs"] = recall_docs
|
kbinfos["doc_aggs"] = recall_docs
|
||||||
|
@ -608,4 +575,4 @@ def ask(question, kb_ids, tenant_id):
|
||||||
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
|
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
|
||||||
answer = ans
|
answer = ans
|
||||||
yield {"answer": answer, "reference": {}}
|
yield {"answer": answer, "reference": {}}
|
||||||
yield decorate_answer(answer)
|
yield decorate_answer(answer)
|
||||||
|
|
|
@ -0,0 +1,123 @@
|
||||||
|
import { Authorization } from '@/constants/authorization';
|
||||||
|
import { IAnswer } from '@/interfaces/database/chat';
|
||||||
|
import { IKnowledge } from '@/interfaces/database/knowledge';
|
||||||
|
import kbService from '@/services/knowledge-service';
|
||||||
|
import api from '@/utils/api';
|
||||||
|
import { getAuthorization } from '@/utils/authorization-util';
|
||||||
|
import { useQuery } from '@tanstack/react-query';
|
||||||
|
import { EventSourceParserStream } from 'eventsource-parser/stream';
|
||||||
|
import { useCallback, useRef, useState } from 'react';
|
||||||
|
|
||||||
|
// 查询知识库数据
|
||||||
|
export const useFetchKnowledgeList = (
|
||||||
|
shouldFilterListWithoutDocument: boolean = false,
|
||||||
|
): {
|
||||||
|
list: IKnowledge[];
|
||||||
|
loading: boolean;
|
||||||
|
} => {
|
||||||
|
const { data, isFetching: loading } = useQuery({
|
||||||
|
queryKey: ['fetchKnowledgeList'],
|
||||||
|
initialData: [],
|
||||||
|
gcTime: 0,
|
||||||
|
queryFn: async () => {
|
||||||
|
const { data } = await kbService.getList();
|
||||||
|
const list = data?.data?.kbs ?? [];
|
||||||
|
return shouldFilterListWithoutDocument
|
||||||
|
? list.filter((x: IKnowledge) => x.chunk_num > 0)
|
||||||
|
: list;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return { list: data, loading };
|
||||||
|
};
|
||||||
|
|
||||||
|
// 发送问答信息
|
||||||
|
export const useSendMessageWithSse = (url: string = api.writeChat) => {
|
||||||
|
const [answer, setAnswer] = useState<IAnswer>({} as IAnswer);
|
||||||
|
const [done, setDone] = useState(true);
|
||||||
|
const timer = useRef<any>();
|
||||||
|
const sseRef = useRef<AbortController>();
|
||||||
|
|
||||||
|
const initializeSseRef = useCallback(() => {
|
||||||
|
sseRef.current = new AbortController();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const resetAnswer = useCallback(() => {
|
||||||
|
if (timer.current) {
|
||||||
|
clearTimeout(timer.current);
|
||||||
|
}
|
||||||
|
timer.current = setTimeout(() => {
|
||||||
|
setAnswer({} as IAnswer);
|
||||||
|
clearTimeout(timer.current);
|
||||||
|
}, 1000);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const send = useCallback(
|
||||||
|
async (
|
||||||
|
body: any,
|
||||||
|
controller?: AbortController,
|
||||||
|
): Promise<{ response: Response; data: ResponseType } | undefined> => {
|
||||||
|
initializeSseRef();
|
||||||
|
try {
|
||||||
|
setDone(false);
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
[Authorization]: getAuthorization(),
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify(body),
|
||||||
|
signal: controller?.signal || sseRef.current?.signal,
|
||||||
|
});
|
||||||
|
|
||||||
|
const res = response.clone().json();
|
||||||
|
|
||||||
|
const reader = response?.body
|
||||||
|
?.pipeThrough(new TextDecoderStream())
|
||||||
|
.pipeThrough(new EventSourceParserStream())
|
||||||
|
.getReader();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const x = await reader?.read();
|
||||||
|
if (x) {
|
||||||
|
const { done, value } = x;
|
||||||
|
if (done) {
|
||||||
|
console.info('done');
|
||||||
|
resetAnswer();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const val = JSON.parse(value?.data || '');
|
||||||
|
const d = val?.data;
|
||||||
|
if (typeof d !== 'boolean') {
|
||||||
|
console.info('data:', d);
|
||||||
|
setAnswer({
|
||||||
|
...d,
|
||||||
|
conversationId: body?.conversation_id,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.warn(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
console.info('done?');
|
||||||
|
setDone(true);
|
||||||
|
resetAnswer();
|
||||||
|
return { data: await res, response };
|
||||||
|
} catch (e) {
|
||||||
|
setDone(true);
|
||||||
|
resetAnswer();
|
||||||
|
|
||||||
|
console.warn(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[initializeSseRef, url, resetAnswer],
|
||||||
|
);
|
||||||
|
|
||||||
|
const stopOutputMessage = useCallback(() => {
|
||||||
|
sseRef.current?.abort();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return { send, answer, done, setDone, resetAnswer, stopOutputMessage };
|
||||||
|
};
|
|
@ -1,5 +1,6 @@
|
||||||
import HightLightMarkdown from '@/components/highlight-markdown';
|
import HightLightMarkdown from '@/components/highlight-markdown';
|
||||||
import { useTranslate } from '@/hooks/common-hooks';
|
import { useTranslate } from '@/hooks/common-hooks';
|
||||||
|
import { useFetchKnowledgeList } from '@/hooks/write-hooks';
|
||||||
import { DeleteOutlined } from '@ant-design/icons';
|
import { DeleteOutlined } from '@ant-design/icons';
|
||||||
import {
|
import {
|
||||||
Button,
|
Button,
|
||||||
|
@ -18,7 +19,6 @@ import {
|
||||||
Space,
|
Space,
|
||||||
Typography,
|
Typography,
|
||||||
} from 'antd';
|
} from 'antd';
|
||||||
import axios from 'axios';
|
|
||||||
import {
|
import {
|
||||||
AlignmentType,
|
AlignmentType,
|
||||||
Document,
|
Document,
|
||||||
|
@ -84,6 +84,10 @@ const Write = () => {
|
||||||
const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBaseItem[]>([]);
|
const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBaseItem[]>([]);
|
||||||
const [isLoadingKbs, setIsLoadingKbs] = useState(false);
|
const [isLoadingKbs, setIsLoadingKbs] = useState(false);
|
||||||
|
|
||||||
|
// 使用 useFetchKnowledgeList hook 获取真实数据
|
||||||
|
const { list: knowledgeList, loading: isLoadingKnowledgeList } =
|
||||||
|
useFetchKnowledgeList(true);
|
||||||
|
|
||||||
const getInitialDefaultTemplateDefinitions = useCallback(
|
const getInitialDefaultTemplateDefinitions = useCallback(
|
||||||
(): TemplateItem[] => [
|
(): TemplateItem[] => [
|
||||||
{
|
{
|
||||||
|
@ -167,55 +171,18 @@ const Write = () => {
|
||||||
loadOrInitializeTemplates();
|
loadOrInitializeTemplates();
|
||||||
}, [loadOrInitializeTemplates]);
|
}, [loadOrInitializeTemplates]);
|
||||||
|
|
||||||
|
// 将 knowledgeList 数据同步到 knowledgeBases 状态
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const fetchKbs = async () => {
|
if (knowledgeList && knowledgeList.length > 0) {
|
||||||
const authorization = localStorage.getItem('Authorization');
|
setKnowledgeBases(
|
||||||
if (!authorization) {
|
knowledgeList.map((kb) => ({
|
||||||
setKnowledgeBases([]);
|
id: kb.id,
|
||||||
return;
|
name: kb.name,
|
||||||
}
|
})),
|
||||||
setIsLoadingKbs(true);
|
);
|
||||||
try {
|
setIsLoadingKbs(isLoadingKnowledgeList);
|
||||||
await new Promise((resolve) => {
|
}
|
||||||
setTimeout(resolve, 500);
|
}, [knowledgeList, isLoadingKnowledgeList]);
|
||||||
});
|
|
||||||
const mockKbs: KnowledgeBaseItem[] = [
|
|
||||||
{
|
|
||||||
id: 'kb_default',
|
|
||||||
name: t('defaultKnowledgeBase', { defaultValue: '默认知识库' }),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'kb_tech',
|
|
||||||
name: t('technicalDocsKnowledgeBase', {
|
|
||||||
defaultValue: '技术文档知识库',
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'kb_product',
|
|
||||||
name: t('productInfoKnowledgeBase', {
|
|
||||||
defaultValue: '产品信息知识库',
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'kb_marketing',
|
|
||||||
name: t('marketingMaterialsKB', { defaultValue: '市场营销材料库' }),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'kb_legal',
|
|
||||||
name: t('legalDocumentsKB', { defaultValue: '法律文件库' }),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
setKnowledgeBases(mockKbs);
|
|
||||||
} catch (error) {
|
|
||||||
console.error('获取知识库失败:', error);
|
|
||||||
message.error(t('fetchKnowledgeBaseFailed'));
|
|
||||||
setKnowledgeBases([]);
|
|
||||||
} finally {
|
|
||||||
setIsLoadingKbs(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
fetchKbs();
|
|
||||||
}, [t]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const loadDraftContent = () => {
|
const loadDraftContent = () => {
|
||||||
|
@ -284,6 +251,7 @@ const Write = () => {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// 删除模板
|
||||||
const handleDeleteTemplate = (templateId: string) => {
|
const handleDeleteTemplate = (templateId: string) => {
|
||||||
try {
|
try {
|
||||||
const updatedTemplates = templates.filter((t) => t.id !== templateId);
|
const updatedTemplates = templates.filter((t) => t.id !== templateId);
|
||||||
|
@ -339,102 +307,6 @@ const Write = () => {
|
||||||
setIsAiLoading(false);
|
setIsAiLoading(false);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const conversationId =
|
|
||||||
Math.random().toString(36).substring(2) + Date.now().toString(36);
|
|
||||||
await axios.post(
|
|
||||||
'v1/conversation/set',
|
|
||||||
{
|
|
||||||
name: '文档撰写对话',
|
|
||||||
is_new: true,
|
|
||||||
conversation_id: conversationId,
|
|
||||||
message: [{ role: 'assistant', content: '新对话' }],
|
|
||||||
},
|
|
||||||
{ headers: { authorization }, signal: controller.signal },
|
|
||||||
);
|
|
||||||
const combinedQuestion = `${aiQuestion}\n\n${t('currentDocumentContextLabel')}:\n${originalContent}`;
|
|
||||||
let lastReceivedContent = '';
|
|
||||||
const response = await axios.post(
|
|
||||||
'/v1/conversation/completion',
|
|
||||||
{
|
|
||||||
conversation_id: conversationId,
|
|
||||||
messages: [{ role: 'user', content: combinedQuestion }],
|
|
||||||
knowledge_base_ids:
|
|
||||||
selectedKnowledgeBases.length > 0
|
|
||||||
? selectedKnowledgeBases
|
|
||||||
: undefined,
|
|
||||||
similarity_threshold: similarityThreshold,
|
|
||||||
keyword_similarity_weight: keywordSimilarityWeight,
|
|
||||||
temperature: modelTemperature,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
timeout: aiAssistantConfig.api.timeout,
|
|
||||||
headers: { authorization },
|
|
||||||
signal: controller.signal,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
if (response.data) {
|
|
||||||
const lines = response.data
|
|
||||||
.split('\n')
|
|
||||||
.filter((line: string) => line.trim());
|
|
||||||
for (let i = 0; i < lines.length; i++) {
|
|
||||||
try {
|
|
||||||
const jsonStr = lines[i].replace('data:', '').trim();
|
|
||||||
const jsonData = JSON.parse(jsonStr);
|
|
||||||
if (jsonData.code === 0 && jsonData.data?.answer) {
|
|
||||||
const answerChunk = jsonData.data.answer;
|
|
||||||
const cleanedAnswerChunk = answerChunk
|
|
||||||
.replace(/<think>[\s\S]*?<\/think>/g, '')
|
|
||||||
.trim();
|
|
||||||
const hasUnclosedThink =
|
|
||||||
cleanedAnswerChunk.includes('<think>') &&
|
|
||||||
(!cleanedAnswerChunk.includes('</think>') ||
|
|
||||||
cleanedAnswerChunk.indexOf('<think>') >
|
|
||||||
cleanedAnswerChunk.lastIndexOf('</think>'));
|
|
||||||
if (cleanedAnswerChunk && !hasUnclosedThink) {
|
|
||||||
const incrementalContent = cleanedAnswerChunk.substring(
|
|
||||||
lastReceivedContent.length,
|
|
||||||
);
|
|
||||||
if (incrementalContent) {
|
|
||||||
lastReceivedContent = cleanedAnswerChunk;
|
|
||||||
let newFullContent,
|
|
||||||
newCursorPosAfterInsertion = cursorPosition;
|
|
||||||
if (initialCursorPos !== null && showCursorIndicator) {
|
|
||||||
newFullContent =
|
|
||||||
beforeCursor + cleanedAnswerChunk + afterCursor;
|
|
||||||
newCursorPosAfterInsertion =
|
|
||||||
initialCursorPos + cleanedAnswerChunk.length;
|
|
||||||
} else {
|
|
||||||
newFullContent = originalContent + cleanedAnswerChunk;
|
|
||||||
newCursorPosAfterInsertion = newFullContent.length;
|
|
||||||
}
|
|
||||||
setContent(newFullContent);
|
|
||||||
setCursorPosition(newCursorPosAfterInsertion);
|
|
||||||
setTimeout(() => {
|
|
||||||
if (textAreaRef.current) {
|
|
||||||
textAreaRef.current.focus();
|
|
||||||
textAreaRef.current.setSelectionRange(
|
|
||||||
newCursorPosAfterInsertion!,
|
|
||||||
newCursorPosAfterInsertion!,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}, 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (parseErr) {
|
|
||||||
console.error('解析单行数据失败:', parseErr);
|
|
||||||
}
|
|
||||||
if (i < lines.length - 1)
|
|
||||||
await new Promise((resolve) => {
|
|
||||||
setTimeout(resolve, 10);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
await axios.post(
|
|
||||||
'/v1/conversation/rm',
|
|
||||||
{ conversation_ids: [conversationId], dialog_id: dialogId },
|
|
||||||
{ headers: { authorization } },
|
|
||||||
);
|
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
console.error('AI助手处理失败:', error);
|
console.error('AI助手处理失败:', error);
|
||||||
if (error.code === 'ECONNABORTED' || error.name === 'AbortError') {
|
if (error.code === 'ECONNABORTED' || error.name === 'AbortError') {
|
||||||
|
@ -455,6 +327,7 @@ const Write = () => {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// 导出为Word
|
||||||
const handleSave = () => {
|
const handleSave = () => {
|
||||||
const selectedTemplateItem = templates.find(
|
const selectedTemplateItem = templates.find(
|
||||||
(item) => item.id === selectedTemplate,
|
(item) => item.id === selectedTemplate,
|
||||||
|
|
|
@ -100,6 +100,8 @@ export default {
|
||||||
getExternalConversation: `${api_host}/api/conversation`,
|
getExternalConversation: `${api_host}/api/conversation`,
|
||||||
completeExternalConversation: `${api_host}/api/completion`,
|
completeExternalConversation: `${api_host}/api/completion`,
|
||||||
uploadAndParseExternal: `${api_host}/api/document/upload_and_parse`,
|
uploadAndParseExternal: `${api_host}/api/document/upload_and_parse`,
|
||||||
|
// 文档撰写模式中的问答API
|
||||||
|
writeChat: `${api_host}/conversation/writechat`,
|
||||||
|
|
||||||
// file manager
|
// file manager
|
||||||
listFile: `${api_host}/file/list`,
|
listFile: `${api_host}/file/list`,
|
||||||
|
|
Loading…
Reference in New Issue