Merge pull request #100 from zstar1003/dev
This commit is contained in:
commit
fd7f1140cd
|
@ -37,7 +37,8 @@ from api.db.services.file_service import FileService
|
|||
|
||||
from flask import jsonify, request, Response
|
||||
|
||||
@manager.route('/chats/<chat_id>/sessions', methods=['POST']) # noqa: F821
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id, chat_id):
|
||||
req = request.json
|
||||
|
@ -50,7 +51,7 @@ def create(tenant_id, chat_id):
|
|||
"dialog_id": req["dialog_id"],
|
||||
"name": req.get("name", "New session"),
|
||||
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}],
|
||||
"user_id": req.get("user_id", "")
|
||||
"user_id": req.get("user_id", ""),
|
||||
}
|
||||
if not conv.get("name"):
|
||||
return get_error_data_result(message="`name` can not be empty.")
|
||||
|
@ -59,20 +60,20 @@ def create(tenant_id, chat_id):
|
|||
if not e:
|
||||
return get_error_data_result(message="Fail to create a session!")
|
||||
conv = conv.to_dict()
|
||||
conv['messages'] = conv.pop("message")
|
||||
conv["messages"] = conv.pop("message")
|
||||
conv["chat_id"] = conv.pop("dialog_id")
|
||||
del conv["reference"]
|
||||
return get_result(data=conv)
|
||||
|
||||
|
||||
@manager.route('/agents/<agent_id>/sessions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create_agent_session(tenant_id, agent_id):
|
||||
req = request.json
|
||||
if not request.is_json:
|
||||
req = request.form
|
||||
files = request.files
|
||||
user_id = request.args.get('user_id', '')
|
||||
user_id = request.args.get("user_id", "")
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
if not e:
|
||||
|
@ -113,7 +114,7 @@ def create_agent_session(tenant_id, agent_id):
|
|||
ele.pop("value")
|
||||
else:
|
||||
if req is not None and req.get(ele["key"]):
|
||||
ele["value"] = req[ele['key']]
|
||||
ele["value"] = req[ele["key"]]
|
||||
else:
|
||||
if "value" in ele:
|
||||
ele.pop("value")
|
||||
|
@ -121,20 +122,13 @@ def create_agent_session(tenant_id, agent_id):
|
|||
for ans in canvas.run(stream=False):
|
||||
pass
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": user_id,
|
||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||
"source": "agent",
|
||||
"dsl": cvs.dsl
|
||||
}
|
||||
conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
|
||||
API4ConversationService.save(**conv)
|
||||
conv["agent_id"] = conv.pop("dialog_id")
|
||||
return get_result(data=conv)
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT']) # noqa: F821
|
||||
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, chat_id, session_id):
|
||||
req = request.json
|
||||
|
@ -156,14 +150,14 @@ def update(tenant_id, chat_id, session_id):
|
|||
return get_result()
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/completions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def chat_completion(tenant_id, chat_id):
|
||||
req = request.json
|
||||
if not req:
|
||||
req = {"question": ""}
|
||||
if not req.get("session_id"):
|
||||
req["question"]=""
|
||||
req["question"] = ""
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(f"You don't own the chat {chat_id}")
|
||||
if req.get("session_id"):
|
||||
|
@ -185,7 +179,7 @@ def chat_completion(tenant_id, chat_id):
|
|||
return get_result(data=answer)
|
||||
|
||||
|
||||
@manager.route('chats_openai/<chat_id>/chat/completions', methods=['POST']) # noqa: F821
|
||||
@manager.route("chats_openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
def chat_completion_openai_like(tenant_id, chat_id):
|
||||
|
@ -259,39 +253,23 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||
# The choices field on the last chunk will always be an empty array [].
|
||||
def streamed_response_generator(chat_id, dia, msg):
|
||||
token_used = 0
|
||||
should_split_index = 0
|
||||
answer_cache = ""
|
||||
response = {
|
||||
"id": f"chatcmpl-{chat_id}",
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": "",
|
||||
"role": "assistant",
|
||||
"function_call": None,
|
||||
"tool_calls": None
|
||||
},
|
||||
"finish_reason": None,
|
||||
"index": 0,
|
||||
"logprobs": None
|
||||
}
|
||||
],
|
||||
"choices": [{"delta": {"content": "", "role": "assistant", "function_call": None, "tool_calls": None}, "finish_reason": None, "index": 0, "logprobs": None}],
|
||||
"created": int(time.time()),
|
||||
"model": "model",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "",
|
||||
"usage": None
|
||||
"usage": None,
|
||||
}
|
||||
|
||||
try:
|
||||
for ans in chat(dia, msg, True):
|
||||
answer = ans["answer"]
|
||||
incremental = answer[should_split_index:]
|
||||
incremental = answer.replace(answer_cache, "", 1)
|
||||
answer_cache = answer.rstrip("</think>")
|
||||
token_used += len(incremental)
|
||||
if incremental.endswith("</think>"):
|
||||
response_data_len = len(incremental.rstrip("</think>"))
|
||||
else:
|
||||
response_data_len = len(incremental)
|
||||
should_split_index += response_data_len
|
||||
response["choices"][0]["delta"]["content"] = incremental
|
||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
|
@ -301,15 +279,10 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||
# The last chunk
|
||||
response["choices"][0]["delta"]["content"] = None
|
||||
response["choices"][0]["finish_reason"] = "stop"
|
||||
response["usage"] = {
|
||||
"prompt_tokens": len(prompt),
|
||||
"completion_tokens": token_used,
|
||||
"total_tokens": len(prompt) + token_used
|
||||
}
|
||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||
yield "data:[DONE]\n\n"
|
||||
|
||||
|
||||
resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
|
@ -324,7 +297,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||
break
|
||||
content = answer["answer"]
|
||||
|
||||
response = {
|
||||
response = {
|
||||
"id": f"chatcmpl-{chat_id}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
|
@ -336,25 +309,15 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||
"completion_tokens_details": {
|
||||
"reasoning_tokens": context_token_used,
|
||||
"accepted_prediction_tokens": len(content),
|
||||
"rejected_prediction_tokens": 0 # 0 for simplicity
|
||||
}
|
||||
"rejected_prediction_tokens": 0, # 0 for simplicity
|
||||
},
|
||||
},
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
"index": 0
|
||||
}
|
||||
]
|
||||
"choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": "stop", "index": 0}],
|
||||
}
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
@manager.route('/agents/<agent_id>/completions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def agent_completions(tenant_id, agent_id):
|
||||
req = request.json
|
||||
|
@ -365,8 +328,8 @@ def agent_completions(tenant_id, agent_id):
|
|||
dsl = cvs[0].dsl
|
||||
if not isinstance(dsl, str):
|
||||
dsl = json.dumps(dsl)
|
||||
#canvas = Canvas(dsl, tenant_id)
|
||||
#if canvas.get_preset_param():
|
||||
# canvas = Canvas(dsl, tenant_id)
|
||||
# if canvas.get_preset_param():
|
||||
# req["question"] = ""
|
||||
conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id)
|
||||
if not conv:
|
||||
|
@ -380,9 +343,7 @@ def agent_completions(tenant_id, agent_id):
|
|||
states = {field: current_dsl.get(field, []) for field in state_fields}
|
||||
current_dsl.update(new_dsl)
|
||||
current_dsl.update(states)
|
||||
API4ConversationService.update_by_id(req["session_id"], {
|
||||
"dsl": current_dsl
|
||||
})
|
||||
API4ConversationService.update_by_id(req["session_id"], {"dsl": current_dsl})
|
||||
else:
|
||||
req["question"] = ""
|
||||
if req.get("stream", True):
|
||||
|
@ -399,7 +360,7 @@ def agent_completions(tenant_id, agent_id):
|
|||
return get_error_data_result(str(e))
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/sessions', methods=['GET']) # noqa: F821
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_session(tenant_id, chat_id):
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
|
@ -418,7 +379,7 @@ def list_session(tenant_id, chat_id):
|
|||
if not convs:
|
||||
return get_result(data=[])
|
||||
for conv in convs:
|
||||
conv['messages'] = conv.pop("message")
|
||||
conv["messages"] = conv.pop("message")
|
||||
infos = conv["messages"]
|
||||
for info in infos:
|
||||
if "prompt" in info:
|
||||
|
@ -452,7 +413,7 @@ def list_session(tenant_id, chat_id):
|
|||
return get_result(data=convs)
|
||||
|
||||
|
||||
@manager.route('/agents/<agent_id>/sessions', methods=['GET']) # noqa: F821
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_agent_session(tenant_id, agent_id):
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
|
@ -468,12 +429,11 @@ def list_agent_session(tenant_id, agent_id):
|
|||
desc = True
|
||||
# dsl defaults to True in all cases except for False and false
|
||||
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
|
||||
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id,
|
||||
user_id, include_dsl)
|
||||
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl)
|
||||
if not convs:
|
||||
return get_result(data=[])
|
||||
for conv in convs:
|
||||
conv['messages'] = conv.pop("message")
|
||||
conv["messages"] = conv.pop("message")
|
||||
infos = conv["messages"]
|
||||
for info in infos:
|
||||
if "prompt" in info:
|
||||
|
@ -506,7 +466,7 @@ def list_agent_session(tenant_id, agent_id):
|
|||
return get_result(data=convs)
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"]) # noqa: F821
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id, chat_id):
|
||||
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
|
@ -532,14 +492,14 @@ def delete(tenant_id, chat_id):
|
|||
return get_result()
|
||||
|
||||
|
||||
@manager.route('/agents/<agent_id>/sessions', methods=["DELETE"]) # noqa: F821
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete_agent_session(tenant_id, agent_id):
|
||||
req = request.json
|
||||
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
|
||||
if not cvs:
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
|
||||
|
||||
convs = API4ConversationService.query(dialog_id=agent_id)
|
||||
if not convs:
|
||||
return get_error_data_result(f"Agent {agent_id} has no sessions")
|
||||
|
@ -555,16 +515,16 @@ def delete_agent_session(tenant_id, agent_id):
|
|||
conv_list.append(conv.id)
|
||||
else:
|
||||
conv_list = ids
|
||||
|
||||
|
||||
for session_id in conv_list:
|
||||
conv = API4ConversationService.query(id=session_id, dialog_id=agent_id)
|
||||
if not conv:
|
||||
return get_error_data_result(f"The agent doesn't own the session ${session_id}")
|
||||
API4ConversationService.delete_by_id(session_id)
|
||||
return get_result()
|
||||
|
||||
|
||||
@manager.route('/sessions/ask', methods=['POST']) # noqa: F821
|
||||
|
||||
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def ask_about(tenant_id):
|
||||
req = request.json
|
||||
|
@ -590,9 +550,7 @@ def ask_about(tenant_id):
|
|||
for ans in ask(req["question"], req["kb_ids"], uid):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
|
@ -603,7 +561,7 @@ def ask_about(tenant_id):
|
|||
return resp
|
||||
|
||||
|
||||
@manager.route('/sessions/related_questions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def related_questions(tenant_id):
|
||||
req = request.json
|
||||
|
@ -635,18 +593,27 @@ Reason:
|
|||
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
||||
|
||||
"""
|
||||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f"""
|
||||
ans = chat_mdl.chat(
|
||||
prompt,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Keywords: {question}
|
||||
Related search terms:
|
||||
"""}], {"temperature": 0.9})
|
||||
""",
|
||||
}
|
||||
],
|
||||
{"temperature": 0.9},
|
||||
)
|
||||
return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
|
||||
|
||||
|
||||
@manager.route('/chatbots/<dialog_id>/completions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
|
||||
def chatbot_completions(dialog_id):
|
||||
req = request.json
|
||||
|
||||
token = request.headers.get('Authorization').split()
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
token = token[1]
|
||||
|
@ -669,11 +636,11 @@ def chatbot_completions(dialog_id):
|
|||
return get_result(data=answer)
|
||||
|
||||
|
||||
@manager.route('/agentbots/<agent_id>/completions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
def agent_bot_completions(agent_id):
|
||||
req = request.json
|
||||
|
||||
token = request.headers.get('Authorization').split()
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
token = token[1]
|
||||
|
|
|
@ -47,20 +47,14 @@ class DocumentService(CommonService):
|
|||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords, id, name):
|
||||
def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, id, name):
|
||||
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
||||
if id:
|
||||
docs = docs.where(
|
||||
cls.model.id == id)
|
||||
docs = docs.where(cls.model.id == id)
|
||||
if name:
|
||||
docs = docs.where(
|
||||
cls.model.name == name
|
||||
)
|
||||
docs = docs.where(cls.model.name == name)
|
||||
if keywords:
|
||||
docs = docs.where(
|
||||
fn.LOWER(cls.model.name).contains(keywords.lower())
|
||||
)
|
||||
docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
if desc:
|
||||
docs = docs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
|
@ -72,13 +66,9 @@ class DocumentService(CommonService):
|
|||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords):
|
||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords):
|
||||
if keywords:
|
||||
docs = cls.model.select().where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
docs = cls.model.select().where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
|
||||
else:
|
||||
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
||||
count = docs.count()
|
||||
|
@ -97,8 +87,7 @@ class DocumentService(CommonService):
|
|||
if not cls.save(**doc):
|
||||
raise RuntimeError("Database error (Document)!")
|
||||
e, kb = KnowledgebaseService.get_by_id(doc["kb_id"])
|
||||
if not KnowledgebaseService.update_by_id(
|
||||
kb.id, {"doc_num": kb.doc_num + 1}):
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"doc_num": kb.doc_num + 1}):
|
||||
raise RuntimeError("Database error (Knowledgebase)!")
|
||||
return Document(**doc)
|
||||
|
||||
|
@ -108,14 +97,16 @@ class DocumentService(CommonService):
|
|||
cls.clear_chunk_num(doc.id)
|
||||
try:
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id},
|
||||
{"remove": {"source_id": doc.id}},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
|
||||
{"removed_kwd": "Y"},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.update(
|
||||
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id},
|
||||
{"remove": {"source_id": doc.id}},
|
||||
search.index_name(tenant_id),
|
||||
doc.kb_id,
|
||||
)
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.delete(
|
||||
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}}, search.index_name(tenant_id), doc.kb_id
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return cls.delete_by_id(doc.id)
|
||||
|
@ -136,67 +127,54 @@ class DocumentService(CommonService):
|
|||
Tenant.embd_id,
|
||||
Tenant.img2txt_id,
|
||||
Tenant.asr_id,
|
||||
cls.model.update_time]
|
||||
docs = cls.model.select(*fields) \
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
|
||||
cls.model.update_time,
|
||||
]
|
||||
docs = (
|
||||
cls.model.select(*fields)
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress == 0,
|
||||
cls.model.update_time >= current_timestamp() - 1000 * 600,
|
||||
cls.model.run == TaskStatus.RUNNING.value) \
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress == 0,
|
||||
cls.model.update_time >= current_timestamp() - 1000 * 600,
|
||||
cls.model.run == TaskStatus.RUNNING.value,
|
||||
)
|
||||
.order_by(cls.model.update_time.asc())
|
||||
)
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_unfinished_docs(cls):
|
||||
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
|
||||
cls.model.run, cls.model.parser_id]
|
||||
docs = cls.model.select(*fields) \
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress < 1,
|
||||
cls.model.progress > 0)
|
||||
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id]
|
||||
docs = cls.model.select(*fields).where(cls.model.status == StatusEnum.VALID.value, ~(cls.model.type == FileType.VIRTUAL.value), cls.model.progress < 1, cls.model.progress > 0)
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
||||
num = cls.model.update(token_num=cls.model.token_num + token_num,
|
||||
chunk_num=cls.model.chunk_num + chunk_num,
|
||||
process_duation=cls.model.process_duation + duation).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
num = (
|
||||
cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duation=cls.model.process_duation + duation)
|
||||
.where(cls.model.id == doc_id)
|
||||
.execute()
|
||||
)
|
||||
if num == 0:
|
||||
raise LookupError(
|
||||
"Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num +
|
||||
token_num,
|
||||
chunk_num=Knowledgebase.chunk_num +
|
||||
chunk_num).where(
|
||||
Knowledgebase.id == kb_id).execute()
|
||||
raise LookupError("Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
||||
num = cls.model.update(token_num=cls.model.token_num - token_num,
|
||||
chunk_num=cls.model.chunk_num - chunk_num,
|
||||
process_duation=cls.model.process_duation + duation).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
num = (
|
||||
cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duation=cls.model.process_duation + duation)
|
||||
.where(cls.model.id == doc_id)
|
||||
.execute()
|
||||
)
|
||||
if num == 0:
|
||||
raise LookupError(
|
||||
"Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num -
|
||||
token_num,
|
||||
chunk_num=Knowledgebase.chunk_num -
|
||||
chunk_num
|
||||
).where(
|
||||
Knowledgebase.id == kb_id).execute()
|
||||
raise LookupError("Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
|
@ -205,24 +183,17 @@ class DocumentService(CommonService):
|
|||
doc = cls.model.get_by_id(doc_id)
|
||||
assert doc, "Can't fine document in database."
|
||||
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num -
|
||||
doc.token_num,
|
||||
chunk_num=Knowledgebase.chunk_num -
|
||||
doc.chunk_num,
|
||||
doc_num=Knowledgebase.doc_num - 1
|
||||
).where(
|
||||
Knowledgebase.id == doc.kb_id).execute()
|
||||
num = (
|
||||
Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1)
|
||||
.where(Knowledgebase.id == doc.kb_id)
|
||||
.execute()
|
||||
)
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_id(cls, doc_id):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.tenant_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
|
@ -240,11 +211,7 @@ class DocumentService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_id_by_name(cls, name):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.tenant_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
|
@ -253,12 +220,13 @@ class DocumentService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible(cls, doc_id, user_id):
|
||||
docs = cls.model.select(
|
||||
cls.model.id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)
|
||||
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
||||
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
|
||||
docs = (
|
||||
cls.model.select(cls.model.id)
|
||||
.join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
|
||||
.join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id))
|
||||
.where(cls.model.id == doc_id, UserTenant.user_id == user_id)
|
||||
.paginate(0, 1)
|
||||
)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return False
|
||||
|
@ -267,11 +235,7 @@ class DocumentService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible4deletion(cls, doc_id, user_id):
|
||||
docs = cls.model.select(
|
||||
cls.model.id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)
|
||||
).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
|
||||
docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return False
|
||||
|
@ -280,11 +244,7 @@ class DocumentService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_embd_id(cls, doc_id):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.embd_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = cls.model.select(Knowledgebase.embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
|
@ -306,9 +266,9 @@ class DocumentService(CommonService):
|
|||
Tenant.asr_id,
|
||||
Tenant.llm_id,
|
||||
)
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(cls.model.id == doc_id)
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(cls.model.id == doc_id)
|
||||
)
|
||||
configs = configs.dicts()
|
||||
if not configs:
|
||||
|
@ -319,8 +279,7 @@ class DocumentService(CommonService):
|
|||
@DB.connection_context()
|
||||
def get_doc_id_by_doc_name(cls, doc_name):
|
||||
fields = [cls.model.id]
|
||||
doc_id = cls.model.select(*fields) \
|
||||
.where(cls.model.name == doc_name)
|
||||
doc_id = cls.model.select(*fields).where(cls.model.name == doc_name)
|
||||
doc_id = doc_id.dicts()
|
||||
if not doc_id:
|
||||
return
|
||||
|
@ -330,8 +289,7 @@ class DocumentService(CommonService):
|
|||
@DB.connection_context()
|
||||
def get_thumbnails(cls, docids):
|
||||
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
|
||||
return list(cls.model.select(
|
||||
*fields).where(cls.model.id.in_(docids)).dicts())
|
||||
return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
|
@ -359,19 +317,14 @@ class DocumentService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_doc_count(cls, tenant_id):
|
||||
docs = cls.model.select(cls.model.id).join(Knowledgebase,
|
||||
on=(Knowledgebase.id == cls.model.kb_id)).where(
|
||||
Knowledgebase.tenant_id == tenant_id)
|
||||
docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(Knowledgebase.tenant_id == tenant_id)
|
||||
return len(docs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def begin2parse(cls, docid):
|
||||
cls.update_by_id(
|
||||
docid, {"progress": random.random() * 1 / 100.,
|
||||
"progress_msg": "Task is queued...",
|
||||
"process_begin_at": get_format_time()
|
||||
})
|
||||
cls.update_by_id(docid, {"progress": random.random() * 1 / 100.0, "progress_msg": "Task is queued...", "process_begin_at": get_format_time()})
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_meta_fields(cls, doc_id, meta_fields):
|
||||
|
@ -420,11 +373,7 @@ class DocumentService(CommonService):
|
|||
status = TaskStatus.DONE.value
|
||||
|
||||
msg = "\n".join(sorted(msg))
|
||||
info = {
|
||||
"process_duation": datetime.timestamp(
|
||||
datetime.now()) -
|
||||
d["process_begin_at"].timestamp(),
|
||||
"run": status}
|
||||
info = {"process_duation": datetime.timestamp(datetime.now()) - d["process_begin_at"].timestamp(), "run": status}
|
||||
if prg != 0:
|
||||
info["progress"] = prg
|
||||
if msg:
|
||||
|
@ -437,8 +386,7 @@ class DocumentService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_doc_count(cls, kb_id):
|
||||
return len(cls.model.select(cls.model.id).where(
|
||||
cls.model.kb_id == kb_id).dicts())
|
||||
return len(cls.model.select(cls.model.id).where(cls.model.kb_id == kb_id).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
|
@ -459,14 +407,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty):
|
|||
|
||||
def new_task():
|
||||
nonlocal doc
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc["id"],
|
||||
"from_page": 100000000,
|
||||
"to_page": 100000000,
|
||||
"task_type": ty,
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
|
||||
}
|
||||
return {"id": get_uuid(), "doc_id": doc["id"], "from_page": 100000000, "to_page": 100000000, "task_type": ty, "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty}
|
||||
|
||||
task = new_task()
|
||||
for field in ["doc_id", "from_page", "to_page"]:
|
||||
|
@ -478,6 +419,25 @@ def queue_raptor_o_graphrag_tasks(doc, ty):
|
|||
|
||||
|
||||
def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
"""
|
||||
上传并解析文档,将内容存入知识库
|
||||
|
||||
参数:
|
||||
conversation_id: 会话ID
|
||||
file_objs: 文件对象列表
|
||||
user_id: 用户ID
|
||||
|
||||
返回:
|
||||
处理成功的文档ID列表
|
||||
|
||||
处理流程:
|
||||
1. 验证会话和知识库
|
||||
2. 初始化嵌入模型
|
||||
3. 上传文件到存储
|
||||
4. 多线程解析文件内容
|
||||
5. 生成内容嵌入向量
|
||||
6. 存入文档存储系统
|
||||
"""
|
||||
from rag.app import presentation, picture, naive, audio, email
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.file_service import FileService
|
||||
|
@ -493,8 +453,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not dia.kb_ids:
|
||||
raise LookupError("No knowledge base associated with this conversation. "
|
||||
"Please add a knowledge base before uploading documents")
|
||||
raise LookupError("No knowledge base associated with this conversation. Please add a knowledge base before uploading documents")
|
||||
kb_id = dia.kb_ids[0]
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
|
@ -508,12 +467,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
FACTORY = {
|
||||
ParserType.PRESENTATION.value: presentation,
|
||||
ParserType.PICTURE.value: picture,
|
||||
ParserType.AUDIO.value: audio,
|
||||
ParserType.EMAIL.value: email
|
||||
}
|
||||
FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
|
||||
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
|
||||
# 使用线程池执行解析任务
|
||||
exe = ThreadPoolExecutor(max_workers=12)
|
||||
|
@ -522,22 +476,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
for d, blob in files:
|
||||
doc_nm[d["id"]] = d["name"]
|
||||
for d, blob in files:
|
||||
kwargs = {
|
||||
"callback": dummy,
|
||||
"parser_config": parser_config,
|
||||
"from_page": 0,
|
||||
"to_page": 100000,
|
||||
"tenant_id": kb.tenant_id,
|
||||
"lang": kb.language
|
||||
}
|
||||
kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language}
|
||||
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
|
||||
|
||||
for (docinfo, _), th in zip(files, threads):
|
||||
docs = []
|
||||
doc = {
|
||||
"doc_id": docinfo["id"],
|
||||
"kb_id": [kb.id]
|
||||
}
|
||||
doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]}
|
||||
for ck in th.result():
|
||||
d = deepcopy(doc)
|
||||
d.update(ck)
|
||||
|
@ -552,7 +496,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
if isinstance(d["image"], bytes):
|
||||
output_buffer = BytesIO(d["image"])
|
||||
else:
|
||||
d["image"].save(output_buffer, format='JPEG')
|
||||
d["image"].save(output_buffer, format="JPEG")
|
||||
|
||||
STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(kb.id, d["id"])
|
||||
|
@ -569,9 +513,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
nonlocal embd_mdl, chunk_counts, token_counts
|
||||
vects = []
|
||||
for i in range(0, len(cnts), batch_size):
|
||||
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
|
||||
vts, c = embd_mdl.encode(cnts[i : i + batch_size])
|
||||
vects.extend(vts.tolist())
|
||||
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
|
||||
chunk_counts[doc_id] += len(cnts[i : i + batch_size])
|
||||
token_counts[doc_id] += c
|
||||
return vects
|
||||
|
||||
|
@ -585,22 +529,25 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
|
||||
if parser_ids[doc_id] != ParserType.PICTURE.value:
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
|
||||
mindmap = MindMapExtractor(llm_bdl)
|
||||
try:
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
|
||||
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
|
||||
if len(mind_map) < 32:
|
||||
raise Exception("Few content: " + mind_map)
|
||||
cks.append({
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc_id,
|
||||
"kb_id": [kb.id],
|
||||
"docnm_kwd": doc_nm[doc_id],
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
|
||||
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
|
||||
"content_with_weight": mind_map,
|
||||
"knowledge_graph_kwd": "mind_map"
|
||||
})
|
||||
cks.append(
|
||||
{
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc_id,
|
||||
"kb_id": [kb.id],
|
||||
"docnm_kwd": doc_nm[doc_id],
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
|
||||
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
|
||||
"content_with_weight": mind_map,
|
||||
"knowledge_graph_kwd": "mind_map",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("Mind map generation error")
|
||||
|
||||
|
@ -614,9 +561,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
|
||||
return [d["id"] for d, _ in files]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from flask import jsonify, request
|
||||
from services.conversation.service import get_conversations_by_user_id, get_messages_by_conversation_id, get_conversation_detail
|
||||
from services.conversation.service import get_conversations_by_user_id, get_messages_by_conversation_id
|
||||
from .. import conversation_bp
|
||||
|
||||
|
||||
|
@ -44,20 +44,3 @@ def get_messages(conversation_id):
|
|||
except Exception as e:
|
||||
# 错误处理
|
||||
return jsonify({"code": 500, "message": f"获取消息列表失败: {str(e)}"}), 500
|
||||
|
||||
|
||||
@conversation_bp.route("/<conversation_id>", methods=["GET"])
|
||||
def get_conversation(conversation_id):
|
||||
"""获取特定对话的详细信息"""
|
||||
try:
|
||||
# 调用服务函数获取对话详情
|
||||
conversation = get_conversation_detail(conversation_id)
|
||||
|
||||
if not conversation:
|
||||
return jsonify({"code": 404, "message": "对话不存在"}), 404
|
||||
|
||||
# 返回符合前端期望格式的数据
|
||||
return jsonify({"code": 0, "data": conversation, "message": "获取对话详情成功"})
|
||||
except Exception as e:
|
||||
# 错误处理
|
||||
return jsonify({"code": 500, "message": f"获取对话详情失败: {str(e)}"}), 500
|
||||
|
|
|
@ -23,8 +23,6 @@ def get_conversations_by_user_id(user_id, page=1, size=20, sort_by="update_time"
|
|||
# 直接使用user_id作为tenant_id
|
||||
tenant_id = user_id
|
||||
|
||||
print(f"查询用户ID: {user_id}, 租户ID: {tenant_id}")
|
||||
|
||||
# 查询总记录数
|
||||
count_sql = """
|
||||
SELECT COUNT(*) as total
|
||||
|
@ -34,7 +32,7 @@ def get_conversations_by_user_id(user_id, page=1, size=20, sort_by="update_time"
|
|||
cursor.execute(count_sql, (tenant_id,))
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
print(f"查询到总记录数: {total}")
|
||||
# print(f"查询到总记录数: {total}")
|
||||
|
||||
# 计算分页偏移量
|
||||
offset = (page - 1) * size
|
||||
|
@ -59,8 +57,8 @@ def get_conversations_by_user_id(user_id, page=1, size=20, sort_by="update_time"
|
|||
LIMIT %s OFFSET %s
|
||||
"""
|
||||
|
||||
print(f"执行查询: {query}")
|
||||
print(f"参数: tenant_id={tenant_id}, size={size}, offset={offset}")
|
||||
# print(f"执行查询: {query}")
|
||||
# print(f"参数: tenant_id={tenant_id}, size={size}, offset={offset}")
|
||||
|
||||
cursor.execute(query, (tenant_id, size, offset))
|
||||
results = cursor.fetchall()
|
||||
|
@ -200,68 +198,3 @@ def get_messages_by_conversation_id(conversation_id, page=1, size=30):
|
|||
|
||||
traceback.print_exc()
|
||||
return None, 0
|
||||
|
||||
|
||||
def get_conversation_detail(conversation_id):
|
||||
"""
|
||||
获取特定对话的详细信息
|
||||
|
||||
参数:
|
||||
conversation_id (str): 对话ID
|
||||
|
||||
返回:
|
||||
dict: 对话详情
|
||||
"""
|
||||
try:
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询对话信息
|
||||
query = """
|
||||
SELECT c.*, d.name as dialog_name, d.icon as dialog_icon
|
||||
FROM conversation c
|
||||
LEFT JOIN dialog d ON c.dialog_id = d.id
|
||||
WHERE c.id = %s
|
||||
"""
|
||||
cursor.execute(query, (conversation_id,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
if not result:
|
||||
print(f"未找到对话ID: {conversation_id}")
|
||||
return None
|
||||
|
||||
# 格式化对话详情
|
||||
conversation = {
|
||||
"id": result["id"],
|
||||
"name": result.get("name", ""),
|
||||
"dialogId": result.get("dialog_id", ""),
|
||||
"dialogName": result.get("dialog_name", ""),
|
||||
"dialogIcon": result.get("dialog_icon", ""),
|
||||
"createTime": result["create_date"].strftime("%Y-%m-%d %H:%M:%S") if result.get("create_date") else "",
|
||||
"updateTime": result["update_date"].strftime("%Y-%m-%d %H:%M:%S") if result.get("update_date") else "",
|
||||
"messages": result.get("message", []),
|
||||
}
|
||||
|
||||
# 打印调试信息
|
||||
print(f"获取到对话详情: ID={conversation_id}")
|
||||
print(f"消息数量: {len(conversation['messages']) if conversation['messages'] else 0}")
|
||||
|
||||
# 关闭连接
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
return conversation
|
||||
|
||||
except mysql.connector.Error as err:
|
||||
print(f"数据库错误: {err}")
|
||||
# 更详细的错误日志
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"未知错误: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
<script lang="ts" setup>
|
||||
import type { CreateOrUpdateTableRequestData, TableData } from "@@/apis/tables/type"
|
||||
import type { FormInstance, FormRules } from "element-plus"
|
||||
import { createTableDataApi, deleteTableDataApi, getTableDataApi, resetPasswordApi, updateTableDataApi } from "@@/apis/tables"
|
||||
import { usePagination } from "@@/composables/usePagination"
|
||||
import { ChatDotRound, CirclePlus, Delete, Edit, Key, Refresh, RefreshRight, Search, User } from "@element-plus/icons-vue"
|
||||
import type { TableData } from "@@/apis/tables/type"
|
||||
import { getTableDataApi } from "@@/apis/tables"
|
||||
import { ChatDotRound, User } from "@element-plus/icons-vue"
|
||||
import axios from "axios"
|
||||
import { cloneDeep } from "lodash-es"
|
||||
|
||||
defineOptions({
|
||||
// 命名当前组件
|
||||
|
@ -79,10 +76,16 @@ function loadMoreUsers() {
|
|||
|
||||
/**
|
||||
* 监听用户列表滚动事件
|
||||
* @param event 滚动事件
|
||||
* @param event DOM滚动事件对象
|
||||
*/
|
||||
function handleUserListScroll(event) {
|
||||
const { scrollTop, scrollHeight, clientHeight } = event.target
|
||||
function handleUserListScroll(event: Event) {
|
||||
// 将 event.target 断言为 HTMLElement 并检查是否存在
|
||||
const target = event.target as HTMLElement
|
||||
if (!target) return
|
||||
|
||||
// 获取滚动相关属性
|
||||
const { scrollTop, scrollHeight, clientHeight } = target
|
||||
|
||||
// 当滚动到距离底部100px时,加载更多数据
|
||||
if (scrollHeight - scrollTop - clientHeight < 100 && userHasMore.value && !userLoading.value) {
|
||||
loadMoreUsers()
|
||||
|
@ -203,10 +206,16 @@ function loadMoreConversations() {
|
|||
|
||||
/**
|
||||
* 监听对话列表滚动事件
|
||||
* @param event 滚动事件
|
||||
* @param event DOM滚动事件对象
|
||||
*/
|
||||
function handleConversationListScroll(event) {
|
||||
const { scrollTop, scrollHeight, clientHeight } = event.target
|
||||
function handleConversationListScroll(event: Event) {
|
||||
// 将 event.target 断言为 HTMLElement 并检查是否存在
|
||||
const target = event.target as HTMLElement
|
||||
if (!target) return
|
||||
|
||||
// 获取滚动相关属性
|
||||
const { scrollTop, scrollHeight, clientHeight } = target
|
||||
|
||||
// 当滚动到距离底部100px时,加载更多数据
|
||||
if (scrollHeight - scrollTop - clientHeight < 100 && conversationHasMore.value && !conversationLoading.value) {
|
||||
loadMoreConversations()
|
||||
|
@ -247,7 +256,7 @@ function getMessagesByConversationId(conversationId: string, isLoadMore = false)
|
|||
const parsedMessages = JSON.parse(conversation.messages)
|
||||
|
||||
// 格式化消息数据
|
||||
processedMessages = parsedMessages.map((msg, index) => {
|
||||
processedMessages = parsedMessages.map((msg: { id: any, role: any, content: any, created_at: number }, index: any) => {
|
||||
return {
|
||||
id: msg.id || `msg-${index}`,
|
||||
conversation_id: conversationId,
|
||||
|
@ -268,7 +277,7 @@ function getMessagesByConversationId(conversationId: string, isLoadMore = false)
|
|||
if (isLoadMore) {
|
||||
// 防止重复加载:检查新消息是否已存在
|
||||
const existingIds = new Set(messageList.value.map(msg => msg.id))
|
||||
const uniqueNewMessages = processedMessages.filter(msg => !existingIds.has(msg.id))
|
||||
const uniqueNewMessages = processedMessages.filter((msg: { id: number }) => !existingIds.has(msg.id))
|
||||
|
||||
// 追加新的唯一消息
|
||||
messageList.value = [...messageList.value, ...uniqueNewMessages]
|
||||
|
@ -307,7 +316,7 @@ function getMessagesByConversationId(conversationId: string, isLoadMore = false)
|
|||
* @param content 消息内容
|
||||
* @returns 处理后的HTML内容
|
||||
*/
|
||||
function renderMessageContent(content) {
|
||||
function renderMessageContent(content: string) {
|
||||
if (!content) return ""
|
||||
|
||||
// 处理Markdown格式的图片
|
||||
|
@ -331,10 +340,16 @@ function loadMoreMessages() {
|
|||
|
||||
/**
|
||||
* 监听消息列表滚动事件
|
||||
* @param event 滚动事件
|
||||
* @param event DOM滚动事件对象
|
||||
*/
|
||||
function handleMessageListScroll(event) {
|
||||
const { scrollTop, scrollHeight, clientHeight } = event.target
|
||||
function handleMessageListScroll(event: Event) {
|
||||
// 将 event.target 断言为 HTMLElement 并检查是否存在
|
||||
const target = event.target as HTMLElement
|
||||
if (!target) return
|
||||
|
||||
// 获取滚动相关属性
|
||||
const { scrollTop, scrollHeight, clientHeight } = target
|
||||
|
||||
// 当滚动到距离底部100px时,加载更多数据(向下滚动加载更多)
|
||||
if (scrollHeight - scrollTop - clientHeight < 100 && messageHasMore.value && !messageLoading.value) {
|
||||
loadMoreMessages()
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
<script lang="ts" setup>
|
||||
import type { CreateOrUpdateTableRequestData, TableData } from "@@/apis/configs/type"
|
||||
import type { FormInstance, FormRules } from "element-plus"
|
||||
import { getTableDataApi, updateTableDataApi } from "@@/apis/configs"
|
||||
import type { TableData } from "@@/apis/configs/type"
|
||||
import type { FormInstance } from "element-plus"
|
||||
import { getTableDataApi } from "@@/apis/configs"
|
||||
import { usePagination } from "@@/composables/usePagination"
|
||||
import { CirclePlus, Delete, Refresh, RefreshRight, Search } from "@element-plus/icons-vue"
|
||||
import { cloneDeep } from "lodash-es"
|
||||
import { Refresh, Search } from "@element-plus/icons-vue"
|
||||
|
||||
defineOptions({
|
||||
// 命名当前组件
|
||||
|
@ -15,47 +14,41 @@ const loading = ref<boolean>(false)
|
|||
const { paginationData, handleCurrentChange, handleSizeChange } = usePagination()
|
||||
|
||||
// #region 增
|
||||
const DEFAULT_FORM_DATA: CreateOrUpdateTableRequestData = {
|
||||
id: undefined,
|
||||
username: "",
|
||||
chatModel: "",
|
||||
embeddingModel: ""
|
||||
}
|
||||
const dialogVisible = ref<boolean>(false)
|
||||
const formData = ref<CreateOrUpdateTableRequestData>(cloneDeep(DEFAULT_FORM_DATA))
|
||||
// const DEFAULT_FORM_DATA: CreateOrUpdateTableRequestData = {
|
||||
// id: undefined,
|
||||
// username: "",
|
||||
// chatModel: "",
|
||||
// embeddingModel: ""
|
||||
// }
|
||||
// const dialogVisible = ref<boolean>(false)
|
||||
// const formData = ref<CreateOrUpdateTableRequestData>(cloneDeep(DEFAULT_FORM_DATA))
|
||||
|
||||
// 删除响应
|
||||
function handleDelete() {
|
||||
ElMessage.success("如需删除租户配置,可直接删除负责人账号")
|
||||
ElMessage.success("如需删除用户配置信息,请直接在前台登录用户账号进行操作")
|
||||
}
|
||||
|
||||
// 改
|
||||
function handleUpdate(row: TableData) {
|
||||
dialogVisible.value = true
|
||||
formData.value = cloneDeep({
|
||||
id: row.id,
|
||||
username: row.username,
|
||||
chatModel: row.chatModel,
|
||||
embeddingModel: row.embeddingModel
|
||||
})
|
||||
// 修改
|
||||
function handleUpdate() {
|
||||
ElMessage.success("如需修改用户配置信息,请直接在前台登录用户账号进行操作")
|
||||
}
|
||||
|
||||
// 处理修改表单提交
|
||||
function submitForm() {
|
||||
loading.value = true
|
||||
updateTableDataApi(formData.value)
|
||||
.then(() => {
|
||||
ElMessage.success("修改成功")
|
||||
dialogVisible.value = false
|
||||
getTableData() // 刷新表格数据
|
||||
})
|
||||
.catch(() => {
|
||||
ElMessage.error("修改失败")
|
||||
})
|
||||
.finally(() => {
|
||||
loading.value = false
|
||||
})
|
||||
}
|
||||
// function submitForm() {
|
||||
// loading.value = true
|
||||
// updateTableDataApi(formData.value)
|
||||
// .then(() => {
|
||||
// ElMessage.success("修改成功")
|
||||
// dialogVisible.value = false
|
||||
// getTableData() // 刷新表格数据
|
||||
// })
|
||||
// .catch(() => {
|
||||
// ElMessage.error("修改失败")
|
||||
// })
|
||||
// .finally(() => {
|
||||
// loading.value = false
|
||||
// })
|
||||
// }
|
||||
|
||||
// 查
|
||||
const tableData = ref<TableData[]>([])
|
||||
|
@ -137,8 +130,8 @@ onActivated(() => {
|
|||
<el-table-column prop="embeddingModel" label="嵌入模型" align="center" />
|
||||
<el-table-column prop="updateTime" label="更新时间" align="center" />
|
||||
<el-table-column fixed="right" label="操作" width="150" align="center">
|
||||
<template #default="scope">
|
||||
<el-button type="primary" text bg size="small" @click="handleUpdate(scope.row)">
|
||||
<template #default="">
|
||||
<el-button type="primary" text bg size="small" @click="handleUpdate">
|
||||
修改
|
||||
</el-button>
|
||||
<el-button type="danger" text bg size="small" @click="handleDelete()">
|
||||
|
@ -163,7 +156,7 @@ onActivated(() => {
|
|||
</el-card>
|
||||
|
||||
<!-- 修改对话框 -->
|
||||
<el-dialog v-model="dialogVisible" title="修改配置" width="30%">
|
||||
<!-- <el-dialog v-model="dialogVisible" title="修改配置" width="30%">
|
||||
<el-form :model="formData" label-width="100px">
|
||||
<el-form-item label="用户名">
|
||||
<el-input v-model="formData.username" disabled />
|
||||
|
@ -183,7 +176,7 @@ onActivated(() => {
|
|||
确认
|
||||
</el-button>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</el-dialog> -->
|
||||
</div>
|
||||
</template>
|
||||
|
||||
|
|
|
@ -32,16 +32,19 @@ def chunks_format(reference):
|
|||
def get_value(d, k1, k2):
|
||||
return d.get(k1, d.get(k2))
|
||||
|
||||
return [{
|
||||
"id": get_value(chunk, "chunk_id", "id"),
|
||||
"content": get_value(chunk, "content", "content_with_weight"),
|
||||
"document_id": get_value(chunk, "doc_id", "document_id"),
|
||||
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
||||
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
||||
"image_id": get_value(chunk, "image_id", "img_id"),
|
||||
"positions": get_value(chunk, "positions", "position_int"),
|
||||
"url": chunk.get("url")
|
||||
} for chunk in reference.get("chunks", [])]
|
||||
return [
|
||||
{
|
||||
"id": get_value(chunk, "chunk_id", "id"),
|
||||
"content": get_value(chunk, "content", "content_with_weight"),
|
||||
"document_id": get_value(chunk, "doc_id", "document_id"),
|
||||
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
||||
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
||||
"image_id": get_value(chunk, "image_id", "img_id"),
|
||||
"positions": get_value(chunk, "positions", "position_int"),
|
||||
"url": chunk.get("url"),
|
||||
}
|
||||
for chunk in reference.get("chunks", [])
|
||||
]
|
||||
|
||||
|
||||
def llm_id2llm_type(llm_id):
|
||||
|
@ -57,21 +60,21 @@ def llm_id2llm_type(llm_id):
|
|||
def message_fit_in(msg, max_length=4000):
|
||||
"""
|
||||
调整消息列表使其token总数不超过max_length限制
|
||||
|
||||
|
||||
参数:
|
||||
msg: 消息列表,每个元素为包含role和content的字典
|
||||
max_length: 最大token数限制,默认4000
|
||||
|
||||
|
||||
返回:
|
||||
tuple: (实际token数, 调整后的消息列表)
|
||||
"""
|
||||
|
||||
def count():
|
||||
"""计算当前消息列表的总token数"""
|
||||
nonlocal msg
|
||||
tks_cnts = []
|
||||
for m in msg:
|
||||
tks_cnts.append(
|
||||
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||||
tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||||
total = 0
|
||||
for m in tks_cnts:
|
||||
total += m["count"]
|
||||
|
@ -81,7 +84,7 @@ def message_fit_in(msg, max_length=4000):
|
|||
# 如果不超限制,直接返回
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
|
||||
# 第一次精简:保留系统消息和最后一条消息
|
||||
msg_ = [m for m in msg if m["role"] == "system"]
|
||||
if len(msg) > 1:
|
||||
|
@ -90,20 +93,20 @@ def message_fit_in(msg, max_length=4000):
|
|||
c = count()
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
|
||||
# 计算系统消息和最后一条消息的token数
|
||||
ll = num_tokens_from_string(msg_[0]["content"])
|
||||
ll2 = num_tokens_from_string(msg_[-1]["content"])
|
||||
# 如果系统消息占比超过80%,则截断系统消息
|
||||
if ll / (ll + ll2) > 0.8:
|
||||
m = msg_[0]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
||||
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
|
||||
msg[0]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
|
||||
# 否则截断最后一条消息
|
||||
m = msg_[-1]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
||||
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
|
||||
msg[-1]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
|
@ -111,18 +114,18 @@ def message_fit_in(msg, max_length=4000):
|
|||
def kb_prompt(kbinfos, max_tokens):
|
||||
"""
|
||||
将检索到的知识库内容格式化为适合大语言模型的提示词
|
||||
|
||||
|
||||
参数:
|
||||
kbinfos (dict): 检索结果,包含chunks等信息
|
||||
max_tokens (int): 模型的最大token限制
|
||||
|
||||
|
||||
流程:
|
||||
1. 提取所有检索到的文档片段内容
|
||||
2. 计算token数量,确保不超过模型限制
|
||||
3. 获取文档元数据
|
||||
4. 按文档名组织文档片段
|
||||
5. 格式化为结构化提示词
|
||||
|
||||
|
||||
返回:
|
||||
list: 格式化后的知识库内容列表,每个元素是一个文档的相关信息
|
||||
"""
|
||||
|
@ -134,7 +137,7 @@ def kb_prompt(kbinfos, max_tokens):
|
|||
chunks_num += 1
|
||||
if max_tokens * 0.97 < used_token_count:
|
||||
knowledges = knowledges[:i]
|
||||
logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}")
|
||||
logging.warning(f"Not all the retrieval into prompt: {i + 1}/{len(knowledges)}")
|
||||
break
|
||||
|
||||
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
|
||||
|
@ -163,6 +166,10 @@ def citation_prompt():
|
|||
- 以格式 '##i$$ ##j$$'插入引用,其中 i, j 是所引用内容的 ID,并用 '##' 和 '$$' 包裹。
|
||||
- 在句子末尾插入引用,每个句子最多 4 个引用。
|
||||
- 如果答案内容不来自检索到的文本块,则不要插入引用。
|
||||
- 不要使用独立的文档 ID(例如 `#ID#`)。
|
||||
- 在任何情况下,不得使用其他引用样式或格式(例如 `~~i==`、`[i]`、`(i)` 等)。
|
||||
- 引用必须始终使用 `##i$$` 格式。
|
||||
- 任何未能遵守上述规则的情况,包括但不限于格式错误、使用禁止的样式或不支持的引用,都将被视为错误,应跳过为该句添加引用。
|
||||
|
||||
--- 示例 ---
|
||||
<SYSTEM>: 以下是知识库:
|
||||
|
@ -210,10 +217,7 @@ def keyword_extraction(chat_mdl, content, topn=3):
|
|||
### 文本内容
|
||||
{content}
|
||||
"""
|
||||
msg = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": "Output: "}
|
||||
]
|
||||
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
|
@ -240,10 +244,7 @@ Requirements:
|
|||
{content}
|
||||
|
||||
"""
|
||||
msg = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": "Output: "}
|
||||
]
|
||||
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
|
@ -368,10 +369,7 @@ Output:
|
|||
{content}
|
||||
|
||||
"""
|
||||
msg = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": "Output: "}
|
||||
]
|
||||
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
|
||||
if isinstance(kwd, tuple):
|
||||
|
@ -384,8 +382,8 @@ Output:
|
|||
return json_repair.loads(kwd)
|
||||
except json_repair.JSONDecodeError:
|
||||
try:
|
||||
result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip()
|
||||
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
||||
result = kwd.replace(prompt[:-1], "").replace("user", "").replace("model", "").strip()
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
return json_repair.loads(result)
|
||||
except Exception as e:
|
||||
logging.exception(f"JSON parsing error: {result} -> {e}")
|
||||
|
|
|
@ -27,7 +27,8 @@ import {
|
|||
UploadProps,
|
||||
} from 'antd';
|
||||
import get from 'lodash/get';
|
||||
import { CircleStop, Paperclip, SendHorizontal } from 'lucide-react';
|
||||
// import { CircleStop, Paperclip, SendHorizontal } from 'lucide-react';
|
||||
import { CircleStop, SendHorizontal } from 'lucide-react';
|
||||
import {
|
||||
ChangeEventHandler,
|
||||
memo,
|
||||
|
@ -339,9 +340,9 @@ const MessageInput = ({
|
|||
return false;
|
||||
}}
|
||||
>
|
||||
<Button type={'primary'} disabled={disabled}>
|
||||
{/* <Button type={'primary'} disabled={disabled}>
|
||||
<Paperclip className="size-4" />
|
||||
</Button>
|
||||
</Button> */}
|
||||
</Upload>
|
||||
)}
|
||||
{sendLoading ? (
|
||||
|
|
|
@ -17,13 +17,11 @@
|
|||
line-height: 1.2;
|
||||
border-bottom: 2px solid #eaeaea;
|
||||
padding-bottom: 0.25em;
|
||||
font-size: 24px;
|
||||
margin: 0.25em 0.25em;
|
||||
}
|
||||
section {
|
||||
margin-top: 1em;
|
||||
margin-bottom: 1em;
|
||||
font-size: 20px;
|
||||
p {
|
||||
margin-left: 0;
|
||||
}
|
||||
|
@ -32,7 +30,6 @@
|
|||
margin-top: 1em;
|
||||
margin-bottom: 1em;
|
||||
margin-left: 1em;
|
||||
font-size: 20px;
|
||||
}
|
||||
ul,
|
||||
ol {
|
||||
|
@ -40,7 +37,6 @@
|
|||
padding-left: 1.8em;
|
||||
li {
|
||||
margin-bottom: 0.25em;
|
||||
font-size: 20px;
|
||||
}
|
||||
}
|
||||
table {
|
||||
|
@ -60,7 +56,6 @@
|
|||
border: none;
|
||||
padding: 12px;
|
||||
text-align: left;
|
||||
font-size: 20px;
|
||||
border-bottom: 1px solid #ddd;
|
||||
}
|
||||
|
||||
|
|
|
@ -24,9 +24,10 @@ import styles from './index.less';
|
|||
|
||||
interface IProps {
|
||||
controller: AbortController;
|
||||
fontSize: number;
|
||||
}
|
||||
|
||||
const ChatContainer = ({ controller }: IProps) => {
|
||||
const ChatContainer = ({ controller, fontSize = 16 }: IProps) => {
|
||||
const { conversationId } = useGetChatSearchParams();
|
||||
const { data: conversation } = useFetchNextConversation();
|
||||
|
||||
|
@ -55,7 +56,12 @@ const ChatContainer = ({ controller }: IProps) => {
|
|||
return (
|
||||
<>
|
||||
<Flex flex={1} className={styles.chatContainer} vertical>
|
||||
<Flex flex={1} vertical className={styles.messageContainer}>
|
||||
<Flex
|
||||
flex={1}
|
||||
vertical
|
||||
className={styles.messageContainer}
|
||||
style={{ fontSize: `${fontSize}px` }}
|
||||
>
|
||||
<div>
|
||||
<Spin spinning={loading}>
|
||||
{derivedMessages?.map((message, i) => {
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
import { ReactComponent as ChatAppCube } from '@/assets/svg/chat-app-cube.svg';
|
||||
import RenameModal from '@/components/rename-modal';
|
||||
import { DeleteOutlined, EditOutlined } from '@ant-design/icons';
|
||||
import {
|
||||
DeleteOutlined,
|
||||
EditOutlined,
|
||||
SettingOutlined,
|
||||
} from '@ant-design/icons';
|
||||
import {
|
||||
Avatar,
|
||||
Button,
|
||||
|
@ -12,12 +16,12 @@ import {
|
|||
Space,
|
||||
Spin,
|
||||
Tag,
|
||||
Tooltip,
|
||||
// Tooltip,
|
||||
Typography,
|
||||
} from 'antd';
|
||||
import { MenuItemProps } from 'antd/lib/menu/MenuItem';
|
||||
import classNames from 'classnames';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import ChatConfigurationModal from './chat-configuration-modal';
|
||||
import ChatContainer from './chat-container';
|
||||
import {
|
||||
|
@ -43,9 +47,9 @@ import {
|
|||
import { useTranslate } from '@/hooks/common-hooks';
|
||||
import { useSetSelectedRecord } from '@/hooks/logic-hooks';
|
||||
import { IDialog } from '@/interfaces/database/chat';
|
||||
import { Modal, Slider } from 'antd';
|
||||
import { PictureInPicture2 } from 'lucide-react';
|
||||
import styles from './index.less';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
const Chat = () => {
|
||||
|
@ -161,6 +165,17 @@ const Chat = () => {
|
|||
addTemporaryConversation();
|
||||
}, [addTemporaryConversation]);
|
||||
|
||||
const [fontSizeModalVisible, setFontSizeModalVisible] = useState(false);
|
||||
const [fontSize, setFontSize] = useState(16); // 默认字体大小
|
||||
|
||||
// 从localStorage加载字体大小设置
|
||||
useEffect(() => {
|
||||
const savedFontSize = localStorage.getItem('chatFontSize');
|
||||
if (savedFontSize) {
|
||||
setFontSize(parseInt(savedFontSize));
|
||||
}
|
||||
}, []);
|
||||
|
||||
const buildAppItems = (dialog: IDialog) => {
|
||||
const dialogId = dialog.id;
|
||||
|
||||
|
@ -286,6 +301,7 @@ const Chat = () => {
|
|||
</Flex>
|
||||
</Flex>
|
||||
<Divider type={'vertical'} className={styles.divider}></Divider>
|
||||
|
||||
<Flex className={styles.chatTitleWrapper}>
|
||||
<Flex flex={1} vertical>
|
||||
<Flex
|
||||
|
@ -297,15 +313,23 @@ const Chat = () => {
|
|||
<b>{t('chat')}</b>
|
||||
<Tag>{conversationList.length}</Tag>
|
||||
</Space>
|
||||
<Tooltip title={t('newChat')}>
|
||||
<div>
|
||||
<SvgIcon
|
||||
name="plus-circle-fill"
|
||||
width={20}
|
||||
onClick={handleCreateTemporaryConversation}
|
||||
></SvgIcon>
|
||||
</div>
|
||||
</Tooltip>
|
||||
{/* <Tooltip title={t('newChat')}> */}
|
||||
<div>
|
||||
<SettingOutlined
|
||||
style={{
|
||||
marginRight: '8px',
|
||||
fontSize: '20px',
|
||||
cursor: 'pointer',
|
||||
}}
|
||||
onClick={() => setFontSizeModalVisible(true)}
|
||||
/>
|
||||
<SvgIcon
|
||||
name="plus-circle-fill"
|
||||
width={20}
|
||||
onClick={handleCreateTemporaryConversation}
|
||||
></SvgIcon>
|
||||
</div>
|
||||
{/* </Tooltip> */}
|
||||
</Flex>
|
||||
<Divider></Divider>
|
||||
<Flex vertical gap={10} className={styles.chatTitleContent}>
|
||||
|
@ -356,7 +380,10 @@ const Chat = () => {
|
|||
</Flex>
|
||||
</Flex>
|
||||
<Divider type={'vertical'} className={styles.divider}></Divider>
|
||||
<ChatContainer controller={controller}></ChatContainer>
|
||||
<ChatContainer
|
||||
controller={controller}
|
||||
fontSize={fontSize}
|
||||
></ChatContainer>
|
||||
{dialogEditVisible && (
|
||||
<ChatConfigurationModal
|
||||
visible={dialogEditVisible}
|
||||
|
@ -386,6 +413,27 @@ const Chat = () => {
|
|||
isAgent={false}
|
||||
></EmbedModal>
|
||||
)}
|
||||
|
||||
{fontSizeModalVisible && (
|
||||
<Modal
|
||||
title={'设置字体大小'}
|
||||
open={fontSizeModalVisible}
|
||||
onCancel={() => setFontSizeModalVisible(false)}
|
||||
footer={null}
|
||||
>
|
||||
<Flex vertical gap="middle" align="center">
|
||||
{'当前字体大小'}: {fontSize}px
|
||||
<Slider
|
||||
min={12}
|
||||
max={24}
|
||||
step={1}
|
||||
defaultValue={fontSize}
|
||||
style={{ width: '80%' }}
|
||||
onChange={(value) => setFontSize(value)}
|
||||
/>
|
||||
</Flex>
|
||||
</Modal>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue