Merge pull request #100 from zstar1003/dev

This commit is contained in:
zstar 2025-05-17 15:30:12 +08:00 committed by GitHub
commit fd7f1140cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 349 additions and 464 deletions

View File

@ -37,7 +37,8 @@ from api.db.services.file_service import FileService
from flask import jsonify, request, Response from flask import jsonify, request, Response
@manager.route('/chats/<chat_id>/sessions', methods=['POST']) # noqa: F821
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required @token_required
def create(tenant_id, chat_id): def create(tenant_id, chat_id):
req = request.json req = request.json
@ -50,7 +51,7 @@ def create(tenant_id, chat_id):
"dialog_id": req["dialog_id"], "dialog_id": req["dialog_id"],
"name": req.get("name", "New session"), "name": req.get("name", "New session"),
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}], "message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}],
"user_id": req.get("user_id", "") "user_id": req.get("user_id", ""),
} }
if not conv.get("name"): if not conv.get("name"):
return get_error_data_result(message="`name` can not be empty.") return get_error_data_result(message="`name` can not be empty.")
@ -59,20 +60,20 @@ def create(tenant_id, chat_id):
if not e: if not e:
return get_error_data_result(message="Fail to create a session!") return get_error_data_result(message="Fail to create a session!")
conv = conv.to_dict() conv = conv.to_dict()
conv['messages'] = conv.pop("message") conv["messages"] = conv.pop("message")
conv["chat_id"] = conv.pop("dialog_id") conv["chat_id"] = conv.pop("dialog_id")
del conv["reference"] del conv["reference"]
return get_result(data=conv) return get_result(data=conv)
@manager.route('/agents/<agent_id>/sessions', methods=['POST']) # noqa: F821 @manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
@token_required @token_required
def create_agent_session(tenant_id, agent_id): def create_agent_session(tenant_id, agent_id):
req = request.json req = request.json
if not request.is_json: if not request.is_json:
req = request.form req = request.form
files = request.files files = request.files
user_id = request.args.get('user_id', '') user_id = request.args.get("user_id", "")
e, cvs = UserCanvasService.get_by_id(agent_id) e, cvs = UserCanvasService.get_by_id(agent_id)
if not e: if not e:
@ -113,7 +114,7 @@ def create_agent_session(tenant_id, agent_id):
ele.pop("value") ele.pop("value")
else: else:
if req is not None and req.get(ele["key"]): if req is not None and req.get(ele["key"]):
ele["value"] = req[ele['key']] ele["value"] = req[ele["key"]]
else: else:
if "value" in ele: if "value" in ele:
ele.pop("value") ele.pop("value")
@ -121,20 +122,13 @@ def create_agent_session(tenant_id, agent_id):
for ans in canvas.run(stream=False): for ans in canvas.run(stream=False):
pass pass
cvs.dsl = json.loads(str(canvas)) cvs.dsl = json.loads(str(canvas))
conv = { conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
"id": get_uuid(),
"dialog_id": cvs.id,
"user_id": user_id,
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent",
"dsl": cvs.dsl
}
API4ConversationService.save(**conv) API4ConversationService.save(**conv)
conv["agent_id"] = conv.pop("dialog_id") conv["agent_id"] = conv.pop("dialog_id")
return get_result(data=conv) return get_result(data=conv)
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT']) # noqa: F821 @manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
@token_required @token_required
def update(tenant_id, chat_id, session_id): def update(tenant_id, chat_id, session_id):
req = request.json req = request.json
@ -156,14 +150,14 @@ def update(tenant_id, chat_id, session_id):
return get_result() return get_result()
@manager.route('/chats/<chat_id>/completions', methods=['POST']) # noqa: F821 @manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
@token_required @token_required
def chat_completion(tenant_id, chat_id): def chat_completion(tenant_id, chat_id):
req = request.json req = request.json
if not req: if not req:
req = {"question": ""} req = {"question": ""}
if not req.get("session_id"): if not req.get("session_id"):
req["question"]="" req["question"] = ""
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(f"You don't own the chat {chat_id}") return get_error_data_result(f"You don't own the chat {chat_id}")
if req.get("session_id"): if req.get("session_id"):
@ -185,7 +179,7 @@ def chat_completion(tenant_id, chat_id):
return get_result(data=answer) return get_result(data=answer)
@manager.route('chats_openai/<chat_id>/chat/completions', methods=['POST']) # noqa: F821 @manager.route("chats_openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
@validate_request("model", "messages") # noqa: F821 @validate_request("model", "messages") # noqa: F821
@token_required @token_required
def chat_completion_openai_like(tenant_id, chat_id): def chat_completion_openai_like(tenant_id, chat_id):
@ -259,39 +253,23 @@ def chat_completion_openai_like(tenant_id, chat_id):
# The choices field on the last chunk will always be an empty array []. # The choices field on the last chunk will always be an empty array [].
def streamed_response_generator(chat_id, dia, msg): def streamed_response_generator(chat_id, dia, msg):
token_used = 0 token_used = 0
should_split_index = 0 answer_cache = ""
response = { response = {
"id": f"chatcmpl-{chat_id}", "id": f"chatcmpl-{chat_id}",
"choices": [ "choices": [{"delta": {"content": "", "role": "assistant", "function_call": None, "tool_calls": None}, "finish_reason": None, "index": 0, "logprobs": None}],
{
"delta": {
"content": "",
"role": "assistant",
"function_call": None,
"tool_calls": None
},
"finish_reason": None,
"index": 0,
"logprobs": None
}
],
"created": int(time.time()), "created": int(time.time()),
"model": "model", "model": "model",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "", "system_fingerprint": "",
"usage": None "usage": None,
} }
try: try:
for ans in chat(dia, msg, True): for ans in chat(dia, msg, True):
answer = ans["answer"] answer = ans["answer"]
incremental = answer[should_split_index:] incremental = answer.replace(answer_cache, "", 1)
answer_cache = answer.rstrip("</think>")
token_used += len(incremental) token_used += len(incremental)
if incremental.endswith("</think>"):
response_data_len = len(incremental.rstrip("</think>"))
else:
response_data_len = len(incremental)
should_split_index += response_data_len
response["choices"][0]["delta"]["content"] = incremental response["choices"][0]["delta"]["content"] = incremental
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
except Exception as e: except Exception as e:
@ -301,15 +279,10 @@ def chat_completion_openai_like(tenant_id, chat_id):
# The last chunk # The last chunk
response["choices"][0]["delta"]["content"] = None response["choices"][0]["delta"]["content"] = None
response["choices"][0]["finish_reason"] = "stop" response["choices"][0]["finish_reason"] = "stop"
response["usage"] = { response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
"prompt_tokens": len(prompt),
"completion_tokens": token_used,
"total_tokens": len(prompt) + token_used
}
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
yield "data:[DONE]\n\n" yield "data:[DONE]\n\n"
resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream") resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("Connection", "keep-alive")
@ -324,7 +297,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
break break
content = answer["answer"] content = answer["answer"]
response = { response = {
"id": f"chatcmpl-{chat_id}", "id": f"chatcmpl-{chat_id}",
"object": "chat.completion", "object": "chat.completion",
"created": int(time.time()), "created": int(time.time()),
@ -336,25 +309,15 @@ def chat_completion_openai_like(tenant_id, chat_id):
"completion_tokens_details": { "completion_tokens_details": {
"reasoning_tokens": context_token_used, "reasoning_tokens": context_token_used,
"accepted_prediction_tokens": len(content), "accepted_prediction_tokens": len(content),
"rejected_prediction_tokens": 0 # 0 for simplicity "rejected_prediction_tokens": 0, # 0 for simplicity
} },
}, },
"choices": [ "choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": "stop", "index": 0}],
{
"message": {
"role": "assistant",
"content": content
},
"logprobs": None,
"finish_reason": "stop",
"index": 0
}
]
} }
return jsonify(response) return jsonify(response)
@manager.route('/agents/<agent_id>/completions', methods=['POST']) # noqa: F821 @manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
@token_required @token_required
def agent_completions(tenant_id, agent_id): def agent_completions(tenant_id, agent_id):
req = request.json req = request.json
@ -365,8 +328,8 @@ def agent_completions(tenant_id, agent_id):
dsl = cvs[0].dsl dsl = cvs[0].dsl
if not isinstance(dsl, str): if not isinstance(dsl, str):
dsl = json.dumps(dsl) dsl = json.dumps(dsl)
#canvas = Canvas(dsl, tenant_id) # canvas = Canvas(dsl, tenant_id)
#if canvas.get_preset_param(): # if canvas.get_preset_param():
# req["question"] = "" # req["question"] = ""
conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id) conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id)
if not conv: if not conv:
@ -380,9 +343,7 @@ def agent_completions(tenant_id, agent_id):
states = {field: current_dsl.get(field, []) for field in state_fields} states = {field: current_dsl.get(field, []) for field in state_fields}
current_dsl.update(new_dsl) current_dsl.update(new_dsl)
current_dsl.update(states) current_dsl.update(states)
API4ConversationService.update_by_id(req["session_id"], { API4ConversationService.update_by_id(req["session_id"], {"dsl": current_dsl})
"dsl": current_dsl
})
else: else:
req["question"] = "" req["question"] = ""
if req.get("stream", True): if req.get("stream", True):
@ -399,7 +360,7 @@ def agent_completions(tenant_id, agent_id):
return get_error_data_result(str(e)) return get_error_data_result(str(e))
@manager.route('/chats/<chat_id>/sessions', methods=['GET']) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
@token_required @token_required
def list_session(tenant_id, chat_id): def list_session(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
@ -418,7 +379,7 @@ def list_session(tenant_id, chat_id):
if not convs: if not convs:
return get_result(data=[]) return get_result(data=[])
for conv in convs: for conv in convs:
conv['messages'] = conv.pop("message") conv["messages"] = conv.pop("message")
infos = conv["messages"] infos = conv["messages"]
for info in infos: for info in infos:
if "prompt" in info: if "prompt" in info:
@ -452,7 +413,7 @@ def list_session(tenant_id, chat_id):
return get_result(data=convs) return get_result(data=convs)
@manager.route('/agents/<agent_id>/sessions', methods=['GET']) # noqa: F821 @manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
@token_required @token_required
def list_agent_session(tenant_id, agent_id): def list_agent_session(tenant_id, agent_id):
if not UserCanvasService.query(user_id=tenant_id, id=agent_id): if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
@ -468,12 +429,11 @@ def list_agent_session(tenant_id, agent_id):
desc = True desc = True
# dsl defaults to True in all cases except for False and false # dsl defaults to True in all cases except for False and false
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false" include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl)
user_id, include_dsl)
if not convs: if not convs:
return get_result(data=[]) return get_result(data=[])
for conv in convs: for conv in convs:
conv['messages'] = conv.pop("message") conv["messages"] = conv.pop("message")
infos = conv["messages"] infos = conv["messages"]
for info in infos: for info in infos:
if "prompt" in info: if "prompt" in info:
@ -506,7 +466,7 @@ def list_agent_session(tenant_id, agent_id):
return get_result(data=convs) return get_result(data=convs)
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required @token_required
def delete(tenant_id, chat_id): def delete(tenant_id, chat_id):
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
@ -532,14 +492,14 @@ def delete(tenant_id, chat_id):
return get_result() return get_result()
@manager.route('/agents/<agent_id>/sessions', methods=["DELETE"]) # noqa: F821 @manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required @token_required
def delete_agent_session(tenant_id, agent_id): def delete_agent_session(tenant_id, agent_id):
req = request.json req = request.json
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id) cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
if not cvs: if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}") return get_error_data_result(f"You don't own the agent {agent_id}")
convs = API4ConversationService.query(dialog_id=agent_id) convs = API4ConversationService.query(dialog_id=agent_id)
if not convs: if not convs:
return get_error_data_result(f"Agent {agent_id} has no sessions") return get_error_data_result(f"Agent {agent_id} has no sessions")
@ -555,16 +515,16 @@ def delete_agent_session(tenant_id, agent_id):
conv_list.append(conv.id) conv_list.append(conv.id)
else: else:
conv_list = ids conv_list = ids
for session_id in conv_list: for session_id in conv_list:
conv = API4ConversationService.query(id=session_id, dialog_id=agent_id) conv = API4ConversationService.query(id=session_id, dialog_id=agent_id)
if not conv: if not conv:
return get_error_data_result(f"The agent doesn't own the session ${session_id}") return get_error_data_result(f"The agent doesn't own the session ${session_id}")
API4ConversationService.delete_by_id(session_id) API4ConversationService.delete_by_id(session_id)
return get_result() return get_result()
@manager.route('/sessions/ask', methods=['POST']) # noqa: F821
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
@token_required @token_required
def ask_about(tenant_id): def ask_about(tenant_id):
req = request.json req = request.json
@ -590,9 +550,7 @@ def ask_about(tenant_id):
for ans in ask(req["question"], req["kb_ids"], uid): for ans in ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e: except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream") resp = Response(stream(), mimetype="text/event-stream")
@ -603,7 +561,7 @@ def ask_about(tenant_id):
return resp return resp
@manager.route('/sessions/related_questions', methods=['POST']) # noqa: F821 @manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
@token_required @token_required
def related_questions(tenant_id): def related_questions(tenant_id):
req = request.json req = request.json
@ -635,18 +593,27 @@ Reason:
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results. - At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
""" """
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f""" ans = chat_mdl.chat(
prompt,
[
{
"role": "user",
"content": f"""
Keywords: {question} Keywords: {question}
Related search terms: Related search terms:
"""}], {"temperature": 0.9}) """,
}
],
{"temperature": 0.9},
)
return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
@manager.route('/chatbots/<dialog_id>/completions', methods=['POST']) # noqa: F821 @manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
def chatbot_completions(dialog_id): def chatbot_completions(dialog_id):
req = request.json req = request.json
token = request.headers.get('Authorization').split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"') return get_error_data_result(message='Authorization is not valid!"')
token = token[1] token = token[1]
@ -669,11 +636,11 @@ def chatbot_completions(dialog_id):
return get_result(data=answer) return get_result(data=answer)
@manager.route('/agentbots/<agent_id>/completions', methods=['POST']) # noqa: F821 @manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
def agent_bot_completions(agent_id): def agent_bot_completions(agent_id):
req = request.json req = request.json
token = request.headers.get('Authorization').split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"') return get_error_data_result(message='Authorization is not valid!"')
token = token[1] token = token[1]

