feat(write): 实现 AI 流式回答功能并优化编辑器交互

- 新增流式消息发送钩子和相关状态管理
- 实时更新编辑器内容,支持 <think> 标签显示
- 优化光标位置管理和内容插入逻辑
- 增加 AI 回答中断处理和用户输入时的流式输出中断
- 调整预览模式下的内容显示
- 优化 AI 回答状态的 UI 提示
This commit is contained in:
zstar 2025-06-04 18:56:30 +08:00
parent cfab4bc7bf
commit a117d68df0
5 changed files with 527 additions and 64 deletions

View File

@ -30,13 +30,14 @@ from api.db.services.dialog_service import DialogService, chat, ask
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantService
from api import settings
from api.db.services.write_service import write_dialog
from api.utils.api_utils import get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.tag import label_question
@manager.route("/set", methods=["POST"]) # noqa: F821
@manager.route("/set", methods=["POST"]) # type: ignore # noqa: F821
@login_required
def set_conversation():
req = request.json
@ -67,7 +68,7 @@ def set_conversation():
return server_error_response(e)
@manager.route("/get", methods=["GET"]) # noqa: F821
@manager.route("/get", methods=["GET"]) # type: ignore # type: ignore # noqa: F821
@login_required
def get():
conv_id = request.args["conversation_id"]
@ -132,7 +133,7 @@ def getsse(dialog_id):
return server_error_response(e)
@manager.route("/rm", methods=["POST"]) # noqa: F821
@manager.route("/rm", methods=["POST"]) # type: ignore # type: ignore # noqa: F821
@login_required
def rm():
conv_ids = request.json["conversation_ids"]
@ -153,7 +154,7 @@ def rm():
return server_error_response(e)
@manager.route("/list", methods=["GET"]) # noqa: F821
@manager.route("/list", methods=["GET"]) # type: ignore # noqa: F821
@login_required
def list_convsersation():
dialog_id = request.args["dialog_id"]
@ -168,7 +169,7 @@ def list_convsersation():
return server_error_response(e)
@manager.route("/completion", methods=["POST"]) # noqa: F821
@manager.route("/completion", methods=["POST"]) # type: ignore # noqa: F821
@login_required
@validate_request("conversation_id", "messages")
def completion():
@ -251,7 +252,7 @@ def completion():
# 用于文档撰写模式的问答调用
@manager.route("/writechat", methods=["POST"]) # noqa: F821
@manager.route("/writechat", methods=["POST"]) # type: ignore # noqa: F821
@login_required
@validate_request("question", "kb_ids")
def writechat():
@ -261,7 +262,7 @@ def writechat():
def stream():
nonlocal req, uid
try:
for ans in ask(req["question"], req["kb_ids"], uid):
for ans in write_dialog(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
@ -275,7 +276,7 @@ def writechat():
return resp
@manager.route("/tts", methods=["POST"]) # noqa: F821
@manager.route("/tts", methods=["POST"]) # type: ignore # noqa: F821
@login_required
def tts():
req = request.json
@ -307,7 +308,7 @@ def tts():
return resp
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
@manager.route("/delete_msg", methods=["POST"]) # type: ignore # noqa: F821
@login_required
@validate_request("conversation_id", "message_id")
def delete_msg():
@ -330,7 +331,7 @@ def delete_msg():
return get_json_result(data=conv)
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
@manager.route("/thumbup", methods=["POST"]) # type: ignore # noqa: F821
@login_required
@validate_request("conversation_id", "message_id")
def thumbup():
@ -357,7 +358,7 @@ def thumbup():
return get_json_result(data=conv)
@manager.route("/ask", methods=["POST"]) # noqa: F821
@manager.route("/ask", methods=["POST"]) # type: ignore # noqa: F821
@login_required
@validate_request("question", "kb_ids")
def ask_about():
@ -381,7 +382,7 @@ def ask_about():
return resp
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
@manager.route("/mindmap", methods=["POST"]) # type: ignore # noqa: F821
@login_required
@validate_request("question", "kb_ids")
def mindmap():
@ -403,7 +404,7 @@ def mindmap():
return get_json_result(data=mind_map)
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
@manager.route("/related_questions", methods=["POST"]) # type: ignore # noqa: F821
@login_required
@validate_request("question")
def related_questions():

View File

@ -0,0 +1,91 @@
from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
from rag.app.tag import label_question
from rag.prompts import kb_prompt
def write_dialog(question, kb_ids, tenant_id):
"""
处理用户搜索请求从知识库中检索相关信息并生成回答
参数:
question (str): 用户的问题或查询
kb_ids (list): 知识库ID列表指定要搜索的知识库
tenant_id (str): 租户ID用于权限控制和资源隔离
流程:
1. 获取指定知识库的信息
2. 确定使用的嵌入模型
3. 根据知识库类型选择检索器(普通检索器或知识图谱检索器)
4. 初始化嵌入模型和聊天模型
5. 执行检索操作获取相关文档片段
6. 格式化知识库内容作为上下文
7. 构建系统提示词
8. 生成回答并添加引用标记
9. 流式返回生成的回答
返回:
generator: 生成器对象产生包含回答和引用信息的字典
"""
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
# 初始化嵌入模型,用于将文本转换为向量表示
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
# 初始化聊天模型,用于生成回答
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
# 获取聊天模型的最大token长度用于控制上下文长度
max_tokens = chat_mdl.max_length
# 获取所有知识库的租户ID并去重
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
# 调用检索器检索相关文档片段
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))
# 将检索结果格式化为提示词并确保不超过模型最大token限制
knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """
角色你是一个聪明的助手
任务总结知识库中的信息并回答用户的问题
要求与限制
- 绝不要捏造内容尤其是数字
- 如果知识库中的信息与用户问题无关**只需回答对不起未提供相关信息
- 使用Markdown格式进行回答
- 使用用户提问所用的语言作答
- 绝不要捏造内容尤其是数字
### 来自知识库的信息
%s
以上是来自知识库的信息
""" % "\n".join(knowledges)
msg = [{"role": "user", "content": question}]
# 生成完成后添加回答中的引用标记
# def decorate_answer(answer):
# nonlocal knowledges, kbinfos, prompt
# answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
# idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
# recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
# if not recall_docs:
# recall_docs = kbinfos["doc_aggs"]
# kbinfos["doc_aggs"] = recall_docs
# refs = deepcopy(kbinfos)
# for c in refs["chunks"]:
# if c.get("vector"):
# del c["vector"]
# if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
# answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
# refs["chunks"] = chunks_format(refs)
# return {"answer": answer, "reference": refs}
answer = ""
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
answer = ans
yield {"answer": answer, "reference": {}}
# yield decorate_answer(answer)

