diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 43ad1df..416a738 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -30,13 +30,14 @@ from api.db.services.dialog_service import DialogService, chat, ask from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle, TenantService from api import settings +from api.db.services.write_service import write_dialog from api.utils.api_utils import get_json_result from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from graphrag.general.mind_map_extractor import MindMapExtractor from rag.app.tag import label_question -@manager.route('/set', methods=['POST']) # noqa: F821 +@manager.route("/set", methods=["POST"]) # type: ignore # noqa: F821 @login_required def set_conversation(): req = request.json @@ -50,8 +51,7 @@ def set_conversation(): return get_data_error_result(message="Conversation not found!") e, conv = ConversationService.get_by_id(conv_id) if not e: - return get_data_error_result( - message="Fail to update a conversation!") + return get_data_error_result(message="Fail to update a conversation!") conv = conv.to_dict() return get_json_result(data=conv) except Exception as e: @@ -61,38 +61,30 @@ def set_conversation(): e, dia = DialogService.get_by_id(req["dialog_id"]) if not e: return get_data_error_result(message="Dialog not found") - conv = { - "id": conv_id, - "dialog_id": req["dialog_id"], - "name": req.get("name", "New conversation"), - "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] - } + conv = {"id": conv_id, "dialog_id": req["dialog_id"], "name": req.get("name", "New conversation"), "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]} ConversationService.save(**conv) return get_json_result(data=conv) except Exception as e: return server_error_response(e) -@manager.route('/get', methods=['GET']) # noqa: F821 +@manager.route("/get", methods=["GET"]) # type: ignore # type: ignore # noqa: F821 @login_required def get(): conv_id = request.args["conversation_id"] try: - e, conv = ConversationService.get_by_id(conv_id) if not e: return get_data_error_result(message="Conversation not found!") tenants = UserTenantService.query(user_id=current_user.id) - avatar =None + avatar = None for tenant in tenants: dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id) - if dialog and len(dialog)>0: + if dialog and len(dialog) > 0: avatar = dialog[0].icon break else: - return get_json_result( - data=False, message='Only owner of conversation authorized for this operation.', - code=settings.RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) def get_value(d, k1, k2): return d.get(k1, d.get(k2)) @@ -100,26 +92,29 @@ def get(): for ref in conv.reference: if isinstance(ref, list): continue - ref["chunks"] = [{ - "id": get_value(ck, "chunk_id", "id"), - "content": get_value(ck, "content", "content_with_weight"), - "document_id": get_value(ck, "doc_id", "document_id"), - "document_name": get_value(ck, "docnm_kwd", "document_name"), - "dataset_id": get_value(ck, "kb_id", "dataset_id"), - "image_id": get_value(ck, "image_id", "img_id"), - "positions": get_value(ck, "positions", "position_int"), - } for ck in ref.get("chunks", [])] + ref["chunks"] = [ + { + "id": get_value(ck, "chunk_id", "id"), + "content": get_value(ck, "content", "content_with_weight"), + "document_id": get_value(ck, "doc_id", "document_id"), + "document_name": get_value(ck, "docnm_kwd", "document_name"), + "dataset_id": get_value(ck, "kb_id", "dataset_id"), + "image_id": get_value(ck, "image_id", "img_id"), + "positions": get_value(ck, "positions", "position_int"), + } + for ck in ref.get("chunks", []) + ] conv = conv.to_dict() - conv["avatar"]=avatar + conv["avatar"] = avatar return get_json_result(data=conv) except Exception as e: return server_error_response(e) -@manager.route('/getsse/', methods=['GET']) # type: ignore # noqa: F821 + +@manager.route("/getsse/", methods=["GET"]) # type: ignore # noqa: F821 def getsse(dialog_id): - - token = request.headers.get('Authorization').split() + token = request.headers.get("Authorization").split() if len(token) != 2: return get_data_error_result(message='Authorization is not valid!"') token = token[1] @@ -131,13 +126,14 @@ def getsse(dialog_id): if not e: return get_data_error_result(message="Dialog not found!") conv = conv.to_dict() - conv["avatar"]= conv["icon"] + conv["avatar"] = conv["icon"] del conv["icon"] return get_json_result(data=conv) except Exception as e: return server_error_response(e) -@manager.route('/rm', methods=['POST']) # noqa: F821 + +@manager.route("/rm", methods=["POST"]) # type: ignore # type: ignore # noqa: F821 @login_required def rm(): conv_ids = request.json["conversation_ids"] @@ -151,28 +147,21 @@ def rm(): if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id): break else: - return get_json_result( - data=False, message='Only owner of conversation authorized for this operation.', - code=settings.RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) ConversationService.delete_by_id(cid) return get_json_result(data=True) except Exception as e: return server_error_response(e) -@manager.route('/list', methods=['GET']) # noqa: F821 +@manager.route("/list", methods=["GET"]) # type: ignore # noqa: F821 @login_required def list_convsersation(): dialog_id = request.args["dialog_id"] try: if not DialogService.query(tenant_id=current_user.id, id=dialog_id): - return get_json_result( - data=False, message='Only owner of dialog authorized for this operation.', - code=settings.RetCode.OPERATING_ERROR) - convs = ConversationService.query( - dialog_id=dialog_id, - order_by=ConversationService.model.create_time, - reverse=True) + return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) + convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True) convs = [d.to_dict() for d in convs] return get_json_result(data=convs) @@ -180,7 +169,7 @@ def list_convsersation(): return server_error_response(e) -@manager.route('/completion', methods=['POST']) # noqa: F821 +@manager.route("/completion", methods=["POST"]) # type: ignore # noqa: F821 @login_required @validate_request("conversation_id", "messages") def completion(): @@ -207,25 +196,30 @@ def completion(): if not conv.reference: conv.reference = [] else: + def get_value(d, k1, k2): return d.get(k1, d.get(k2)) for ref in conv.reference: if isinstance(ref, list): continue - ref["chunks"] = [{ - "id": get_value(ck, "chunk_id", "id"), - "content": get_value(ck, "content", "content_with_weight"), - "document_id": get_value(ck, "doc_id", "document_id"), - "document_name": get_value(ck, "docnm_kwd", "document_name"), - "dataset_id": get_value(ck, "kb_id", "dataset_id"), - "image_id": get_value(ck, "image_id", "img_id"), - "positions": get_value(ck, "positions", "position_int"), - } for ck in ref.get("chunks", [])] + ref["chunks"] = [ + { + "id": get_value(ck, "chunk_id", "id"), + "content": get_value(ck, "content", "content_with_weight"), + "document_id": get_value(ck, "doc_id", "document_id"), + "document_name": get_value(ck, "docnm_kwd", "document_name"), + "dataset_id": get_value(ck, "kb_id", "dataset_id"), + "image_id": get_value(ck, "image_id", "img_id"), + "positions": get_value(ck, "positions", "position_int"), + } + for ck in ref.get("chunks", []) + ] if not conv.reference: conv.reference = [] conv.reference.append({"chunks": [], "doc_aggs": []}) + def stream(): nonlocal dia, msg, req, conv try: @@ -235,9 +229,7 @@ def completion(): ConversationService.update_by_id(conv.id, conv.to_dict()) except Exception as e: traceback.print_exc() - yield "data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" if req.get("stream", True): @@ -259,7 +251,32 @@ def completion(): return server_error_response(e) -@manager.route('/tts', methods=['POST']) # noqa: F821 +# 用于文档撰写模式的问答调用 +@manager.route("/writechat", methods=["POST"]) # type: ignore # noqa: F821 +@login_required +@validate_request("question", "kb_ids") +def writechat(): + req = request.json + uid = current_user.id + + def stream(): + nonlocal req, uid + try: + for ans in write_dialog(req["question"], req["kb_ids"], uid): + yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" + except Exception as e: + yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" + + resp = Response(stream(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + +@manager.route("/tts", methods=["POST"]) # type: ignore # noqa: F821 @login_required def tts(): req = request.json @@ -281,9 +298,7 @@ def tts(): for chunk in tts_mdl.tts(txt): yield chunk except Exception as e: - yield ("data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e)}}, - ensure_ascii=False)).encode('utf-8') + yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8") resp = Response(stream_audio(), mimetype="audio/mpeg") resp.headers.add_header("Cache-Control", "no-cache") @@ -293,7 +308,7 @@ def tts(): return resp -@manager.route('/delete_msg', methods=['POST']) # noqa: F821 +@manager.route("/delete_msg", methods=["POST"]) # type: ignore # noqa: F821 @login_required @validate_request("conversation_id", "message_id") def delete_msg(): @@ -316,7 +331,7 @@ def delete_msg(): return get_json_result(data=conv) -@manager.route('/thumbup', methods=['POST']) # noqa: F821 +@manager.route("/thumbup", methods=["POST"]) # type: ignore # noqa: F821 @login_required @validate_request("conversation_id", "message_id") def thumbup(): @@ -343,7 +358,7 @@ def thumbup(): return get_json_result(data=conv) -@manager.route('/ask', methods=['POST']) # noqa: F821 +@manager.route("/ask", methods=["POST"]) # type: ignore # noqa: F821 @login_required @validate_request("question", "kb_ids") def ask_about(): @@ -356,9 +371,7 @@ def ask_about(): for ans in ask(req["question"], req["kb_ids"], uid): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: - yield "data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" resp = Response(stream(), mimetype="text/event-stream") @@ -369,7 +382,7 @@ def ask_about(): return resp -@manager.route('/mindmap', methods=['POST']) # noqa: F821 +@manager.route("/mindmap", methods=["POST"]) # type: ignore # noqa: F821 @login_required @validate_request("question", "kb_ids") def mindmap(): @@ -382,10 +395,7 @@ def mindmap(): embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id) chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) question = req["question"] - ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, - 0.3, 0.3, aggs=False, - rank_feature=label_question(question, [kb]) - ) + ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, 0.3, 0.3, aggs=False, rank_feature=label_question(question, [kb])) mindmap = MindMapExtractor(chat_mdl) mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]]) mind_map = mind_map.output @@ -394,7 +404,7 @@ def mindmap(): return get_json_result(data=mind_map) -@manager.route('/related_questions', methods=['POST']) # noqa: F821 +@manager.route("/related_questions", methods=["POST"]) # type: ignore # noqa: F821 @login_required @validate_request("question") def related_questions(): @@ -425,8 +435,17 @@ Reason: - At the same time, related terms can also help search engines better understand user needs and return more accurate search results. """ - ans = chat_mdl.chat(prompt, [{"role": "user", "content": f""" + ans = chat_mdl.chat( + prompt, + [ + { + "role": "user", + "content": f""" Keywords: {question} Related search terms: - """}], {"temperature": 0.9}) + """, + } + ], + {"temperature": 0.9}, + ) return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 59d4b4f..307b9ff 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -30,8 +30,7 @@ from api import settings from rag.app.resume import forbidden_select_fields4resume from rag.app.tag import label_question from rag.nlp.search import index_name -from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, \ - citation_prompt +from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, citation_prompt from rag.utils import rmSpace, num_tokens_from_string from rag.utils.tavily_conn import Tavily @@ -41,17 +40,13 @@ class DialogService(CommonService): @classmethod @DB.connection_context() - def get_list(cls, tenant_id, - page_number, items_per_page, orderby, desc, id, name): + def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name): chats = cls.model.select() if id: chats = chats.where(cls.model.id == id) if name: chats = chats.where(cls.model.name == name) - chats = chats.where( - (cls.model.tenant_id == tenant_id) - & (cls.model.status == StatusEnum.VALID.value) - ) + chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value)) if desc: chats = chats.order_by(cls.model.getter_by(orderby).desc()) else: @@ -72,13 +67,12 @@ def chat_solo(dialog, messages, stream=True): tts_mdl = None if prompt_config.get("tts"): tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) - msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} - for m in messages if m["role"] != "system"] + msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"] if stream: last_ans = "" for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): answer = ans - delta_ans = ans[len(last_ans):] + delta_ans = ans[len(last_ans) :] if num_tokens_from_string(delta_ans) < 16: continue last_ans = answer @@ -159,9 +153,8 @@ def chat(dialog, messages, stream=True, **kwargs): if p["key"] not in kwargs and not p["optional"]: raise KeyError("Miss parameter: " + p["key"]) if p["key"] not in kwargs: - prompt_config["system"] = prompt_config["system"].replace( - "{%s}" % p["key"], " ") - + prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") + # 不再使用多轮对话优化 # if len(questions) > 1 and prompt_config.get("refine_multiturn"): # questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] @@ -190,7 +183,7 @@ def chat(dialog, messages, stream=True, **kwargs): tenant_ids = list(set([kb.tenant_id for kb in kbs])) knowledges = [] - + # 不再使用推理 # if prompt_config.get("reasoning", False): # reasoner = DeepResearcher(chat_mdl, @@ -226,17 +219,24 @@ def chat(dialog, messages, stream=True, **kwargs): # kbinfos["chunks"].insert(0, ck) # knowledges = kb_prompt(kbinfos, max_tokens) - kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, - dialog.similarity_threshold, - dialog.vector_similarity_weight, - doc_ids=attachments, - top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, - rank_feature=label_question(" ".join(questions), kbs) - ) + kbinfos = retriever.retrieval( + " ".join(questions), + embd_mdl, + tenant_ids, + dialog.kb_ids, + 1, + dialog.top_n, + dialog.similarity_threshold, + dialog.vector_similarity_weight, + doc_ids=attachments, + top=dialog.top_k, + aggs=False, + rerank_mdl=rerank_mdl, + rank_feature=label_question(" ".join(questions), kbs), + ) knowledges = kb_prompt(kbinfos, max_tokens) - logging.debug( - "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) + logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges))) retrieval_ts = timer() if not knowledges and prompt_config.get("empty_response"): @@ -252,22 +252,19 @@ def chat(dialog, messages, stream=True, **kwargs): if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): prompt4citation = citation_prompt() # 过滤掉 system 角色的消息(因为前面已经单独处理了系统消息) - msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} - for m in messages if m["role"] != "system"]) + msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]) used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95)) assert len(msg) >= 2, f"message_fit_in has bug: {msg}" prompt = msg[0]["content"] if "max_tokens" in gen_conf: - gen_conf["max_tokens"] = min( - gen_conf["max_tokens"], - max_tokens - used_token_count) + gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count) def decorate_answer(answer): nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions refs = [] - image_markdowns = [] # 用于存储图片的 Markdown 字符串 + image_markdowns = [] # 用于存储图片的 Markdown 字符串 ans = answer.split("") think = "" if len(ans) == 2: @@ -275,29 +272,29 @@ def chat(dialog, messages, stream=True, **kwargs): answer = ans[1] if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL) - cited_chunk_indices = set() # 用于存储被引用的 chunk 索引 + cited_chunk_indices = set() # 用于存储被引用的 chunk 索引 if not re.search(r"##[0-9]+\$\$", answer): - answer, idx = retriever.insert_citations(answer, - [ck["content_ltks"] - for ck in kbinfos["chunks"]], - [ck["vector"] - for ck in kbinfos["chunks"]], - embd_mdl, - tkweight=1 - dialog.vector_similarity_weight, - vtweight=dialog.vector_similarity_weight) - cited_chunk_indices = idx # 获取 insert_citations 返回的索引 - + answer, idx = retriever.insert_citations( + answer, + [ck["content_ltks"] for ck in kbinfos["chunks"]], + [ck["vector"] for ck in kbinfos["chunks"]], + embd_mdl, + tkweight=1 - dialog.vector_similarity_weight, + vtweight=dialog.vector_similarity_weight, + ) + cited_chunk_indices = idx # 获取 insert_citations 返回的索引 + else: idx = set([]) for r in re.finditer(r"##([0-9]+)\$\$", answer): i = int(r.group(1)) if i < len(kbinfos["chunks"]): idx.add(i) - cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引 + cited_chunk_indices = idx # 获取从 ##...$$ 标记中提取的索引 # 根据引用的 chunk 索引提取图像信息并生成 Markdown cited_doc_ids = set() - processed_image_urls = set() # 避免重复添加同一张图片 + processed_image_urls = set() # 避免重复添加同一张图片 print(f"DEBUG: cited_chunk_indices = {cited_chunk_indices}") for i in cited_chunk_indices: i_int = int(i) @@ -312,11 +309,10 @@ def chat(dialog, messages, stream=True, **kwargs): # 生成 Markdown 字符串,alt text 可以简单设为 "image" 或 chunk ID alt_text = f"image_chunk_{chunk.get('chunk_id', i_int)}" image_markdowns.append(f"\n![{alt_text}]({img_url})") - processed_image_urls.add(img_url) # 标记为已处理 + processed_image_urls.add(img_url) # 标记为已处理 idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) - recall_docs = [ - d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] + recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] if not recall_docs: recall_docs = kbinfos["doc_aggs"] kbinfos["doc_aggs"] = recall_docs @@ -325,7 +321,7 @@ def chat(dialog, messages, stream=True, **kwargs): for c in refs["chunks"]: if c.get("vector"): del c["vector"] - + # 将图片的 Markdown 字符串追加到回答末尾 if image_markdowns: answer += "".join(image_markdowns) @@ -347,30 +343,30 @@ def chat(dialog, messages, stream=True, **kwargs): prompt += "\n\n### Query:\n%s" % " ".join(questions) prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms" - return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()} + return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()} if stream: - last_ans = "" # 记录上一次返回的完整回答 - answer = "" # 当前累计的完整回答 - for ans in chat_mdl.chat_streamly(prompt+prompt4citation, msg[1:], gen_conf): + last_ans = "" # 记录上一次返回的完整回答 + answer = "" # 当前累计的完整回答 + for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): # 如果存在思考过程(thought),移除相关标记 if thought: ans = re.sub(r".*", "", ans, flags=re.DOTALL) answer = ans # 计算新增的文本片段(delta) - delta_ans = ans[len(last_ans):] + delta_ans = ans[len(last_ans) :] # 如果新增token太少(小于16),跳过本次返回(避免频繁发送小片段) if num_tokens_from_string(delta_ans) < 16: continue last_ans = answer # 返回当前累计回答(包含思考过程)+新增片段) - yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} - delta_ans = answer[len(last_ans):] + yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} + delta_ans = answer[len(last_ans) :] if delta_ans: - yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} - yield decorate_answer(thought+answer) + yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} + yield decorate_answer(thought + answer) else: - answer = chat_mdl.chat(prompt+prompt4citation, msg[1:], gen_conf) + answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) res = decorate_answer(answer) @@ -388,27 +384,22 @@ Table of database fields are as follows: Question are as follows: {} Please write the SQL, only SQL, without any other explanations or text. -""".format( - index_name(tenant_id), - "\n".join([f"{k}: {v}" for k, v in field_map.items()]), - question - ) +""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question) tried_times = 0 def get_table(): nonlocal sys_prompt, user_prompt, question, tried_times - sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], { - "temperature": 0.06}) + sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) sql = re.sub(r".*", "", sql, flags=re.DOTALL) logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") sql = re.sub(r"[\r\n]+", " ", sql.lower()) sql = re.sub(r".*select ", "select ", sql.lower()) sql = re.sub(r" +", " ", sql) sql = re.sub(r"([;;]|```).*", "", sql) - if sql[:len("select ")] != "select ": + if sql[: len("select ")] != "select ": return None, None if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()): - if sql[:len("select *")] != "select *": + if sql[: len("select *")] != "select *": sql = "select doc_id,docnm_kwd," + sql[6:] else: flds = [] @@ -445,11 +436,7 @@ Please write the SQL, only SQL, without any other explanations or text. {} Please correct the error and write SQL again, only SQL, without any other explanations or text. - """.format( - index_name(tenant_id), - "\n".join([f"{k}: {v}" for k, v in field_map.items()]), - question, sql, tbl["error"] - ) + """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"]) tbl, sql = get_table() logging.debug("TRY it again: {}".format(sql)) @@ -457,24 +444,18 @@ Please write the SQL, only SQL, without any other explanations or text. if tbl.get("error") or len(tbl["rows"]) == 0: return None - docid_idx = set([ii for ii, c in enumerate( - tbl["columns"]) if c["name"] == "doc_id"]) - doc_name_idx = set([ii for ii, c in enumerate( - tbl["columns"]) if c["name"] == "docnm_kwd"]) - column_idx = [ii for ii in range( - len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)] + docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"]) + doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"]) + column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)] # compose Markdown table - columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], - tbl["columns"][i]["name"])) for i in - column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") + columns = ( + "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") + ) - line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \ - ("|------|" if docid_idx and docid_idx else "") + line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") - rows = ["|" + - "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + - "|" for r in tbl["rows"]] + rows = ["|" + "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]] rows = [r for r in rows if re.sub(r"[ |]+", "", r)] if quota: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) @@ -484,11 +465,7 @@ Please write the SQL, only SQL, without any other explanations or text. if not docid_idx or not doc_name_idx: logging.warning("SQL missing field: " + sql) - return { - "answer": "\n".join([columns, line, rows]), - "reference": {"chunks": [], "doc_aggs": []}, - "prompt": sys_prompt - } + return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt} docid_idx = list(docid_idx)[0] doc_name_idx = list(doc_name_idx)[0] @@ -499,10 +476,11 @@ Please write the SQL, only SQL, without any other explanations or text. doc_aggs[r[docid_idx]]["count"] += 1 return { "answer": "\n".join([columns, line, rows]), - "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], - "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in - doc_aggs.items()]}, - "prompt": sys_prompt + "reference": { + "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], + "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()], + }, + "prompt": sys_prompt, } @@ -518,12 +496,12 @@ def tts(tts_mdl, text): def ask(question, kb_ids, tenant_id): """ 处理用户搜索请求,从知识库中检索相关信息并生成回答 - + 参数: question (str): 用户的问题或查询 kb_ids (list): 知识库ID列表,指定要搜索的知识库 tenant_id (str): 租户ID,用于权限控制和资源隔离 - + 流程: 1. 获取指定知识库的信息 2. 确定使用的嵌入模型 @@ -534,11 +512,11 @@ def ask(question, kb_ids, tenant_id): 7. 构建系统提示词 8. 生成回答并添加引用标记 9. 流式返回生成的回答 - + 返回: generator: 生成器对象,产生包含回答和引用信息的字典 """ - + kbs = KnowledgebaseService.get_by_ids(kb_ids) embedding_list = list(set([kb.embd_id for kb in kbs])) @@ -552,27 +530,24 @@ def ask(question, kb_ids, tenant_id): max_tokens = chat_mdl.max_length # 获取所有知识库的租户ID并去重 tenant_ids = list(set([kb.tenant_id for kb in kbs])) - # 调用检索器检索相关文档片段 - kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, - 1, 12, 0.1, 0.3, aggs=False, - rank_feature=label_question(question, kbs) - ) - # 将检索结果格式化为提示词,并确保不超过模型最大token限制 + # 调用检索器检索相关文档片段 + kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs)) + # 将检索结果格式化为提示词,并确保不超过模型最大token限制 knowledges = kb_prompt(kbinfos, max_tokens) prompt = """ - Role: You're a smart assistant. Your name is Miss R. - Task: Summarize the information from knowledge bases and answer user's question. - Requirements and restriction: - - DO NOT make things up, especially for numbers. - - If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided. - - Answer with markdown format text. - - Answer in language of user's question. - - DO NOT make things up, especially for numbers. + 角色:你是一个聪明的助手。 + 任务:总结知识库中的信息并回答用户的问题。 + 要求与限制: + - 绝不要捏造内容,尤其是数字。 + - 如果知识库中的信息与用户问题无关,**只需回答:对不起,未提供相关信息。 + - 使用Markdown格式进行回答。 + - 使用用户提问所用的语言作答。 + - 绝不要捏造内容,尤其是数字。 - ### Information from knowledge bases + ### 来自知识库的信息 %s - The above is information from knowledge bases. + 以上是来自知识库的信息。 """ % "\n".join(knowledges) msg = [{"role": "user", "content": question}] @@ -580,17 +555,9 @@ def ask(question, kb_ids, tenant_id): # 生成完成后添加回答中的引用标记 def decorate_answer(answer): nonlocal knowledges, kbinfos, prompt - answer, idx = retriever.insert_citations(answer, - [ck["content_ltks"] - for ck in kbinfos["chunks"]], - [ck["vector"] - for ck in kbinfos["chunks"]], - embd_mdl, - tkweight=0.7, - vtweight=0.3) + answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) - recall_docs = [ - d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] + recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] if not recall_docs: recall_docs = kbinfos["doc_aggs"] kbinfos["doc_aggs"] = recall_docs @@ -608,4 +575,4 @@ def ask(question, kb_ids, tenant_id): for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): answer = ans yield {"answer": answer, "reference": {}} - yield decorate_answer(answer) \ No newline at end of file + yield decorate_answer(answer) diff --git a/api/db/services/write_service.py b/api/db/services/write_service.py new file mode 100644 index 0000000..4e06b6a --- /dev/null +++ b/api/db/services/write_service.py @@ -0,0 +1,91 @@ +from api.db import LLMType, ParserType +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.llm_service import LLMBundle +from api import settings +from rag.app.tag import label_question +from rag.prompts import kb_prompt + + +def write_dialog(question, kb_ids, tenant_id): + """ + 处理用户搜索请求,从知识库中检索相关信息并生成回答 + + 参数: + question (str): 用户的问题或查询 + kb_ids (list): 知识库ID列表,指定要搜索的知识库 + tenant_id (str): 租户ID,用于权限控制和资源隔离 + + 流程: + 1. 获取指定知识库的信息 + 2. 确定使用的嵌入模型 + 3. 根据知识库类型选择检索器(普通检索器或知识图谱检索器) + 4. 初始化嵌入模型和聊天模型 + 5. 执行检索操作获取相关文档片段 + 6. 格式化知识库内容作为上下文 + 7. 构建系统提示词 + 8. 生成回答并添加引用标记 + 9. 流式返回生成的回答 + + 返回: + generator: 生成器对象,产生包含回答和引用信息的字典 + """ + + kbs = KnowledgebaseService.get_by_ids(kb_ids) + embedding_list = list(set([kb.embd_id for kb in kbs])) + + is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) + retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler + # 初始化嵌入模型,用于将文本转换为向量表示 + embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) + # 初始化聊天模型,用于生成回答 + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) + # 获取聊天模型的最大token长度,用于控制上下文长度 + max_tokens = chat_mdl.max_length + # 获取所有知识库的租户ID并去重 + tenant_ids = list(set([kb.tenant_id for kb in kbs])) + # 调用检索器检索相关文档片段 + kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs)) + # 将检索结果格式化为提示词,并确保不超过模型最大token限制 + knowledges = kb_prompt(kbinfos, max_tokens) + prompt = """ + 角色:你是一个聪明的助手。 + 任务:总结知识库中的信息并回答用户的问题。 + 要求与限制: + - 绝不要捏造内容,尤其是数字。 + - 如果知识库中的信息与用户问题无关,**只需回答:对不起,未提供相关信息。 + - 使用Markdown格式进行回答。 + - 使用用户提问所用的语言作答。 + - 绝不要捏造内容,尤其是数字。 + + ### 来自知识库的信息 + %s + + 以上是来自知识库的信息。 + + """ % "\n".join(knowledges) + msg = [{"role": "user", "content": question}] + + # 生成完成后添加回答中的引用标记 + # def decorate_answer(answer): + # nonlocal knowledges, kbinfos, prompt + # answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3) + # idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) + # recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] + # if not recall_docs: + # recall_docs = kbinfos["doc_aggs"] + # kbinfos["doc_aggs"] = recall_docs + # refs = deepcopy(kbinfos) + # for c in refs["chunks"]: + # if c.get("vector"): + # del c["vector"] + + # if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: + # answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" + # refs["chunks"] = chunks_format(refs) + # return {"answer": answer, "reference": refs} + + answer = "" + for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): + answer = ans + yield {"answer": answer, "reference": {}} + # yield decorate_answer(answer) diff --git a/management/web/src/common/assets/icons/storage.svg b/management/web/src/common/assets/icons/storage.svg new file mode 100644 index 0000000..0cd231b --- /dev/null +++ b/management/web/src/common/assets/icons/storage.svg @@ -0,0 +1,6 @@ + + + storage-solid + + + \ No newline at end of file diff --git a/web/src/hooks/write-hooks.ts b/web/src/hooks/write-hooks.ts new file mode 100644 index 0000000..3027e7e --- /dev/null +++ b/web/src/hooks/write-hooks.ts @@ -0,0 +1,123 @@ +import { Authorization } from '@/constants/authorization'; +import { IAnswer } from '@/interfaces/database/chat'; +import { IKnowledge } from '@/interfaces/database/knowledge'; +import kbService from '@/services/knowledge-service'; +import api from '@/utils/api'; +import { getAuthorization } from '@/utils/authorization-util'; +import { useQuery } from '@tanstack/react-query'; +import { EventSourceParserStream } from 'eventsource-parser/stream'; +import { useCallback, useRef, useState } from 'react'; + +// 查询知识库数据 +export const useFetchKnowledgeList = ( + shouldFilterListWithoutDocument: boolean = false, +): { + list: IKnowledge[]; + loading: boolean; +} => { + const { data, isFetching: loading } = useQuery({ + queryKey: ['fetchKnowledgeList'], + initialData: [], + gcTime: 0, + queryFn: async () => { + const { data } = await kbService.getList(); + const list = data?.data?.kbs ?? []; + return shouldFilterListWithoutDocument + ? list.filter((x: IKnowledge) => x.chunk_num > 0) + : list; + }, + }); + + return { list: data, loading }; +}; + +// 发送问答信息 +export const useSendMessageWithSse = (url: string = api.writeChat) => { + const [answer, setAnswer] = useState({} as IAnswer); + const [done, setDone] = useState(true); + const timer = useRef(); + const sseRef = useRef(); + + const initializeSseRef = useCallback(() => { + sseRef.current = new AbortController(); + }, []); + + const resetAnswer = useCallback(() => { + if (timer.current) { + clearTimeout(timer.current); + } + timer.current = setTimeout(() => { + setAnswer({} as IAnswer); + clearTimeout(timer.current); + }, 1000); + }, []); + + const send = useCallback( + async ( + body: any, + controller?: AbortController, + ): Promise<{ response: Response; data: ResponseType } | undefined> => { + initializeSseRef(); + try { + setDone(false); + const response = await fetch(url, { + method: 'POST', + headers: { + [Authorization]: getAuthorization(), + 'Content-Type': 'application/json', + }, + body: JSON.stringify(body), + signal: controller?.signal || sseRef.current?.signal, + }); + + const res = response.clone().json(); + + const reader = response?.body + ?.pipeThrough(new TextDecoderStream()) + .pipeThrough(new EventSourceParserStream()) + .getReader(); + + while (true) { + const x = await reader?.read(); + if (x) { + const { done, value } = x; + if (done) { + console.info('done'); + resetAnswer(); + break; + } + try { + const val = JSON.parse(value?.data || ''); + const d = val?.data; + if (typeof d !== 'boolean') { + console.info('data:', d); + setAnswer({ + ...d, + conversationId: body?.conversation_id, + }); + } + } catch (e) { + console.warn(e); + } + } + } + console.info('done?'); + setDone(true); + resetAnswer(); + return { data: await res, response }; + } catch (e) { + setDone(true); + resetAnswer(); + + console.warn(e); + } + }, + [initializeSseRef, url, resetAnswer], + ); + + const stopOutputMessage = useCallback(() => { + sseRef.current?.abort(); + }, []); + + return { send, answer, done, setDone, resetAnswer, stopOutputMessage }; +}; diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index a4e9854..418ed3d 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -591,6 +591,7 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 decisions: '决定事项', actionItems: '行动项', nextMeeting: '下次会议', + noTemplatesAvailable: "没有可用模板", // 模型配置相关 modelConfigurationTitle: "模型配置", knowledgeBaseLabel: "知识库", @@ -601,6 +602,7 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 fetchKnowledgeBaseFailed: "获取知识库列表失败", defaultKnowledgeBase: "默认知识库", technicalDocsKnowledgeBase: "技术文档知识库", + aiRequestFailedError: "问答模型请求失败", }, setting: { profile: '概要', diff --git a/web/src/pages/write/index.tsx b/web/src/pages/write/index.tsx index a8347c2..f618ea7 100644 --- a/web/src/pages/write/index.tsx +++ b/web/src/pages/write/index.tsx @@ -1,8 +1,9 @@ import HightLightMarkdown from '@/components/highlight-markdown'; import { useTranslate } from '@/hooks/common-hooks'; -// 假设 aiAssistantConfig 在实际项目中是正确导入的 -// import { aiAssistantConfig } from '@/pages/write/ai-assistant-config'; -const aiAssistantConfig = { api: { timeout: 30000 } }; // 模拟定义 +import { + useFetchKnowledgeList, + useSendMessageWithSse, +} from '@/hooks/write-hooks'; import { DeleteOutlined } from '@ant-design/icons'; import { @@ -22,7 +23,6 @@ import { Space, Typography, } from 'antd'; -import axios from 'axios'; import { AlignmentType, Document, @@ -32,7 +32,7 @@ import { TextRun, } from 'docx'; import { saveAs } from 'file-saver'; -import { marked, Token, Tokens } from 'marked'; // 从 marked 导入 Token 和 Tokens 类型 +import { marked, Token, Tokens } from 'marked'; import { useCallback, useEffect, useRef, useState } from 'react'; const { Sider, Content } = Layout; @@ -53,22 +53,28 @@ interface KnowledgeBaseItem { name: string; } -// 使用 marked 导出的类型或更精确的自定义类型 type MarkedHeadingToken = Tokens.Heading; type MarkedParagraphToken = Tokens.Paragraph; type MarkedListItem = Tokens.ListItem; type MarkedListToken = Tokens.List; type MarkedSpaceToken = Tokens.Space; +// 定义插入点标记,以便在onChange时识别并移除 +// const INSERTION_MARKER = '【AI内容将插入此处】'; +const INSERTION_MARKER = ''; // 保持为空字符串,不显示实际标记 + const Write = () => { const { t } = useTranslate('write'); const [content, setContent] = useState(''); const [aiQuestion, setAiQuestion] = useState(''); const [isAiLoading, setIsAiLoading] = useState(false); - const [dialogId, setDialogId] = useState(''); + const [dialogId] = useState(''); + // cursorPosition 存储用户点击设定的插入点位置 const [cursorPosition, setCursorPosition] = useState(null); + // showCursorIndicator 现在仅用于控制文档中是否显示 'INSERTION_MARKER', + // 并且一旦设置了光标位置,就希望它保持为 true,除非内容被清空或主动重置。 const [showCursorIndicator, setShowCursorIndicator] = useState(false); - const textAreaRef = useRef(null); + const textAreaRef = useRef(null); // Ant Design Input.TextArea 的 ref 类型 const [templates, setTemplates] = useState([]); const [isTemplateModalVisible, setIsTemplateModalVisible] = useState(false); @@ -87,6 +93,27 @@ const Write = () => { const [modelTemperature, setModelTemperature] = useState(0.7); const [knowledgeBases, setKnowledgeBases] = useState([]); const [isLoadingKbs, setIsLoadingKbs] = useState(false); + const [isStreaming, setIsStreaming] = useState(false); // 标记AI是否正在流式输出 + + // 新增状态和 useRef,用于流式输出管理 + // currentStreamedAiOutput 现在将直接接收 useSendMessageWithSse 返回的累积内容 + const [currentStreamedAiOutput, setCurrentStreamedAiOutput] = useState(''); + // 使用 useRef 存储 AI 插入点前后的内容,以及插入点位置,避免在流式更新中出现闭包陷阱 + const contentBeforeAiInsertionRef = useRef(''); + const contentAfterAiInsertionRef = useRef(''); + const aiInsertionStartPosRef = useRef(null); + + // 使用 useFetchKnowledgeList hook 获取真实数据 + const { list: knowledgeList, loading: isLoadingKnowledgeList } = + useFetchKnowledgeList(true); + + // 使用流式消息发送钩子 + const { + send: sendMessage, + answer, + done, + stopOutputMessage, + } = useSendMessageWithSse(); const getInitialDefaultTemplateDefinitions = useCallback( (): TemplateItem[] => [ @@ -171,70 +198,102 @@ const Write = () => { loadOrInitializeTemplates(); }, [loadOrInitializeTemplates]); + // 将 knowledgeList 数据同步到 knowledgeBases 状态 useEffect(() => { - const fetchKbs = async () => { - const authorization = localStorage.getItem('Authorization'); - if (!authorization) { - setKnowledgeBases([]); - return; + if (knowledgeList && knowledgeList.length > 0) { + setKnowledgeBases( + knowledgeList.map((kb) => ({ + id: kb.id, + name: kb.name, + })), + ); + setIsLoadingKbs(isLoadingKnowledgeList); + } + }, [knowledgeList, isLoadingKnowledgeList]); + + // --- 调整流式响应处理逻辑 --- + // 阶段1: 累积 AI 输出片段,用于实时显示(包括 标签) + // 这个 useEffect 确保 currentStreamedAiOutput 始终是实时更新的、包含 标签的完整内容 + useEffect(() => { + if (isStreaming && answer && answer.answer) { + setCurrentStreamedAiOutput(answer.answer); + } + }, [isStreaming, answer]); + + // 阶段2: 当流式输出完成时 (done 为 true) + // 这个 useEffect 负责在流式输出结束时执行清理和最终内容更新 + useEffect(() => { + if (done) { + setIsStreaming(false); + setIsAiLoading(false); + + // --- Process the final streamed AI output before committing --- + // 关键修改:这里**必须**使用 currentStreamedAiOutput,因为它是在流式过程中积累的、包含 标签的内容 + // answer.answer 可能在 done 阶段已经提前被钩子内部清理过,所以不能依赖它来获取带标签的原始内容。 + let processedAiOutput = currentStreamedAiOutput; + if (processedAiOutput) { + // Regex to remove ... including content + processedAiOutput = processedAiOutput.replace( + /.*?<\/think>/gs, + '', + ); } - setIsLoadingKbs(true); - try { - await new Promise((resolve) => { - setTimeout(resolve, 500); - }); - const mockKbs: KnowledgeBaseItem[] = [ - { - id: 'kb_default', - name: t('defaultKnowledgeBase', { defaultValue: '默认知识库' }), - }, - { - id: 'kb_tech', - name: t('technicalDocsKnowledgeBase', { - defaultValue: '技术文档知识库', - }), - }, - { - id: 'kb_product', - name: t('productInfoKnowledgeBase', { - defaultValue: '产品信息知识库', - }), - }, - { - id: 'kb_marketing', - name: t('marketingMaterialsKB', { defaultValue: '市场营销材料库' }), - }, - { - id: 'kb_legal', - name: t('legalDocumentsKB', { defaultValue: '法律文件库' }), - }, - ]; - setKnowledgeBases(mockKbs); - } catch (error) { - console.error('获取知识库失败:', error); - message.error(t('fetchKnowledgeBaseFailed')); - setKnowledgeBases([]); - } finally { - setIsLoadingKbs(false); + // --- END NEW --- + + // 将最终累积的AI内容(已处理移除标签)和初始文档内容拼接,更新到主内容状态 + setContent((prevContent) => { + if (aiInsertionStartPosRef.current !== null) { + // 使用 useRef 中存储的初始内容和最终处理过的 AI 输出 + const finalContent = + contentBeforeAiInsertionRef.current + + processedAiOutput + + contentAfterAiInsertionRef.current; + return finalContent; + } + return prevContent; + }); + + // AI完成回答后,将光标实际移到新内容末尾 + if ( + textAreaRef.current?.resizableTextArea?.textArea && + aiInsertionStartPosRef.current !== null + ) { + const newCursorPos = + aiInsertionStartPosRef.current + processedAiOutput.length; + textAreaRef.current.resizableTextArea.textArea.selectionStart = + newCursorPos; + textAreaRef.current.resizableTextArea.textArea.selectionEnd = + newCursorPos; + textAreaRef.current.resizableTextArea.textArea.focus(); + setCursorPosition(newCursorPos); } - }; - fetchKbs(); - }, [t]); + + // 清理流式相关的临时状态和 useRef + setCurrentStreamedAiOutput(''); // 清空累积内容 + contentBeforeAiInsertionRef.current = ''; + contentAfterAiInsertionRef.current = ''; + aiInsertionStartPosRef.current = null; + setShowCursorIndicator(true); + } + }, [done, currentStreamedAiOutput]); // 依赖 done 和 currentStreamedAiOutput,确保在 done 时拿到最新的 currentStreamedAiOutput + + // 监听 currentStreamedAiOutput 的变化,实时更新主 content 状态以实现流式显示 + useEffect(() => { + if (isStreaming && aiInsertionStartPosRef.current !== null) { + // 实时更新编辑器内容,保留 标签内容 + setContent( + contentBeforeAiInsertionRef.current + + currentStreamedAiOutput + + contentAfterAiInsertionRef.current, + ); + // 同时更新 cursorPosition,让光标跟随 AI 输出移动(基于包含 think 标签的原始长度) + setCursorPosition( + aiInsertionStartPosRef.current + currentStreamedAiOutput.length, + ); + } + }, [currentStreamedAiOutput, isStreaming, aiInsertionStartPosRef]); useEffect(() => { - const fetchDialogs = async () => { - try { - const authorization = localStorage.getItem('Authorization'); - if (!authorization) return; - const response = await axios.get('/v1/dialog', { - headers: { authorization }, - }); - if (response.data?.data?.length > 0) - setDialogId(response.data.data[0].id); - } catch (error) { - console.error('获取对话列表失败:', error); - } - }; const loadDraftContent = () => { try { const draftContent = localStorage.getItem('writeDraftContent'); @@ -250,13 +309,13 @@ const Write = () => { console.error('加载暂存内容失败:', error); } }; - fetchDialogs(); if (localStorage.getItem(LOCAL_STORAGE_INIT_FLAG_KEY) === 'true') { loadDraftContent(); } }, [content, selectedTemplate, templates]); useEffect(() => { + // 防抖保存,防止频繁写入 localStorage const timer = setTimeout( () => localStorage.setItem('writeDraftContent', content), 1000, @@ -302,6 +361,7 @@ const Write = () => { } }; + // 删除模板 const handleDeleteTemplate = (templateId: string) => { try { const updatedTemplates = templates.filter((t) => t.id !== templateId); @@ -326,6 +386,51 @@ const Write = () => { } }; + // 获取上下文内容的辅助函数 + const getContextContent = ( + cursorPos: number, + currentDocumentContent: string, + maxLength: number = 4000, + ) => { + // 注意: 这里的 currentDocumentContent 传入的是 AI 提问时编辑器里的总内容, + // 而不是 contentBeforeAiInsertionRef + contentAfterAiInsertionRef,因为可能包含标记 + const beforeCursor = currentDocumentContent.substring(0, cursorPos); + const afterCursor = currentDocumentContent.substring(cursorPos); + + // 使用更明显的插入点标记,这个标记是给AI看的,不是给用户看的 + const insertMarker = '[AI 内容插入点]'; + const availableLength = maxLength - insertMarker.length; + + if (currentDocumentContent.length <= availableLength) { + return { + beforeCursor, + afterCursor, + contextContent: beforeCursor + insertMarker + afterCursor, + }; + } + + const halfLength = Math.floor(availableLength / 2); + let finalBefore = beforeCursor; + let finalAfter = afterCursor; + + // 如果前半部分太长,截断并在前面加省略号 + if (beforeCursor.length > halfLength) { + finalBefore = + '...' + beforeCursor.substring(beforeCursor.length - halfLength + 3); + } + + // 如果后半部分太长,截断并在后面加省略号 + if (afterCursor.length > halfLength) { + finalAfter = afterCursor.substring(0, halfLength - 3) + '...'; + } + + return { + beforeCursor, + afterCursor, + contextContent: finalBefore + insertMarker + finalAfter, + }; + }; + const handleAiQuestionSubmit = async ( e: React.KeyboardEvent, ) => { @@ -335,128 +440,106 @@ const Write = () => { message.warning(t('enterYourQuestion')); return; } - if (!dialogId) { - message.error(t('noDialogFound')); + + // 检查是否选择了知识库 + if (selectedKnowledgeBases.length === 0) { + message.warning('请至少选择一个知识库'); return; } - setIsAiLoading(true); - const initialCursorPos = cursorPosition; - const originalContent = content; - let beforeCursor = '', - afterCursor = ''; - if (initialCursorPos !== null && showCursorIndicator) { - beforeCursor = originalContent.substring(0, initialCursorPos); - afterCursor = originalContent.substring(initialCursorPos); + + // 如果AI正在流式输出,停止它,并处理新问题 + if (isStreaming) { + stopOutputMessage(); // 停止当前的流式输出 + setIsStreaming(false); // 立即设置为false,中断流 + setIsAiLoading(false); // 确保加载状态也停止 + + // 中断时立即清除流中的 标签,并更新主内容 + // 这里使用 currentStreamedAiOutput 作为基准来构建中断时的内容, + // 因为它是屏幕上实际显示的,包含了 标签。 + const contentToCleanOnInterrupt = + contentBeforeAiInsertionRef.current + + currentStreamedAiOutput + + contentAfterAiInsertionRef.current; + const cleanedContent = contentToCleanOnInterrupt.replace( + /.*?<\/think>/gs, + '', + ); + setContent(cleanedContent); + + setCurrentStreamedAiOutput(''); // 清除旧的流式内容 + contentBeforeAiInsertionRef.current = ''; // 清理 useRef + contentAfterAiInsertionRef.current = ''; + aiInsertionStartPosRef.current = null; + message.info('已中断上一次AI回答,正在处理新问题...'); + // 稍作延迟,确保状态更新后再处理新问题,防止竞态条件 + await new Promise((resolve) => { + setTimeout(resolve, 100); + }); } - const controller = new AbortController(); - const timeoutId = setTimeout( - () => controller.abort(), - aiAssistantConfig.api.timeout || 30000, + + // 如果当前光标位置无效,提醒用户设置 + if (cursorPosition === null) { + message.warning('请先点击文本框以设置AI内容插入位置。'); + return; + } + + // 捕获 AI 插入点前后的静态内容,存储到 useRef + const currentCursorPos = cursorPosition; + // 此时的 content 应该是用户当前编辑器的实际内容,包括可能存在的INSERTION_MARKER + // 但由于 INSERTION_MARKER 为空,所以就是当前的主 content + contentBeforeAiInsertionRef.current = content.substring( + 0, + currentCursorPos, ); + contentAfterAiInsertionRef.current = content.substring(currentCursorPos); + aiInsertionStartPosRef.current = currentCursorPos; // 记录确切的开始插入位置 + + setIsAiLoading(true); + setIsStreaming(true); // 标记AI开始流式输出 + setCurrentStreamedAiOutput(''); // 清空历史累积内容,为新的流做准备 + try { const authorization = localStorage.getItem('Authorization'); if (!authorization) { message.error(t('loginRequiredError')); setIsAiLoading(false); + setIsStreaming(false); // 停止流式标记 + // 失败时也清理临时状态 + setCurrentStreamedAiOutput(''); + contentBeforeAiInsertionRef.current = ''; + contentAfterAiInsertionRef.current = ''; + aiInsertionStartPosRef.current = null; return; } - const conversationId = - Math.random().toString(36).substring(2) + Date.now().toString(36); - await axios.post( - 'v1/conversation/set', - { - dialog_id: dialogId, - name: '文档撰写对话', - is_new: true, - conversation_id: conversationId, - message: [{ role: 'assistant', content: '新对话' }], - }, - { headers: { authorization }, signal: controller.signal }, - ); - const combinedQuestion = `${aiQuestion}\n\n${t('currentDocumentContextLabel')}:\n${originalContent}`; - let lastReceivedContent = ''; - const response = await axios.post( - '/v1/conversation/completion', - { - conversation_id: conversationId, - messages: [{ role: 'user', content: combinedQuestion }], - knowledge_base_ids: - selectedKnowledgeBases.length > 0 - ? selectedKnowledgeBases - : undefined, - similarity_threshold: similarityThreshold, - keyword_similarity_weight: keywordSimilarityWeight, - temperature: modelTemperature, - }, - { - timeout: aiAssistantConfig.api.timeout, - headers: { authorization }, - signal: controller.signal, - }, - ); - if (response.data) { - const lines = response.data - .split('\n') - .filter((line: string) => line.trim()); - for (let i = 0; i < lines.length; i++) { - try { - const jsonStr = lines[i].replace('data:', '').trim(); - const jsonData = JSON.parse(jsonStr); - if (jsonData.code === 0 && jsonData.data?.answer) { - const answerChunk = jsonData.data.answer; - const cleanedAnswerChunk = answerChunk - .replace(/[\s\S]*?<\/think>/g, '') - .trim(); - const hasUnclosedThink = - cleanedAnswerChunk.includes('') && - (!cleanedAnswerChunk.includes('') || - cleanedAnswerChunk.indexOf('') > - cleanedAnswerChunk.lastIndexOf('')); - if (cleanedAnswerChunk && !hasUnclosedThink) { - const incrementalContent = cleanedAnswerChunk.substring( - lastReceivedContent.length, - ); - if (incrementalContent) { - lastReceivedContent = cleanedAnswerChunk; - let newFullContent, - newCursorPosAfterInsertion = cursorPosition; - if (initialCursorPos !== null && showCursorIndicator) { - newFullContent = - beforeCursor + cleanedAnswerChunk + afterCursor; - newCursorPosAfterInsertion = - initialCursorPos + cleanedAnswerChunk.length; - } else { - newFullContent = originalContent + cleanedAnswerChunk; - newCursorPosAfterInsertion = newFullContent.length; - } - setContent(newFullContent); - setCursorPosition(newCursorPosAfterInsertion); - setTimeout(() => { - if (textAreaRef.current) { - textAreaRef.current.focus(); - textAreaRef.current.setSelectionRange( - newCursorPosAfterInsertion!, - newCursorPosAfterInsertion!, - ); - } - }, 0); - } - } - } - } catch (parseErr) { - console.error('解析单行数据失败:', parseErr); - } - if (i < lines.length - 1) - await new Promise((resolve) => { - setTimeout(resolve, 10); - }); - } + + // 构建请求内容,将上下文内容发送给AI + let questionWithContext = aiQuestion; + + // 只有当用户设置了插入位置时才包含上下文 + if (aiInsertionStartPosRef.current !== null) { + // 传递给 getContextContent 的 content 应该是当前编辑器完整的,包含marker的 + const { contextContent } = getContextContent( + aiInsertionStartPosRef.current, + content, + ); + questionWithContext = `${aiQuestion}\n\n上下文内容:\n${contextContent}`; + } + + // 发送流式请求 + await sendMessage({ + question: questionWithContext, + kb_ids: selectedKnowledgeBases, + dialog_id: dialogId, + similarity_threshold: similarityThreshold, + keyword_similarity_weight: keywordSimilarityWeight, + temperature: modelTemperature, + }); + + setAiQuestion(''); // 清空输入框 + // 重新聚焦文本框,但不是AI问答框,而是主编辑区 + if (textAreaRef.current?.resizableTextArea?.textArea) { + textAreaRef.current.resizableTextArea.textArea.focus(); } - await axios.post( - '/v1/conversation/rm', - { conversation_ids: [conversationId], dialog_id: dialogId }, - { headers: { authorization } }, - ); } catch (error: any) { console.error('AI助手处理失败:', error); if (error.code === 'ECONNABORTED' || error.name === 'AbortError') { @@ -469,14 +552,13 @@ const Write = () => { message.error(t('aiRequestFailedError')); } } finally { - clearTimeout(timeoutId); - setIsAiLoading(false); - setAiQuestion(''); - if (textAreaRef.current) textAreaRef.current.focus(); + // AI加载状态在 done 状态或错误处理中会更新,这里不主动设置为 false + // 只有当 isStreaming 状态完全结束时,才彻底清除临时状态 } } }; + // 导出为Word const handleSave = () => { const selectedTemplateItem = templates.find( (item) => item.id === selectedTemplate, @@ -674,33 +756,125 @@ const Write = () => { } }; - const renderEditor = () => ( - setContent(e.target.value)} - onClick={(e) => { - const target = e.target as HTMLTextAreaElement; - setCursorPosition(target.selectionStart); - setShowCursorIndicator(true); - }} - onKeyUp={(e) => { - const target = e.target as HTMLTextAreaElement; - setCursorPosition(target.selectionStart); - setShowCursorIndicator(true); - }} - placeholder={t('writePlaceholder')} - autoSize={false} - /> - ); + // 修改编辑器渲染函数,添加光标标记 + const renderEditor = () => { + let displayContent = content; // 默认显示主内容状态 + + // 如果 AI 正在流式输出,则动态拼接显示内容 + if (isStreaming && aiInsertionStartPosRef.current !== null) { + // 实时显示时,保留 标签内容 + displayContent = + contentBeforeAiInsertionRef.current + + currentStreamedAiOutput + + contentAfterAiInsertionRef.current; + } else if (showCursorIndicator && cursorPosition !== null) { + // 如果不处于流式输出中,但设置了光标,则显示插入标记 + // (由于 INSERTION_MARKER 为空字符串,这一步实际上不会添加可见标记) + const beforeCursor = content.substring(0, cursorPosition); + const afterCursor = content.substring(cursorPosition); + displayContent = beforeCursor + INSERTION_MARKER + afterCursor; + } + + return ( +
+ { + const currentInputValue = e.target.value; // 获取当前输入框中的完整内容 + const newCursorSelectionStart = e.target.selectionStart; + let finalContent = currentInputValue; + let finalCursorPosition = newCursorSelectionStart; + + // 如果用户在 AI 流式输出时输入,则中断 AI 输出,并“固化”当前内容(清除 标签) + if (isStreaming) { + stopOutputMessage(); // 中断 SSE 连接 + setIsStreaming(false); // 停止流式输出 + setIsAiLoading(false); // 停止加载状态 + + // 此时 currentInputValue 已经包含了所有已流出的 AI 内容 (包括 标签) + // 移除 标签 + const contentWithoutThinkTags = currentInputValue.replace( + /.*?<\/think>/gs, + '', + ); + finalContent = contentWithoutThinkTags; + + // 重新计算光标位置,因为内容长度可能因移除 标签而改变 + const originalLength = currentInputValue.length; + const cleanedLength = finalContent.length; + + // 假设光标是在 AI 插入点之后,或者在用户输入后新位置,需要调整 + // 如果光标在被移除的 区域内部,或者在移除区域之后,需要回退相应长度 + if ( + newCursorSelectionStart > (aiInsertionStartPosRef.current || 0) + ) { + // 假设 aiInsertionStartPosRef.current 是 AI 内容的起始点 + finalCursorPosition = + newCursorSelectionStart - (originalLength - cleanedLength); + // 确保光标不会超出新内容的末尾 + if (finalCursorPosition > cleanedLength) { + finalCursorPosition = cleanedLength; + } + } else { + finalCursorPosition = newCursorSelectionStart; // 光标在 AI 插入点之前,无需调整 + } + + // 清理流式相关的临时状态和 useRef + setCurrentStreamedAiOutput(''); + contentBeforeAiInsertionRef.current = ''; + contentAfterAiInsertionRef.current = ''; + aiInsertionStartPosRef.current = null; + } + + // 检查内容中是否包含 INSERTION_MARKER,如果包含则移除 + // 由于 INSERTION_MARKER 为空字符串,此逻辑块影响很小 + const markerIndex = finalContent.indexOf(INSERTION_MARKER); // 对已处理的 finalContent 进行检查 + if (markerIndex !== -1) { + const contentWithoutMarker = finalContent.replace( + INSERTION_MARKER, + '', + ); + finalContent = contentWithoutMarker; + if (newCursorSelectionStart > markerIndex) { + // 此处的 newCursorSelectionStart 仍然是原始的,需要与 markerIndex 比较 + finalCursorPosition = + finalCursorPosition - INSERTION_MARKER.length; + } + } + + setContent(finalContent); // 更新主内容状态 + setCursorPosition(finalCursorPosition); // 更新光标位置状态 + // 手动设置光标位置 + // 这里不能直接操作 DOM,因为是在 setState 之后,DOM 尚未更新 + // Ant Design Input.TextArea 会在 value 更新后自动处理光标位置 + setShowCursorIndicator(true); // 用户输入时,表明已设置光标位置,持续显示标记 + }} + onClick={(e) => { + const target = e.target as HTMLTextAreaElement; + setCursorPosition(target.selectionStart); + setShowCursorIndicator(true); // 点击时设置光标位置并显示标记 + target.focus(); // 确保点击后立即聚焦 + }} + onKeyUp={(e) => { + const target = e.target as HTMLTextAreaElement; + setCursorPosition(target.selectionStart); + setShowCursorIndicator(true); // 键盘抬起时设置光标位置并显示标记 + }} + placeholder={t('writePlaceholder')} + autoSize={false} + /> +
+ ); + }; const renderPreview = () => (
{ }} > + {/* 预览模式下,通常不显示 标签,所以这里不需要特殊处理 */} {content || t('previewPlaceholder')}
@@ -754,10 +929,10 @@ const Write = () => { return ( {
{ }} style={{ flexShrink: 0 }} > - {isAiLoading && ( -
- {t('aiLoadingMessage')}... -
- )} { onKeyDown={handleAiQuestionSubmit} disabled={isAiLoading} /> + + {/* 插入位置提示 或 AI正在回答时的提示 - 现已常驻显示 */} + {isStreaming ? ( // AI正在回答时优先显示此提示 +
+ ✨ AI正在生成回答,请稍候... +
+ ) : // AI未回答时 + cursorPosition !== null ? ( // 如果光标已设置 +
+ 💡 AI回答将插入到文档光标位置 (第 {cursorPosition} 个字符)。 +
+ ) : ( + // 如果光标未设置 +
+ 👆 请在上方文档中点击,设置AI内容插入位置。 +
+ )} diff --git a/web/src/utils/api.ts b/web/src/utils/api.ts index 61f4343..a845b8d 100644 --- a/web/src/utils/api.ts +++ b/web/src/utils/api.ts @@ -100,6 +100,8 @@ export default { getExternalConversation: `${api_host}/api/conversation`, completeExternalConversation: `${api_host}/api/completion`, uploadAndParseExternal: `${api_host}/api/document/upload_and_parse`, + // 文档撰写模式中的问答API + writeChat: `${api_host}/conversation/writechat`, // file manager listFile: `${api_host}/file/list`,