View File

@ -47,20 +47,14 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_list(cls, kb_id, page_number, items_per_page, def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, id, name):
orderby, desc, keywords, id, name):
docs = cls.model.select().where(cls.model.kb_id == kb_id) docs = cls.model.select().where(cls.model.kb_id == kb_id)
if id: if id:
docs = docs.where( docs = docs.where(cls.model.id == id)
cls.model.id == id)
if name: if name:
docs = docs.where( docs = docs.where(cls.model.name == name)
cls.model.name == name
)
if keywords: if keywords:
docs = docs.where( docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
fn.LOWER(cls.model.name).contains(keywords.lower())
)
if desc: if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc()) docs = docs.order_by(cls.model.getter_by(orderby).desc())
else: else:
@ -72,13 +66,9 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page, def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords):
orderby, desc, keywords):
if keywords: if keywords:
docs = cls.model.select().where( docs = cls.model.select().where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
else: else:
docs = cls.model.select().where(cls.model.kb_id == kb_id) docs = cls.model.select().where(cls.model.kb_id == kb_id)
count = docs.count() count = docs.count()
@ -97,8 +87,7 @@ class DocumentService(CommonService):
if not cls.save(**doc): if not cls.save(**doc):
raise RuntimeError("Database error (Document)!") raise RuntimeError("Database error (Document)!")
e, kb = KnowledgebaseService.get_by_id(doc["kb_id"]) e, kb = KnowledgebaseService.get_by_id(doc["kb_id"])
if not KnowledgebaseService.update_by_id( if not KnowledgebaseService.update_by_id(kb.id, {"doc_num": kb.doc_num + 1}):
kb.id, {"doc_num": kb.doc_num + 1}):
raise RuntimeError("Database error (Knowledgebase)!") raise RuntimeError("Database error (Knowledgebase)!")
return Document(**doc) return Document(**doc)
@ -108,14 +97,16 @@ class DocumentService(CommonService):
cls.clear_chunk_num(doc.id) cls.clear_chunk_num(doc.id)
try: try:
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id}, settings.docStoreConn.update(
{"remove": {"source_id": doc.id}}, {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id},
search.index_name(tenant_id), doc.kb_id) {"remove": {"source_id": doc.id}},
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, search.index_name(tenant_id),
{"removed_kwd": "Y"}, doc.kb_id,
search.index_name(tenant_id), doc.kb_id) )
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}}, settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id)
search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.delete(
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}}, search.index_name(tenant_id), doc.kb_id
)
except Exception: except Exception:
pass pass
return cls.delete_by_id(doc.id) return cls.delete_by_id(doc.id)
@ -136,67 +127,54 @@ class DocumentService(CommonService):
Tenant.embd_id, Tenant.embd_id,
Tenant.img2txt_id, Tenant.img2txt_id,
Tenant.asr_id, Tenant.asr_id,
cls.model.update_time] cls.model.update_time,
docs = cls.model.select(*fields) \ ]
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ docs = (
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \ cls.model.select(*fields)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where( .where(
cls.model.status == StatusEnum.VALID.value, cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value), ~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0, cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600, cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value) \ cls.model.run == TaskStatus.RUNNING.value,
)
.order_by(cls.model.update_time.asc()) .order_by(cls.model.update_time.asc())
)
return list(docs.dicts()) return list(docs.dicts())
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_unfinished_docs(cls): def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id]
cls.model.run, cls.model.parser_id] docs = cls.model.select(*fields).where(cls.model.status == StatusEnum.VALID.value, ~(cls.model.type == FileType.VIRTUAL.value), cls.model.progress < 1, cls.model.progress > 0)
docs = cls.model.select(*fields) \
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress < 1,
cls.model.progress > 0)
return list(docs.dicts()) return list(docs.dicts())
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num + token_num, num = (
chunk_num=cls.model.chunk_num + chunk_num, cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duation=cls.model.process_duation + duation)
process_duation=cls.model.process_duation + duation).where( .where(cls.model.id == doc_id)
cls.model.id == doc_id).execute() .execute()
)
if num == 0: if num == 0:
raise LookupError( raise LookupError("Document not found which is supposed to be there")
"Document not found which is supposed to be there") num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute()
num = Knowledgebase.update(
token_num=Knowledgebase.token_num +
token_num,
chunk_num=Knowledgebase.chunk_num +
chunk_num).where(
Knowledgebase.id == kb_id).execute()
return num return num
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num - token_num, num = (
chunk_num=cls.model.chunk_num - chunk_num, cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duation=cls.model.process_duation + duation)
process_duation=cls.model.process_duation + duation).where( .where(cls.model.id == doc_id)
cls.model.id == doc_id).execute() .execute()
)
if num == 0: if num == 0:
raise LookupError( raise LookupError("Document not found which is supposed to be there")
"Document not found which is supposed to be there") num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute()
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
token_num,
chunk_num=Knowledgebase.chunk_num -
chunk_num
).where(
Knowledgebase.id == kb_id).execute()
return num return num
@classmethod @classmethod
@ -205,24 +183,17 @@ class DocumentService(CommonService):
doc = cls.model.get_by_id(doc_id) doc = cls.model.get_by_id(doc_id)
assert doc, "Can't fine document in database." assert doc, "Can't fine document in database."
num = Knowledgebase.update( num = (
token_num=Knowledgebase.token_num - Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1)
doc.token_num, .where(Knowledgebase.id == doc.kb_id)
chunk_num=Knowledgebase.chunk_num - .execute()
doc.chunk_num, )
doc_num=Knowledgebase.doc_num - 1
).where(
Knowledgebase.id == doc.kb_id).execute()
return num return num
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_tenant_id(cls, doc_id): def get_tenant_id(cls, doc_id):
docs = cls.model.select( docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts() docs = docs.dicts()
if not docs: if not docs:
return return
@ -240,11 +211,7 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_tenant_id_by_name(cls, name): def get_tenant_id_by_name(cls, name):
docs = cls.model.select( docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts() docs = docs.dicts()
if not docs: if not docs:
return return
@ -253,12 +220,13 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def accessible(cls, doc_id, user_id): def accessible(cls, doc_id, user_id):
docs = cls.model.select( docs = (
cls.model.id).join( cls.model.select(cls.model.id)
Knowledgebase, on=( .join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
Knowledgebase.id == cls.model.kb_id) .join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id))
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) .where(cls.model.id == doc_id, UserTenant.user_id == user_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1) .paginate(0, 1)
)
docs = docs.dicts() docs = docs.dicts()
if not docs: if not docs:
return False return False
@ -267,11 +235,7 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def accessible4deletion(cls, doc_id, user_id): def accessible4deletion(cls, doc_id, user_id):
docs = cls.model.select( docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
docs = docs.dicts() docs = docs.dicts()
if not docs: if not docs:
return False return False
@ -280,11 +244,7 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_embd_id(cls, doc_id): def get_embd_id(cls, doc_id):
docs = cls.model.select( docs = cls.model.select(Knowledgebase.embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
Knowledgebase.embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts() docs = docs.dicts()
if not docs: if not docs:
return return
@ -306,9 +266,9 @@ class DocumentService(CommonService):
Tenant.asr_id, Tenant.asr_id,
Tenant.llm_id, Tenant.llm_id,
) )
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id) .where(cls.model.id == doc_id)
) )
configs = configs.dicts() configs = configs.dicts()
if not configs: if not configs:
@ -319,8 +279,7 @@ class DocumentService(CommonService):
@DB.connection_context() @DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name): def get_doc_id_by_doc_name(cls, doc_name):
fields = [cls.model.id] fields = [cls.model.id]
doc_id = cls.model.select(*fields) \ doc_id = cls.model.select(*fields).where(cls.model.name == doc_name)
.where(cls.model.name == doc_name)
doc_id = doc_id.dicts() doc_id = doc_id.dicts()
if not doc_id: if not doc_id:
return return
@ -330,8 +289,7 @@ class DocumentService(CommonService):
@DB.connection_context() @DB.connection_context()
def get_thumbnails(cls, docids): def get_thumbnails(cls, docids):
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail] fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
return list(cls.model.select( return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts())
*fields).where(cls.model.id.in_(docids)).dicts())
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
@ -359,19 +317,14 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_doc_count(cls, tenant_id): def get_doc_count(cls, tenant_id):
docs = cls.model.select(cls.model.id).join(Knowledgebase, docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(Knowledgebase.tenant_id == tenant_id)
on=(Knowledgebase.id == cls.model.kb_id)).where(
Knowledgebase.tenant_id == tenant_id)
return len(docs) return len(docs)
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def begin2parse(cls, docid): def begin2parse(cls, docid):
cls.update_by_id( cls.update_by_id(docid, {"progress": random.random() * 1 / 100.0, "progress_msg": "Task is queued...", "process_begin_at": get_format_time()})
docid, {"progress": random.random() * 1 / 100.,
"progress_msg": "Task is queued...",
"process_begin_at": get_format_time()
})
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def update_meta_fields(cls, doc_id, meta_fields): def update_meta_fields(cls, doc_id, meta_fields):
@ -420,11 +373,7 @@ class DocumentService(CommonService):
status = TaskStatus.DONE.value status = TaskStatus.DONE.value
msg = "\n".join(sorted(msg)) msg = "\n".join(sorted(msg))
info = { info = {"process_duation": datetime.timestamp(datetime.now()) - d["process_begin_at"].timestamp(), "run": status}
"process_duation": datetime.timestamp(
datetime.now()) -
d["process_begin_at"].timestamp(),
"run": status}
if prg != 0: if prg != 0:
info["progress"] = prg info["progress"] = prg
if msg: if msg:
@ -437,8 +386,7 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_kb_doc_count(cls, kb_id): def get_kb_doc_count(cls, kb_id):
return len(cls.model.select(cls.model.id).where( return len(cls.model.select(cls.model.id).where(cls.model.kb_id == kb_id).dicts())
cls.model.kb_id == kb_id).dicts())
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
@ -459,14 +407,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty):
def new_task(): def new_task():
nonlocal doc nonlocal doc
return { return {"id": get_uuid(), "doc_id": doc["id"], "from_page": 100000000, "to_page": 100000000, "task_type": ty, "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty}
"id": get_uuid(),
"doc_id": doc["id"],
"from_page": 100000000,
"to_page": 100000000,
"task_type": ty,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
}
task = new_task() task = new_task()
for field in ["doc_id", "from_page", "to_page"]: for field in ["doc_id", "from_page", "to_page"]:
@ -478,6 +419,25 @@ def queue_raptor_o_graphrag_tasks(doc, ty):
def doc_upload_and_parse(conversation_id, file_objs, user_id): def doc_upload_and_parse(conversation_id, file_objs, user_id):
"""
上传并解析文档将内容存入知识库
参数:
conversation_id: 会话ID
file_objs: 文件对象列表
user_id: 用户ID
返回:
处理成功的文档ID列表
处理流程:
1. 验证会话和知识库
2. 初始化嵌入模型
3. 上传文件到存储
4. 多线程解析文件内容
5. 生成内容嵌入向量
6. 存入文档存储系统
"""
from rag.app import presentation, picture, naive, audio, email from rag.app import presentation, picture, naive, audio, email
from api.db.services.dialog_service import DialogService from api.db.services.dialog_service import DialogService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
@ -493,8 +453,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
e, dia = DialogService.get_by_id(conv.dialog_id) e, dia = DialogService.get_by_id(conv.dialog_id)
if not dia.kb_ids: if not dia.kb_ids:
raise LookupError("No knowledge base associated with this conversation. " raise LookupError("No knowledge base associated with this conversation. Please add a knowledge base before uploading documents")
"Please add a knowledge base before uploading documents")
kb_id = dia.kb_ids[0] kb_id = dia.kb_ids[0]
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
@ -508,12 +467,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
FACTORY = { FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
ParserType.PRESENTATION.value: presentation,
ParserType.PICTURE.value: picture,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"} parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
# 使用线程池执行解析任务 # 使用线程池执行解析任务
exe = ThreadPoolExecutor(max_workers=12) exe = ThreadPoolExecutor(max_workers=12)
@ -522,22 +476,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
for d, blob in files: for d, blob in files:
doc_nm[d["id"]] = d["name"] doc_nm[d["id"]] = d["name"]
for d, blob in files: for d, blob in files:
kwargs = { kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language}
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000,
"tenant_id": kb.tenant_id,
"lang": kb.language
}
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs)) threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
for (docinfo, _), th in zip(files, threads): for (docinfo, _), th in zip(files, threads):
docs = [] docs = []
doc = { doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]}
"doc_id": docinfo["id"],
"kb_id": [kb.id]
}
for ck in th.result(): for ck in th.result():
d = deepcopy(doc) d = deepcopy(doc)
d.update(ck) d.update(ck)
@ -552,7 +496,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if isinstance(d["image"], bytes): if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"]) output_buffer = BytesIO(d["image"])
else: else:
d["image"].save(output_buffer, format='JPEG') d["image"].save(output_buffer, format="JPEG")
STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue()) STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["id"]) d["img_id"] = "{}-{}".format(kb.id, d["id"])
@ -569,9 +513,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
nonlocal embd_mdl, chunk_counts, token_counts nonlocal embd_mdl, chunk_counts, token_counts
vects = [] vects = []
for i in range(0, len(cnts), batch_size): for i in range(0, len(cnts), batch_size):
vts, c = embd_mdl.encode(cnts[i: i + batch_size]) vts, c = embd_mdl.encode(cnts[i : i + batch_size])
vects.extend(vts.tolist()) vects.extend(vts.tolist())
chunk_counts[doc_id] += len(cnts[i:i + batch_size]) chunk_counts[doc_id] += len(cnts[i : i + batch_size])
token_counts[doc_id] += c token_counts[doc_id] += c
return vects return vects
@ -585,22 +529,25 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if parser_ids[doc_id] != ParserType.PICTURE.value: if parser_ids[doc_id] != ParserType.PICTURE.value:
from graphrag.general.mind_map_extractor import MindMapExtractor from graphrag.general.mind_map_extractor import MindMapExtractor
mindmap = MindMapExtractor(llm_bdl) mindmap = MindMapExtractor(llm_bdl)
try: try:
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]) mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2) mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
if len(mind_map) < 32: if len(mind_map) < 32:
raise Exception("Few content: " + mind_map) raise Exception("Few content: " + mind_map)
cks.append({ cks.append(
"id": get_uuid(), {
"doc_id": doc_id, "id": get_uuid(),
"kb_id": [kb.id], "doc_id": doc_id,
"docnm_kwd": doc_nm[doc_id], "kb_id": [kb.id],
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])), "docnm_kwd": doc_nm[doc_id],
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"), "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
"content_with_weight": mind_map, "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
"knowledge_graph_kwd": "mind_map" "content_with_weight": mind_map,
}) "knowledge_graph_kwd": "mind_map",
}
)
except Exception as e: except Exception as e:
logging.exception("Mind map generation error") logging.exception("Mind map generation error")
@ -614,9 +561,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if not settings.docStoreConn.indexExist(idxnm, kb_id): if not settings.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
try_create_idx = False try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
return [d["id"] for d, _ in files] return [d["id"] for d, _ in files]