View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
<svg fill="#000000" width="800px" height="800px" viewBox="0 0 36 36" version="1.1" preserveAspectRatio="xMidYMid meet" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>storage-solid</title>
<path class="clr-i-solid clr-i-solid-path-1" d="M17.91,18.28c8.08,0,14.66-1.74,15.09-3.94V8.59c-.43,2.2-7,3.94-15.09,3.94A39.4,39.4,0,0,1,6.25,11V9a39.4,39.4,0,0,0,11.66,1.51C26,10.53,32.52,8.79,33,6.61h0C32.8,3.2,23.52,2.28,18,2.28S3,3.21,3,6.71V29.29c0,3.49,9.43,4.43,15,4.43s15-.93,15-4.43V24.09C32.57,26.28,26,28,17.91,28A39.4,39.4,0,0,1,6.25,26.52v-2A39.4,39.4,0,0,0,17.91,26C26,26,32.57,24.28,33,22.09V16.34c-.43,2.2-7,3.94-15.09,3.94A39.4,39.4,0,0,1,6.25,18.77v-2A39.4,39.4,0,0,0,17.91,18.28Z"></path>
<rect x="0" y="0" width="36" height="36" fill-opacity="0"/>
</svg>

After

Width:  |  Height:  |  Size: 934 B

View File

@ -591,6 +591,7 @@ General实体和关系提取提示来自 GitHub - microsoft/graphrag基于
decisions: '决定事项',
actionItems: '行动项',
nextMeeting: '下次会议',
noTemplatesAvailable: "没有可用模板",
// 模型配置相关
modelConfigurationTitle: "模型配置",
knowledgeBaseLabel: "知识库",
@ -601,6 +602,7 @@ General实体和关系提取提示来自 GitHub - microsoft/graphrag基于
fetchKnowledgeBaseFailed: "获取知识库列表失败",
defaultKnowledgeBase: "默认知识库",
technicalDocsKnowledgeBase: "技术文档知识库",
aiRequestFailedError: "问答模型请求失败",
},
setting: {
profile: '概要',

View File

@ -1,6 +1,10 @@
import HightLightMarkdown from '@/components/highlight-markdown';
import { useTranslate } from '@/hooks/common-hooks';
import { useFetchKnowledgeList } from '@/hooks/write-hooks';
import {
useFetchKnowledgeList,
useSendMessageWithSse,
} from '@/hooks/write-hooks';
import { DeleteOutlined } from '@ant-design/icons';
import {
Button,
@ -33,7 +37,6 @@ import { useCallback, useEffect, useRef, useState } from 'react';
const { Sider, Content } = Layout;
const { Option } = Select;
const aiAssistantConfig = { api: { timeout: 30000 } };
const LOCAL_STORAGE_TEMPLATES_KEY = 'userWriteTemplates_v4_no_restore_final';
const LOCAL_STORAGE_INIT_FLAG_KEY =
@ -56,15 +59,22 @@ type MarkedListItem = Tokens.ListItem;
type MarkedListToken = Tokens.List;
type MarkedSpaceToken = Tokens.Space;
// 定义插入点标记以便在onChange时识别并移除
// const INSERTION_MARKER = '【AI内容将插入此处】';
const INSERTION_MARKER = ''; // 保持为空字符串,不显示实际标记
const Write = () => {
const { t } = useTranslate('write');
const [content, setContent] = useState('');
const [aiQuestion, setAiQuestion] = useState('');
const [isAiLoading, setIsAiLoading] = useState(false);
const [dialogId] = useState('');
// cursorPosition 存储用户点击设定的插入点位置
const [cursorPosition, setCursorPosition] = useState<number | null>(null);
// showCursorIndicator 现在仅用于控制文档中是否显示 'INSERTION_MARKER'
// 并且一旦设置了光标位置,就希望它保持为 true除非内容被清空或主动重置。
const [showCursorIndicator, setShowCursorIndicator] = useState(false);
const textAreaRef = useRef<HTMLTextAreaElement>(null);
const textAreaRef = useRef<any>(null); // Ant Design Input.TextArea 的 ref 类型
const [templates, setTemplates] = useState<TemplateItem[]>([]);
const [isTemplateModalVisible, setIsTemplateModalVisible] = useState(false);
@ -83,11 +93,28 @@ const Write = () => {
const [modelTemperature, setModelTemperature] = useState<number>(0.7);
const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBaseItem[]>([]);
const [isLoadingKbs, setIsLoadingKbs] = useState(false);
const [isStreaming, setIsStreaming] = useState(false); // 标记AI是否正在流式输出
// 新增状态和 useRef用于流式输出管理
// currentStreamedAiOutput 现在将直接接收 useSendMessageWithSse 返回的累积内容
const [currentStreamedAiOutput, setCurrentStreamedAiOutput] = useState('');
// 使用 useRef 存储 AI 插入点前后的内容,以及插入点位置,避免在流式更新中出现闭包陷阱
const contentBeforeAiInsertionRef = useRef('');
const contentAfterAiInsertionRef = useRef('');
const aiInsertionStartPosRef = useRef<number | null>(null);
// 使用 useFetchKnowledgeList hook 获取真实数据
const { list: knowledgeList, loading: isLoadingKnowledgeList } =
useFetchKnowledgeList(true);
// 使用流式消息发送钩子
const {
send: sendMessage,
answer,
done,
stopOutputMessage,
} = useSendMessageWithSse();
const getInitialDefaultTemplateDefinitions = useCallback(
(): TemplateItem[] => [
{
@ -184,6 +211,88 @@ const Write = () => {
}
}, [knowledgeList, isLoadingKnowledgeList]);
// --- 调整流式响应处理逻辑 ---
// 阶段1: 累积 AI 输出片段,用于实时显示(包括 <think> 标签)
// 这个 useEffect 确保 currentStreamedAiOutput 始终是实时更新的、包含 <think> 标签的完整内容
useEffect(() => {
if (isStreaming && answer && answer.answer) {
setCurrentStreamedAiOutput(answer.answer);
}
}, [isStreaming, answer]);
// 阶段2: 当流式输出完成时 (done 为 true)
// 这个 useEffect 负责在流式输出结束时执行清理和最终内容更新
useEffect(() => {
if (done) {
setIsStreaming(false);
setIsAiLoading(false);
// --- Process the final streamed AI output before committing ---
// 关键修改:这里**必须**使用 currentStreamedAiOutput因为它是在流式过程中积累的、包含 <think> 标签的内容
// answer.answer 可能在 done 阶段已经提前被钩子内部清理过,所以不能依赖它来获取带标签的原始内容。
let processedAiOutput = currentStreamedAiOutput;
if (processedAiOutput) {
// Regex to remove <think>...</think> including content
processedAiOutput = processedAiOutput.replace(
/<think>.*?<\/think>/gs,
'',
);
}
// --- END NEW ---
// 将最终累积的AI内容已处理移除<think>标签)和初始文档内容拼接,更新到主内容状态
setContent((prevContent) => {
if (aiInsertionStartPosRef.current !== null) {
// 使用 useRef 中存储的初始内容和最终处理过的 AI 输出
const finalContent =
contentBeforeAiInsertionRef.current +
processedAiOutput +
contentAfterAiInsertionRef.current;
return finalContent;
}
return prevContent;
});
// AI完成回答后将光标实际移到新内容末尾
if (
textAreaRef.current?.resizableTextArea?.textArea &&
aiInsertionStartPosRef.current !== null
) {
const newCursorPos =
aiInsertionStartPosRef.current + processedAiOutput.length;
textAreaRef.current.resizableTextArea.textArea.selectionStart =
newCursorPos;
textAreaRef.current.resizableTextArea.textArea.selectionEnd =
newCursorPos;
textAreaRef.current.resizableTextArea.textArea.focus();
setCursorPosition(newCursorPos);
}
// 清理流式相关的临时状态和 useRef
setCurrentStreamedAiOutput(''); // 清空累积内容
contentBeforeAiInsertionRef.current = '';
contentAfterAiInsertionRef.current = '';
aiInsertionStartPosRef.current = null;
setShowCursorIndicator(true);
}
}, [done, currentStreamedAiOutput]); // 依赖 done 和 currentStreamedAiOutput确保在 done 时拿到最新的 currentStreamedAiOutput
// 监听 currentStreamedAiOutput 的变化,实时更新主 content 状态以实现流式显示
useEffect(() => {
if (isStreaming && aiInsertionStartPosRef.current !== null) {
// 实时更新编辑器内容,保留 <think> 标签内容
setContent(
contentBeforeAiInsertionRef.current +
currentStreamedAiOutput +
contentAfterAiInsertionRef.current,
);
// 同时更新 cursorPosition让光标跟随 AI 输出移动(基于包含 think 标签的原始长度)
setCursorPosition(
aiInsertionStartPosRef.current + currentStreamedAiOutput.length,
);
}
}, [currentStreamedAiOutput, isStreaming, aiInsertionStartPosRef]);
useEffect(() => {
const loadDraftContent = () => {
try {
@ -206,6 +315,7 @@ const Write = () => {
}, [content, selectedTemplate, templates]);
useEffect(() => {
// 防抖保存,防止频繁写入 localStorage
const timer = setTimeout(
() => localStorage.setItem('writeDraftContent', content),
1000,
@ -276,6 +386,51 @@ const Write = () => {
}
};
// 获取上下文内容的辅助函数
const getContextContent = (
cursorPos: number,
currentDocumentContent: string,
maxLength: number = 4000,
) => {
// 注意: 这里的 currentDocumentContent 传入的是 AI 提问时编辑器里的总内容,
// 而不是 contentBeforeAiInsertionRef + contentAfterAiInsertionRef因为可能包含标记
const beforeCursor = currentDocumentContent.substring(0, cursorPos);
const afterCursor = currentDocumentContent.substring(cursorPos);
// 使用更明显的插入点标记这个标记是给AI看的不是给用户看的
const insertMarker = '[AI 内容插入点]';
const availableLength = maxLength - insertMarker.length;
if (currentDocumentContent.length <= availableLength) {
return {
beforeCursor,
afterCursor,
contextContent: beforeCursor + insertMarker + afterCursor,
};
}
const halfLength = Math.floor(availableLength / 2);
let finalBefore = beforeCursor;
let finalAfter = afterCursor;
// 如果前半部分太长,截断并在前面加省略号
if (beforeCursor.length > halfLength) {
finalBefore =
'...' + beforeCursor.substring(beforeCursor.length - halfLength + 3);
}
// 如果后半部分太长,截断并在后面加省略号
if (afterCursor.length > halfLength) {
finalAfter = afterCursor.substring(0, halfLength - 3) + '...';
}
return {
beforeCursor,
afterCursor,
contextContent: finalBefore + insertMarker + finalAfter,
};
};
const handleAiQuestionSubmit = async (
e: React.KeyboardEvent<HTMLTextAreaElement>,
) => {
@ -286,27 +441,105 @@ const Write = () => {
return;
}
setIsAiLoading(true);
const initialCursorPos = cursorPosition;
const originalContent = content;
let beforeCursor = '',
afterCursor = '';
if (initialCursorPos !== null && showCursorIndicator) {
beforeCursor = originalContent.substring(0, initialCursorPos);
afterCursor = originalContent.substring(initialCursorPos);
// 检查是否选择了知识库
if (selectedKnowledgeBases.length === 0) {
message.warning('请至少选择一个知识库');
return;
}
const controller = new AbortController();
const timeoutId = setTimeout(
() => controller.abort(),
aiAssistantConfig.api.timeout || 30000,
// 如果AI正在流式输出停止它并处理新问题
if (isStreaming) {
stopOutputMessage(); // 停止当前的流式输出
setIsStreaming(false); // 立即设置为false中断流
setIsAiLoading(false); // 确保加载状态也停止
// 中断时立即清除流中的 <think> 标签,并更新主内容
// 这里使用 currentStreamedAiOutput 作为基准来构建中断时的内容,
// 因为它是屏幕上实际显示的,包含了 <think> 标签。
const contentToCleanOnInterrupt =
contentBeforeAiInsertionRef.current +
currentStreamedAiOutput +
contentAfterAiInsertionRef.current;
const cleanedContent = contentToCleanOnInterrupt.replace(
/<think>.*?<\/think>/gs,
'',
);
setContent(cleanedContent);
setCurrentStreamedAiOutput(''); // 清除旧的流式内容
contentBeforeAiInsertionRef.current = ''; // 清理 useRef
contentAfterAiInsertionRef.current = '';
aiInsertionStartPosRef.current = null;
message.info('已中断上一次AI回答正在处理新问题...');
// 稍作延迟,确保状态更新后再处理新问题,防止竞态条件
await new Promise((resolve) => {
setTimeout(resolve, 100);
});
}
// 如果当前光标位置无效,提醒用户设置
if (cursorPosition === null) {
message.warning('请先点击文本框以设置AI内容插入位置。');
return;
}
// 捕获 AI 插入点前后的静态内容,存储到 useRef
const currentCursorPos = cursorPosition;
// 此时的 content 应该是用户当前编辑器的实际内容包括可能存在的INSERTION_MARKER
// 但由于 INSERTION_MARKER 为空,所以就是当前的主 content
contentBeforeAiInsertionRef.current = content.substring(
0,
currentCursorPos,
);
contentAfterAiInsertionRef.current = content.substring(currentCursorPos);
aiInsertionStartPosRef.current = currentCursorPos; // 记录确切的开始插入位置
setIsAiLoading(true);
setIsStreaming(true); // 标记AI开始流式输出
setCurrentStreamedAiOutput(''); // 清空历史累积内容,为新的流做准备
try {
const authorization = localStorage.getItem('Authorization');
if (!authorization) {
message.error(t('loginRequiredError'));
setIsAiLoading(false);
setIsStreaming(false); // 停止流式标记
// 失败时也清理临时状态
setCurrentStreamedAiOutput('');
contentBeforeAiInsertionRef.current = '';
contentAfterAiInsertionRef.current = '';
aiInsertionStartPosRef.current = null;
return;
}
// 构建请求内容将上下文内容发送给AI
let questionWithContext = aiQuestion;
// 只有当用户设置了插入位置时才包含上下文
if (aiInsertionStartPosRef.current !== null) {
// 传递给 getContextContent 的 content 应该是当前编辑器完整的包含marker的
const { contextContent } = getContextContent(
aiInsertionStartPosRef.current,
content,
);
questionWithContext = `${aiQuestion}\n\n上下文内容\n${contextContent}`;
}
// 发送流式请求
await sendMessage({
question: questionWithContext,
kb_ids: selectedKnowledgeBases,
dialog_id: dialogId,
similarity_threshold: similarityThreshold,
keyword_similarity_weight: keywordSimilarityWeight,
temperature: modelTemperature,
});
setAiQuestion(''); // 清空输入框
// 重新聚焦文本框但不是AI问答框而是主编辑区
if (textAreaRef.current?.resizableTextArea?.textArea) {
textAreaRef.current.resizableTextArea.textArea.focus();
}
} catch (error: any) {
console.error('AI助手处理失败:', error);
if (error.code === 'ECONNABORTED' || error.name === 'AbortError') {
@ -319,10 +552,8 @@ const Write = () => {
message.error(t('aiRequestFailedError'));
}
} finally {
clearTimeout(timeoutId);
setIsAiLoading(false);
setAiQuestion('');
if (textAreaRef.current) textAreaRef.current.focus();
// AI加载状态在 done 状态或错误处理中会更新,这里不主动设置为 false
// 只有当 isStreaming 状态完全结束时,才彻底清除临时状态
}
}
};
@ -525,33 +756,125 @@ const Write = () => {
}
};
const renderEditor = () => (
<Input.TextArea
ref={textAreaRef}
style={{
height: '100%',
width: '100%',
border: 'none',
padding: 24,
fontSize: 16,
resize: 'none',
}}
value={content}
onChange={(e) => setContent(e.target.value)}
onClick={(e) => {
const target = e.target as HTMLTextAreaElement;
setCursorPosition(target.selectionStart);
setShowCursorIndicator(true);
}}
onKeyUp={(e) => {
const target = e.target as HTMLTextAreaElement;
setCursorPosition(target.selectionStart);
setShowCursorIndicator(true);
}}
placeholder={t('writePlaceholder')}
autoSize={false}
/>
);
// 修改编辑器渲染函数,添加光标标记
const renderEditor = () => {
let displayContent = content; // 默认显示主内容状态
// 如果 AI 正在流式输出,则动态拼接显示内容
if (isStreaming && aiInsertionStartPosRef.current !== null) {
// 实时显示时,保留 <think> 标签内容
displayContent =
contentBeforeAiInsertionRef.current +
currentStreamedAiOutput +
contentAfterAiInsertionRef.current;
} else if (showCursorIndicator && cursorPosition !== null) {
// 如果不处于流式输出中,但设置了光标,则显示插入标记
// (由于 INSERTION_MARKER 为空字符串,这一步实际上不会添加可见标记)
const beforeCursor = content.substring(0, cursorPosition);
const afterCursor = content.substring(cursorPosition);
displayContent = beforeCursor + INSERTION_MARKER + afterCursor;
}
return (
<div style={{ position: 'relative', height: '100%', width: '100%' }}>
<Input.TextArea
ref={textAreaRef}
style={{
height: '100%',
width: '100%',
border: 'none',
padding: 24,
fontSize: 16,
resize: 'none',
}}
value={displayContent} // 使用动态的 displayContent
onChange={(e) => {
const currentInputValue = e.target.value; // 获取当前输入框中的完整内容
const newCursorSelectionStart = e.target.selectionStart;
let finalContent = currentInputValue;
let finalCursorPosition = newCursorSelectionStart;
// 如果用户在 AI 流式输出时输入,则中断 AI 输出,并“固化”当前内容(清除 <think> 标签)
if (isStreaming) {
stopOutputMessage(); // 中断 SSE 连接
setIsStreaming(false); // 停止流式输出
setIsAiLoading(false); // 停止加载状态
// 此时 currentInputValue 已经包含了所有已流出的 AI 内容 (包括 <think> 标签)
// 移除 <think> 标签
const contentWithoutThinkTags = currentInputValue.replace(
/<think>.*?<\/think>/gs,
'',
);
finalContent = contentWithoutThinkTags;
// 重新计算光标位置,因为内容长度可能因移除 <think> 标签而改变
const originalLength = currentInputValue.length;
const cleanedLength = finalContent.length;
// 假设光标是在 AI 插入点之后,或者在用户输入后新位置,需要调整
// 如果光标在被移除的 <think> 区域内部,或者在移除区域之后,需要回退相应长度
if (
newCursorSelectionStart > (aiInsertionStartPosRef.current || 0)
) {
// 假设 aiInsertionStartPosRef.current 是 AI 内容的起始点
finalCursorPosition =
newCursorSelectionStart - (originalLength - cleanedLength);
// 确保光标不会超出新内容的末尾
if (finalCursorPosition > cleanedLength) {
finalCursorPosition = cleanedLength;
}
} else {
finalCursorPosition = newCursorSelectionStart; // 光标在 AI 插入点之前,无需调整
}
// 清理流式相关的临时状态和 useRef
setCurrentStreamedAiOutput('');
contentBeforeAiInsertionRef.current = '';
contentAfterAiInsertionRef.current = '';
aiInsertionStartPosRef.current = null;
}
// 检查内容中是否包含 INSERTION_MARKER如果包含则移除
// 由于 INSERTION_MARKER 为空字符串,此逻辑块影响很小
const markerIndex = finalContent.indexOf(INSERTION_MARKER); // 对已处理的 finalContent 进行检查
if (markerIndex !== -1) {
const contentWithoutMarker = finalContent.replace(
INSERTION_MARKER,
'',
);
finalContent = contentWithoutMarker;
if (newCursorSelectionStart > markerIndex) {
// 此处的 newCursorSelectionStart 仍然是原始的,需要与 markerIndex 比较
finalCursorPosition =
finalCursorPosition - INSERTION_MARKER.length;
}
}
setContent(finalContent); // 更新主内容状态
setCursorPosition(finalCursorPosition); // 更新光标位置状态
// 手动设置光标位置
// 这里不能直接操作 DOM因为是在 setState 之后DOM 尚未更新
// Ant Design Input.TextArea 会在 value 更新后自动处理光标位置
setShowCursorIndicator(true); // 用户输入时,表明已设置光标位置,持续显示标记
}}
onClick={(e) => {
const target = e.target as HTMLTextAreaElement;
setCursorPosition(target.selectionStart);
setShowCursorIndicator(true); // 点击时设置光标位置并显示标记
target.focus(); // 确保点击后立即聚焦
}}
onKeyUp={(e) => {
const target = e.target as HTMLTextAreaElement;
setCursorPosition(target.selectionStart);
setShowCursorIndicator(true); // 键盘抬起时设置光标位置并显示标记
}}
placeholder={t('writePlaceholder')}
autoSize={false}
/>
</div>
);
};
const renderPreview = () => (
<div
style={{
@ -563,6 +886,7 @@ const Write = () => {
}}
>
<HightLightMarkdown>
{/* 预览模式下,通常不显示 <think> 标签,所以这里不需要特殊处理 */}
{content || t('previewPlaceholder')}
</HightLightMarkdown>
</div>
@ -874,11 +1198,6 @@ const Write = () => {
}}
style={{ flexShrink: 0 }}
>
{isAiLoading && (
<div style={{ textAlign: 'center', marginBottom: 8 }}>
{t('aiLoadingMessage')}...
</div>
)}
<Input.TextArea
placeholder={t('askAI')}
autoSize={{ minRows: 2, maxRows: 5 }}
@ -887,6 +1206,50 @@ const Write = () => {
onKeyDown={handleAiQuestionSubmit}
disabled={isAiLoading}
/>
{/* 插入位置提示 或 AI正在回答时的提示 - 现已常驻显示 */}
{isStreaming ? ( // AI正在回答时优先显示此提示
<div
style={{
fontSize: '12px',
color: '#faad14', // 警告色
padding: '6px 10px',
backgroundColor: '#fffbe6',
borderRadius: '4px',
border: '1px solid #ffe58f',
}}
>
AI正在生成回答...
</div>
) : // AI未回答时
cursorPosition !== null ? ( // 如果光标已设置
<div
style={{
fontSize: '12px',
color: '#666',
padding: '6px 10px',
backgroundColor: '#e6f7ff',
borderRadius: '4px',
border: '1px solid #91d5ff',
}}
>
💡 AI回答将插入到文档光标位置 ( {cursorPosition} )
</div>
) : (
// 如果光标未设置
<div
style={{
fontSize: '12px',
color: '#f5222d', // 错误色,提醒用户
padding: '6px 10px',
backgroundColor: '#fff1f0',
borderRadius: '4px',
border: '1px solid #ffccc7',
}}
>
👆 AI内容插入位置
</div>
)}
</Card>
</Flex>
</Content>