View File

@ -1,5 +1,5 @@
from flask import jsonify, request from flask import jsonify, request
from services.conversation.service import get_conversations_by_user_id, get_messages_by_conversation_id, get_conversation_detail from services.conversation.service import get_conversations_by_user_id, get_messages_by_conversation_id
from .. import conversation_bp from .. import conversation_bp
@ -44,20 +44,3 @@ def get_messages(conversation_id):
except Exception as e: except Exception as e:
# 错误处理 # 错误处理
return jsonify({"code": 500, "message": f"获取消息列表失败: {str(e)}"}), 500 return jsonify({"code": 500, "message": f"获取消息列表失败: {str(e)}"}), 500
@conversation_bp.route("/<conversation_id>", methods=["GET"])
def get_conversation(conversation_id):
"""获取特定对话的详细信息"""
try:
# 调用服务函数获取对话详情
conversation = get_conversation_detail(conversation_id)
if not conversation:
return jsonify({"code": 404, "message": "对话不存在"}), 404
# 返回符合前端期望格式的数据
return jsonify({"code": 0, "data": conversation, "message": "获取对话详情成功"})
except Exception as e:
# 错误处理
return jsonify({"code": 500, "message": f"获取对话详情失败: {str(e)}"}), 500

View File

@ -23,8 +23,6 @@ def get_conversations_by_user_id(user_id, page=1, size=20, sort_by="update_time"
# 直接使用user_id作为tenant_id # 直接使用user_id作为tenant_id
tenant_id = user_id tenant_id = user_id
print(f"查询用户ID: {user_id}, 租户ID: {tenant_id}")
# 查询总记录数 # 查询总记录数
count_sql = """ count_sql = """
SELECT COUNT(*) as total SELECT COUNT(*) as total
@ -34,7 +32,7 @@ def get_conversations_by_user_id(user_id, page=1, size=20, sort_by="update_time"
cursor.execute(count_sql, (tenant_id,)) cursor.execute(count_sql, (tenant_id,))
total = cursor.fetchone()["total"] total = cursor.fetchone()["total"]
print(f"查询到总记录数: {total}") # print(f"查询到总记录数: {total}")
# 计算分页偏移量 # 计算分页偏移量
offset = (page - 1) * size offset = (page - 1) * size
@ -59,8 +57,8 @@ def get_conversations_by_user_id(user_id, page=1, size=20, sort_by="update_time"
LIMIT %s OFFSET %s LIMIT %s OFFSET %s
""" """
print(f"执行查询: {query}") # print(f"执行查询: {query}")
print(f"参数: tenant_id={tenant_id}, size={size}, offset={offset}") # print(f"参数: tenant_id={tenant_id}, size={size}, offset={offset}")
cursor.execute(query, (tenant_id, size, offset)) cursor.execute(query, (tenant_id, size, offset))
results = cursor.fetchall() results = cursor.fetchall()
@ -200,68 +198,3 @@ def get_messages_by_conversation_id(conversation_id, page=1, size=30):
traceback.print_exc() traceback.print_exc()
return None, 0 return None, 0
def get_conversation_detail(conversation_id):
"""
获取特定对话的详细信息
参数:
conversation_id (str): 对话ID
返回:
dict: 对话详情
"""
try:
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor(dictionary=True)
# 查询对话信息
query = """
SELECT c.*, d.name as dialog_name, d.icon as dialog_icon
FROM conversation c
LEFT JOIN dialog d ON c.dialog_id = d.id
WHERE c.id = %s
"""
cursor.execute(query, (conversation_id,))
result = cursor.fetchone()
if not result:
print(f"未找到对话ID: {conversation_id}")
return None
# 格式化对话详情
conversation = {
"id": result["id"],
"name": result.get("name", ""),
"dialogId": result.get("dialog_id", ""),
"dialogName": result.get("dialog_name", ""),
"dialogIcon": result.get("dialog_icon", ""),
"createTime": result["create_date"].strftime("%Y-%m-%d %H:%M:%S") if result.get("create_date") else "",
"updateTime": result["update_date"].strftime("%Y-%m-%d %H:%M:%S") if result.get("update_date") else "",
"messages": result.get("message", []),
}
# 打印调试信息
print(f"获取到对话详情: ID={conversation_id}")
print(f"消息数量: {len(conversation['messages']) if conversation['messages'] else 0}")
# 关闭连接
cursor.close()
conn.close()
return conversation
except mysql.connector.Error as err:
print(f"数据库错误: {err}")
# 更详细的错误日志
import traceback
traceback.print_exc()
return None
except Exception as e:
print(f"未知错误: {e}")
import traceback
traceback.print_exc()
return None

View File

@ -1,11 +1,8 @@
<script lang="ts" setup> <script lang="ts" setup>
import type { CreateOrUpdateTableRequestData, TableData } from "@@/apis/tables/type" import type { TableData } from "@@/apis/tables/type"
import type { FormInstance, FormRules } from "element-plus" import { getTableDataApi } from "@@/apis/tables"
import { createTableDataApi, deleteTableDataApi, getTableDataApi, resetPasswordApi, updateTableDataApi } from "@@/apis/tables" import { ChatDotRound, User } from "@element-plus/icons-vue"
import { usePagination } from "@@/composables/usePagination"
import { ChatDotRound, CirclePlus, Delete, Edit, Key, Refresh, RefreshRight, Search, User } from "@element-plus/icons-vue"
import axios from "axios" import axios from "axios"
import { cloneDeep } from "lodash-es"
defineOptions({ defineOptions({
// //
@ -79,10 +76,16 @@ function loadMoreUsers() {
/** /**
* 监听用户列表滚动事件 * 监听用户列表滚动事件
* @param event 滚动事件 * @param event DOM滚动事件对象
*/ */
function handleUserListScroll(event) { function handleUserListScroll(event: Event) {
const { scrollTop, scrollHeight, clientHeight } = event.target // event.target HTMLElement
const target = event.target as HTMLElement
if (!target) return
//
const { scrollTop, scrollHeight, clientHeight } = target
// 100px // 100px
if (scrollHeight - scrollTop - clientHeight < 100 && userHasMore.value && !userLoading.value) { if (scrollHeight - scrollTop - clientHeight < 100 && userHasMore.value && !userLoading.value) {
loadMoreUsers() loadMoreUsers()
@ -203,10 +206,16 @@ function loadMoreConversations() {
/** /**
* 监听对话列表滚动事件 * 监听对话列表滚动事件
* @param event 滚动事件 * @param event DOM滚动事件对象
*/ */
function handleConversationListScroll(event) { function handleConversationListScroll(event: Event) {
const { scrollTop, scrollHeight, clientHeight } = event.target // event.target HTMLElement
const target = event.target as HTMLElement
if (!target) return
//
const { scrollTop, scrollHeight, clientHeight } = target
// 100px // 100px
if (scrollHeight - scrollTop - clientHeight < 100 && conversationHasMore.value && !conversationLoading.value) { if (scrollHeight - scrollTop - clientHeight < 100 && conversationHasMore.value && !conversationLoading.value) {
loadMoreConversations() loadMoreConversations()
@ -247,7 +256,7 @@ function getMessagesByConversationId(conversationId: string, isLoadMore = false)
const parsedMessages = JSON.parse(conversation.messages) const parsedMessages = JSON.parse(conversation.messages)
// //
processedMessages = parsedMessages.map((msg, index) => { processedMessages = parsedMessages.map((msg: { id: any, role: any, content: any, created_at: number }, index: any) => {
return { return {
id: msg.id || `msg-${index}`, id: msg.id || `msg-${index}`,
conversation_id: conversationId, conversation_id: conversationId,
@ -268,7 +277,7 @@ function getMessagesByConversationId(conversationId: string, isLoadMore = false)
if (isLoadMore) { if (isLoadMore) {
// //
const existingIds = new Set(messageList.value.map(msg => msg.id)) const existingIds = new Set(messageList.value.map(msg => msg.id))
const uniqueNewMessages = processedMessages.filter(msg => !existingIds.has(msg.id)) const uniqueNewMessages = processedMessages.filter((msg: { id: number }) => !existingIds.has(msg.id))
// //
messageList.value = [...messageList.value, ...uniqueNewMessages] messageList.value = [...messageList.value, ...uniqueNewMessages]
@ -307,7 +316,7 @@ function getMessagesByConversationId(conversationId: string, isLoadMore = false)
* @param content 消息内容 * @param content 消息内容
* @returns 处理后的HTML内容 * @returns 处理后的HTML内容
*/ */
function renderMessageContent(content) { function renderMessageContent(content: string) {
if (!content) return "" if (!content) return ""
// Markdown // Markdown
@ -331,10 +340,16 @@ function loadMoreMessages() {
/** /**
* 监听消息列表滚动事件 * 监听消息列表滚动事件
* @param event 滚动事件 * @param event DOM滚动事件对象
*/ */
function handleMessageListScroll(event) { function handleMessageListScroll(event: Event) {
const { scrollTop, scrollHeight, clientHeight } = event.target // event.target HTMLElement
const target = event.target as HTMLElement
if (!target) return
//
const { scrollTop, scrollHeight, clientHeight } = target
// 100px // 100px
if (scrollHeight - scrollTop - clientHeight < 100 && messageHasMore.value && !messageLoading.value) { if (scrollHeight - scrollTop - clientHeight < 100 && messageHasMore.value && !messageLoading.value) {
loadMoreMessages() loadMoreMessages()

View File

@ -1,10 +1,9 @@
<script lang="ts" setup> <script lang="ts" setup>
import type { CreateOrUpdateTableRequestData, TableData } from "@@/apis/configs/type" import type { TableData } from "@@/apis/configs/type"
import type { FormInstance, FormRules } from "element-plus" import type { FormInstance } from "element-plus"
import { getTableDataApi, updateTableDataApi } from "@@/apis/configs" import { getTableDataApi } from "@@/apis/configs"
import { usePagination } from "@@/composables/usePagination" import { usePagination } from "@@/composables/usePagination"
import { CirclePlus, Delete, Refresh, RefreshRight, Search } from "@element-plus/icons-vue" import { Refresh, Search } from "@element-plus/icons-vue"
import { cloneDeep } from "lodash-es"
defineOptions({ defineOptions({
// //
@ -15,47 +14,41 @@ const loading = ref<boolean>(false)
const { paginationData, handleCurrentChange, handleSizeChange } = usePagination() const { paginationData, handleCurrentChange, handleSizeChange } = usePagination()
// #region // #region
const DEFAULT_FORM_DATA: CreateOrUpdateTableRequestData = { // const DEFAULT_FORM_DATA: CreateOrUpdateTableRequestData = {
id: undefined, // id: undefined,
username: "", // username: "",
chatModel: "", // chatModel: "",
embeddingModel: "" // embeddingModel: ""
} // }
const dialogVisible = ref<boolean>(false) // const dialogVisible = ref<boolean>(false)
const formData = ref<CreateOrUpdateTableRequestData>(cloneDeep(DEFAULT_FORM_DATA)) // const formData = ref<CreateOrUpdateTableRequestData>(cloneDeep(DEFAULT_FORM_DATA))
// //
function handleDelete() { function handleDelete() {
ElMessage.success("如需删除租户配置,可直接删除负责人账号") ElMessage.success("如需删除用户配置信息,请直接在前台登录用户账号进行操作")
} }
// //
function handleUpdate(row: TableData) { function handleUpdate() {
dialogVisible.value = true ElMessage.success("如需修改用户配置信息,请直接在前台登录用户账号进行操作")
formData.value = cloneDeep({
id: row.id,
username: row.username,
chatModel: row.chatModel,
embeddingModel: row.embeddingModel
})
} }
// //
function submitForm() { // function submitForm() {
loading.value = true // loading.value = true
updateTableDataApi(formData.value) // updateTableDataApi(formData.value)
.then(() => { // .then(() => {
ElMessage.success("修改成功") // ElMessage.success("")
dialogVisible.value = false // dialogVisible.value = false
getTableData() // // getTableData() //
}) // })
.catch(() => { // .catch(() => {
ElMessage.error("修改失败") // ElMessage.error("")
}) // })
.finally(() => { // .finally(() => {
loading.value = false // loading.value = false
}) // })
} // }
// //
const tableData = ref<TableData[]>([]) const tableData = ref<TableData[]>([])
@ -137,8 +130,8 @@ onActivated(() => {
<el-table-column prop="embeddingModel" label="嵌入模型" align="center" /> <el-table-column prop="embeddingModel" label="嵌入模型" align="center" />
<el-table-column prop="updateTime" label="更新时间" align="center" /> <el-table-column prop="updateTime" label="更新时间" align="center" />
<el-table-column fixed="right" label="操作" width="150" align="center"> <el-table-column fixed="right" label="操作" width="150" align="center">
<template #default="scope"> <template #default="">
<el-button type="primary" text bg size="small" @click="handleUpdate(scope.row)"> <el-button type="primary" text bg size="small" @click="handleUpdate">
修改 修改
</el-button> </el-button>
<el-button type="danger" text bg size="small" @click="handleDelete()"> <el-button type="danger" text bg size="small" @click="handleDelete()">
@ -163,7 +156,7 @@ onActivated(() => {
</el-card> </el-card>
<!-- 修改对话框 --> <!-- 修改对话框 -->
<el-dialog v-model="dialogVisible" title="修改配置" width="30%"> <!-- <el-dialog v-model="dialogVisible" title="修改配置" width="30%">
<el-form :model="formData" label-width="100px"> <el-form :model="formData" label-width="100px">
<el-form-item label="用户名"> <el-form-item label="用户名">
<el-input v-model="formData.username" disabled /> <el-input v-model="formData.username" disabled />
@ -183,7 +176,7 @@ onActivated(() => {
确认 确认
</el-button> </el-button>
</template> </template>
</el-dialog> </el-dialog> -->
</div> </div>
</template> </template>

View File

@ -32,16 +32,19 @@ def chunks_format(reference):
def get_value(d, k1, k2): def get_value(d, k1, k2):
return d.get(k1, d.get(k2)) return d.get(k1, d.get(k2))
return [{ return [
"id": get_value(chunk, "chunk_id", "id"), {
"content": get_value(chunk, "content", "content_with_weight"), "id": get_value(chunk, "chunk_id", "id"),
"document_id": get_value(chunk, "doc_id", "document_id"), "content": get_value(chunk, "content", "content_with_weight"),
"document_name": get_value(chunk, "docnm_kwd", "document_name"), "document_id": get_value(chunk, "doc_id", "document_id"),
"dataset_id": get_value(chunk, "kb_id", "dataset_id"), "document_name": get_value(chunk, "docnm_kwd", "document_name"),
"image_id": get_value(chunk, "image_id", "img_id"), "dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"positions": get_value(chunk, "positions", "position_int"), "image_id": get_value(chunk, "image_id", "img_id"),
"url": chunk.get("url") "positions": get_value(chunk, "positions", "position_int"),
} for chunk in reference.get("chunks", [])] "url": chunk.get("url"),
}
for chunk in reference.get("chunks", [])
]
def llm_id2llm_type(llm_id): def llm_id2llm_type(llm_id):
@ -57,21 +60,21 @@ def llm_id2llm_type(llm_id):
def message_fit_in(msg, max_length=4000): def message_fit_in(msg, max_length=4000):
""" """
调整消息列表使其token总数不超过max_length限制 调整消息列表使其token总数不超过max_length限制
参数: 参数:
msg: 消息列表每个元素为包含role和content的字典 msg: 消息列表每个元素为包含role和content的字典
max_length: 最大token数限制默认4000 max_length: 最大token数限制默认4000
返回: 返回:
tuple: (实际token数, 调整后的消息列表) tuple: (实际token数, 调整后的消息列表)
""" """
def count(): def count():
"""计算当前消息列表的总token数""" """计算当前消息列表的总token数"""
nonlocal msg nonlocal msg
tks_cnts = [] tks_cnts = []
for m in msg: for m in msg:
tks_cnts.append( tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
total = 0 total = 0
for m in tks_cnts: for m in tks_cnts:
total += m["count"] total += m["count"]
@ -81,7 +84,7 @@ def message_fit_in(msg, max_length=4000):
# 如果不超限制,直接返回 # 如果不超限制,直接返回
if c < max_length: if c < max_length:
return c, msg return c, msg
# 第一次精简:保留系统消息和最后一条消息 # 第一次精简:保留系统消息和最后一条消息
msg_ = [m for m in msg if m["role"] == "system"] msg_ = [m for m in msg if m["role"] == "system"]
if len(msg) > 1: if len(msg) > 1:
@ -90,20 +93,20 @@ def message_fit_in(msg, max_length=4000):
c = count() c = count()
if c < max_length: if c < max_length:
return c, msg return c, msg
# 计算系统消息和最后一条消息的token数 # 计算系统消息和最后一条消息的token数
ll = num_tokens_from_string(msg_[0]["content"]) ll = num_tokens_from_string(msg_[0]["content"])
ll2 = num_tokens_from_string(msg_[-1]["content"]) ll2 = num_tokens_from_string(msg_[-1]["content"])
# 如果系统消息占比超过80%,则截断系统消息 # 如果系统消息占比超过80%,则截断系统消息
if ll / (ll + ll2) > 0.8: if ll / (ll + ll2) > 0.8:
m = msg_[0]["content"] m = msg_[0]["content"]
m = encoder.decode(encoder.encode(m)[:max_length - ll2]) m = encoder.decode(encoder.encode(m)[: max_length - ll2])
msg[0]["content"] = m msg[0]["content"] = m
return max_length, msg return max_length, msg
# 否则截断最后一条消息 # 否则截断最后一条消息
m = msg_[-1]["content"] m = msg_[-1]["content"]
m = encoder.decode(encoder.encode(m)[:max_length - ll2]) m = encoder.decode(encoder.encode(m)[: max_length - ll2])
msg[-1]["content"] = m msg[-1]["content"] = m
return max_length, msg return max_length, msg
@ -111,18 +114,18 @@ def message_fit_in(msg, max_length=4000):
def kb_prompt(kbinfos, max_tokens): def kb_prompt(kbinfos, max_tokens):
""" """
将检索到的知识库内容格式化为适合大语言模型的提示词 将检索到的知识库内容格式化为适合大语言模型的提示词
参数: 参数:
kbinfos (dict): 检索结果包含chunks等信息 kbinfos (dict): 检索结果包含chunks等信息
max_tokens (int): 模型的最大token限制 max_tokens (int): 模型的最大token限制
流程: 流程:
1. 提取所有检索到的文档片段内容 1. 提取所有检索到的文档片段内容
2. 计算token数量确保不超过模型限制 2. 计算token数量确保不超过模型限制
3. 获取文档元数据 3. 获取文档元数据
4. 按文档名组织文档片段 4. 按文档名组织文档片段
5. 格式化为结构化提示词 5. 格式化为结构化提示词
返回: 返回:
list: 格式化后的知识库内容列表每个元素是一个文档的相关信息 list: 格式化后的知识库内容列表每个元素是一个文档的相关信息
""" """
@ -134,7 +137,7 @@ def kb_prompt(kbinfos, max_tokens):
chunks_num += 1 chunks_num += 1
if max_tokens * 0.97 < used_token_count: if max_tokens * 0.97 < used_token_count:
knowledges = knowledges[:i] knowledges = knowledges[:i]
logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}") logging.warning(f"Not all the retrieval into prompt: {i + 1}/{len(knowledges)}")
break break
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]]) docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
@ -163,6 +166,10 @@ def citation_prompt():
- 以格式 '##i$$ ##j$$'插入引用其中 i, j 是所引用内容的 ID并用 '##' '$$' 包裹 - 以格式 '##i$$ ##j$$'插入引用其中 i, j 是所引用内容的 ID并用 '##' '$$' 包裹
- 在句子末尾插入引用每个句子最多 4 个引用 - 在句子末尾插入引用每个句子最多 4 个引用
- 如果答案内容不来自检索到的文本块则不要插入引用 - 如果答案内容不来自检索到的文本块则不要插入引用
- 不要使用独立的文档 ID例如 `#ID#`)。
- 在任何情况下不得使用其他引用样式或格式例如 `~~i==``[i]``(i)`
- 引用必须始终使用 `##i$$` 格式。
- 任何未能遵守上述规则的情况包括但不限于格式错误使用禁止的样式或不支持的引用都将被视为错误应跳过为该句添加引用
--- 示例 --- --- 示例 ---
<SYSTEM>: 以下是知识库: <SYSTEM>: 以下是知识库:
@ -210,10 +217,7 @@ def keyword_extraction(chat_mdl, content, topn=3):
### 文本内容 ### 文本内容
{content} {content}
""" """
msg = [ msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length) _, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): if isinstance(kwd, tuple):
@ -240,10 +244,7 @@ Requirements:
{content} {content}
""" """
msg = [ msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length) _, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): if isinstance(kwd, tuple):
@ -368,10 +369,7 @@ Output:
{content} {content}
""" """
msg = [ msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length) _, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5}) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple): if isinstance(kwd, tuple):
@ -384,8 +382,8 @@ Output:
return json_repair.loads(kwd) return json_repair.loads(kwd)
except json_repair.JSONDecodeError: except json_repair.JSONDecodeError:
try: try:
result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip() result = kwd.replace(prompt[:-1], "").replace("user", "").replace("model", "").strip()
result = '{' + result.split('{')[1].split('}')[0] + '}' result = "{" + result.split("{")[1].split("}")[0] + "}"
return json_repair.loads(result) return json_repair.loads(result)
except Exception as e: except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}") logging.exception(f"JSON parsing error: {result} -> {e}")

View File

@ -27,7 +27,8 @@ import {
UploadProps, UploadProps,
} from 'antd'; } from 'antd';
import get from 'lodash/get'; import get from 'lodash/get';
import { CircleStop, Paperclip, SendHorizontal } from 'lucide-react'; // import { CircleStop, Paperclip, SendHorizontal } from 'lucide-react';
import { CircleStop, SendHorizontal } from 'lucide-react';
import { import {
ChangeEventHandler, ChangeEventHandler,
memo, memo,
@ -339,9 +340,9 @@ const MessageInput = ({
return false; return false;
}} }}
> >
<Button type={'primary'} disabled={disabled}> {/* <Button type={'primary'} disabled={disabled}>
<Paperclip className="size-4" /> <Paperclip className="size-4" />
</Button> </Button> */}
</Upload> </Upload>
)} )}
{sendLoading ? ( {sendLoading ? (

View File

@ -17,13 +17,11 @@
line-height: 1.2; line-height: 1.2;
border-bottom: 2px solid #eaeaea; border-bottom: 2px solid #eaeaea;
padding-bottom: 0.25em; padding-bottom: 0.25em;
font-size: 24px;
margin: 0.25em 0.25em; margin: 0.25em 0.25em;
} }
section { section {
margin-top: 1em; margin-top: 1em;
margin-bottom: 1em; margin-bottom: 1em;
font-size: 20px;
p { p {
margin-left: 0; margin-left: 0;
} }
@ -32,7 +30,6 @@
margin-top: 1em; margin-top: 1em;
margin-bottom: 1em; margin-bottom: 1em;
margin-left: 1em; margin-left: 1em;
font-size: 20px;
} }
ul, ul,
ol { ol {
@ -40,7 +37,6 @@
padding-left: 1.8em; padding-left: 1.8em;
li { li {
margin-bottom: 0.25em; margin-bottom: 0.25em;
font-size: 20px;
} }
} }
table { table {
@ -60,7 +56,6 @@
border: none; border: none;
padding: 12px; padding: 12px;
text-align: left; text-align: left;
font-size: 20px;
border-bottom: 1px solid #ddd; border-bottom: 1px solid #ddd;
} }

View File

@ -24,9 +24,10 @@ import styles from './index.less';
interface IProps { interface IProps {
controller: AbortController; controller: AbortController;
fontSize: number;
} }
const ChatContainer = ({ controller }: IProps) => { const ChatContainer = ({ controller, fontSize = 16 }: IProps) => {
const { conversationId } = useGetChatSearchParams(); const { conversationId } = useGetChatSearchParams();
const { data: conversation } = useFetchNextConversation(); const { data: conversation } = useFetchNextConversation();
@ -55,7 +56,12 @@ const ChatContainer = ({ controller }: IProps) => {
return ( return (
<> <>
<Flex flex={1} className={styles.chatContainer} vertical> <Flex flex={1} className={styles.chatContainer} vertical>
<Flex flex={1} vertical className={styles.messageContainer}> <Flex
flex={1}
vertical
className={styles.messageContainer}
style={{ fontSize: `${fontSize}px` }}
>
<div> <div>
<Spin spinning={loading}> <Spin spinning={loading}>
{derivedMessages?.map((message, i) => { {derivedMessages?.map((message, i) => {

View File

@ -1,6 +1,10 @@
import { ReactComponent as ChatAppCube } from '@/assets/svg/chat-app-cube.svg'; import { ReactComponent as ChatAppCube } from '@/assets/svg/chat-app-cube.svg';
import RenameModal from '@/components/rename-modal'; import RenameModal from '@/components/rename-modal';
import { DeleteOutlined, EditOutlined } from '@ant-design/icons'; import {
DeleteOutlined,
EditOutlined,
SettingOutlined,
} from '@ant-design/icons';
import { import {
Avatar, Avatar,
Button, Button,
@ -12,12 +16,12 @@ import {
Space, Space,
Spin, Spin,
Tag, Tag,
Tooltip, // Tooltip,
Typography, Typography,
} from 'antd'; } from 'antd';
import { MenuItemProps } from 'antd/lib/menu/MenuItem'; import { MenuItemProps } from 'antd/lib/menu/MenuItem';
import classNames from 'classnames'; import classNames from 'classnames';
import { useCallback, useState } from 'react'; import { useCallback, useEffect, useState } from 'react';
import ChatConfigurationModal from './chat-configuration-modal'; import ChatConfigurationModal from './chat-configuration-modal';
import ChatContainer from './chat-container'; import ChatContainer from './chat-container';
import { import {
@ -43,9 +47,9 @@ import {
import { useTranslate } from '@/hooks/common-hooks'; import { useTranslate } from '@/hooks/common-hooks';
import { useSetSelectedRecord } from '@/hooks/logic-hooks'; import { useSetSelectedRecord } from '@/hooks/logic-hooks';
import { IDialog } from '@/interfaces/database/chat'; import { IDialog } from '@/interfaces/database/chat';
import { Modal, Slider } from 'antd';
import { PictureInPicture2 } from 'lucide-react'; import { PictureInPicture2 } from 'lucide-react';
import styles from './index.less'; import styles from './index.less';
const { Text } = Typography; const { Text } = Typography;
const Chat = () => { const Chat = () => {
@ -161,6 +165,17 @@ const Chat = () => {
addTemporaryConversation(); addTemporaryConversation();
}, [addTemporaryConversation]); }, [addTemporaryConversation]);
const [fontSizeModalVisible, setFontSizeModalVisible] = useState(false);
const [fontSize, setFontSize] = useState(16); // 默认字体大小
// 从localStorage加载字体大小设置
useEffect(() => {
const savedFontSize = localStorage.getItem('chatFontSize');
if (savedFontSize) {
setFontSize(parseInt(savedFontSize));
}
}, []);
const buildAppItems = (dialog: IDialog) => { const buildAppItems = (dialog: IDialog) => {
const dialogId = dialog.id; const dialogId = dialog.id;
@ -286,6 +301,7 @@ const Chat = () => {
</Flex> </Flex>
</Flex> </Flex>
<Divider type={'vertical'} className={styles.divider}></Divider> <Divider type={'vertical'} className={styles.divider}></Divider>
<Flex className={styles.chatTitleWrapper}> <Flex className={styles.chatTitleWrapper}>
<Flex flex={1} vertical> <Flex flex={1} vertical>
<Flex <Flex
@ -297,15 +313,23 @@ const Chat = () => {
<b>{t('chat')}</b> <b>{t('chat')}</b>
<Tag>{conversationList.length}</Tag> <Tag>{conversationList.length}</Tag>
</Space> </Space>
<Tooltip title={t('newChat')}> {/* <Tooltip title={t('newChat')}> */}
<div> <div>
<SvgIcon <SettingOutlined
name="plus-circle-fill" style={{
width={20} marginRight: '8px',
onClick={handleCreateTemporaryConversation} fontSize: '20px',
></SvgIcon> cursor: 'pointer',
</div> }}
</Tooltip> onClick={() => setFontSizeModalVisible(true)}
/>
<SvgIcon
name="plus-circle-fill"
width={20}
onClick={handleCreateTemporaryConversation}
></SvgIcon>
</div>
{/* </Tooltip> */}
</Flex> </Flex>
<Divider></Divider> <Divider></Divider>
<Flex vertical gap={10} className={styles.chatTitleContent}> <Flex vertical gap={10} className={styles.chatTitleContent}>
@ -356,7 +380,10 @@ const Chat = () => {
</Flex> </Flex>
</Flex> </Flex>
<Divider type={'vertical'} className={styles.divider}></Divider> <Divider type={'vertical'} className={styles.divider}></Divider>
<ChatContainer controller={controller}></ChatContainer> <ChatContainer
controller={controller}
fontSize={fontSize}
></ChatContainer>
{dialogEditVisible && ( {dialogEditVisible && (
<ChatConfigurationModal <ChatConfigurationModal
visible={dialogEditVisible} visible={dialogEditVisible}
@ -386,6 +413,27 @@ const Chat = () => {
isAgent={false} isAgent={false}
></EmbedModal> ></EmbedModal>
)} )}
{fontSizeModalVisible && (
<Modal
title={'设置字体大小'}
open={fontSizeModalVisible}
onCancel={() => setFontSizeModalVisible(false)}
footer={null}
>
<Flex vertical gap="middle" align="center">
{'当前字体大小'}: {fontSize}px
<Slider
min={12}
max={24}
step={1}
defaultValue={fontSize}
style={{ width: '80%' }}
onChange={(value) => setFontSize(value)}
/>
</Flex>
</Modal>
)}
</Flex> </Flex>
); );
}; };