Merge pull request #100 from zstar1003/dev

This commit is contained in:
zstar 2025-05-17 15:30:12 +08:00 committed by GitHub
commit fd7f1140cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 349 additions and 464 deletions

View File

@ -37,7 +37,8 @@ from api.db.services.file_service import FileService
from flask import jsonify, request, Response
@manager.route('/chats/<chat_id>/sessions', methods=['POST']) # noqa: F821
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required
def create(tenant_id, chat_id):
req = request.json
@ -50,7 +51,7 @@ def create(tenant_id, chat_id):
"dialog_id": req["dialog_id"],
"name": req.get("name", "New session"),
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}],
"user_id": req.get("user_id", "")
"user_id": req.get("user_id", ""),
}
if not conv.get("name"):
return get_error_data_result(message="`name` can not be empty.")
@ -59,20 +60,20 @@ def create(tenant_id, chat_id):
if not e:
return get_error_data_result(message="Fail to create a session!")
conv = conv.to_dict()
conv['messages'] = conv.pop("message")
conv["messages"] = conv.pop("message")
conv["chat_id"] = conv.pop("dialog_id")
del conv["reference"]
return get_result(data=conv)
@manager.route('/agents/<agent_id>/sessions', methods=['POST']) # noqa: F821
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
@token_required
def create_agent_session(tenant_id, agent_id):
req = request.json
if not request.is_json:
req = request.form
files = request.files
user_id = request.args.get('user_id', '')
user_id = request.args.get("user_id", "")
e, cvs = UserCanvasService.get_by_id(agent_id)
if not e:
@ -113,7 +114,7 @@ def create_agent_session(tenant_id, agent_id):
ele.pop("value")
else:
if req is not None and req.get(ele["key"]):
ele["value"] = req[ele['key']]
ele["value"] = req[ele["key"]]
else:
if "value" in ele:
ele.pop("value")
@ -121,20 +122,13 @@ def create_agent_session(tenant_id, agent_id):
for ans in canvas.run(stream=False):
pass
cvs.dsl = json.loads(str(canvas))
conv = {
"id": get_uuid(),
"dialog_id": cvs.id,
"user_id": user_id,
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent",
"dsl": cvs.dsl
}
conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
API4ConversationService.save(**conv)
conv["agent_id"] = conv.pop("dialog_id")
return get_result(data=conv)
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT']) # noqa: F821
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
@token_required
def update(tenant_id, chat_id, session_id):
req = request.json
@ -156,14 +150,14 @@ def update(tenant_id, chat_id, session_id):
return get_result()
@manager.route('/chats/<chat_id>/completions', methods=['POST']) # noqa: F821
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
@token_required
def chat_completion(tenant_id, chat_id):
req = request.json
if not req:
req = {"question": ""}
if not req.get("session_id"):
req["question"]=""
req["question"] = ""
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(f"You don't own the chat {chat_id}")
if req.get("session_id"):
@ -185,7 +179,7 @@ def chat_completion(tenant_id, chat_id):
return get_result(data=answer)
@manager.route('chats_openai/<chat_id>/chat/completions', methods=['POST']) # noqa: F821
@manager.route("chats_openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
@validate_request("model", "messages") # noqa: F821
@token_required
def chat_completion_openai_like(tenant_id, chat_id):
@ -259,39 +253,23 @@ def chat_completion_openai_like(tenant_id, chat_id):
# The choices field on the last chunk will always be an empty array [].
def streamed_response_generator(chat_id, dia, msg):
token_used = 0
should_split_index = 0
answer_cache = ""
response = {
"id": f"chatcmpl-{chat_id}",
"choices": [
{
"delta": {
"content": "",
"role": "assistant",
"function_call": None,
"tool_calls": None
},
"finish_reason": None,
"index": 0,
"logprobs": None
}
],
"choices": [{"delta": {"content": "", "role": "assistant", "function_call": None, "tool_calls": None}, "finish_reason": None, "index": 0, "logprobs": None}],
"created": int(time.time()),
"model": "model",
"object": "chat.completion.chunk",
"system_fingerprint": "",
"usage": None
"usage": None,
}
try:
for ans in chat(dia, msg, True):
answer = ans["answer"]
incremental = answer[should_split_index:]
incremental = answer.replace(answer_cache, "", 1)
answer_cache = answer.rstrip("</think>")
token_used += len(incremental)
if incremental.endswith("</think>"):
response_data_len = len(incremental.rstrip("</think>"))
else:
response_data_len = len(incremental)
should_split_index += response_data_len
response["choices"][0]["delta"]["content"] = incremental
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
except Exception as e:
@ -301,15 +279,10 @@ def chat_completion_openai_like(tenant_id, chat_id):
# The last chunk
response["choices"][0]["delta"]["content"] = None
response["choices"][0]["finish_reason"] = "stop"
response["usage"] = {
"prompt_tokens": len(prompt),
"completion_tokens": token_used,
"total_tokens": len(prompt) + token_used
}
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
yield "data:[DONE]\n\n"
resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
@ -324,7 +297,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
break
content = answer["answer"]
response = {
response = {
"id": f"chatcmpl-{chat_id}",
"object": "chat.completion",
"created": int(time.time()),
@ -336,25 +309,15 @@ def chat_completion_openai_like(tenant_id, chat_id):
"completion_tokens_details": {
"reasoning_tokens": context_token_used,
"accepted_prediction_tokens": len(content),
"rejected_prediction_tokens": 0 # 0 for simplicity
}
"rejected_prediction_tokens": 0, # 0 for simplicity
},
},
"choices": [
{
"message": {
"role": "assistant",
"content": content
},
"logprobs": None,
"finish_reason": "stop",
"index": 0
}
]
"choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": "stop", "index": 0}],
}
return jsonify(response)
@manager.route('/agents/<agent_id>/completions', methods=['POST']) # noqa: F821
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
@token_required
def agent_completions(tenant_id, agent_id):
req = request.json
@ -365,8 +328,8 @@ def agent_completions(tenant_id, agent_id):
dsl = cvs[0].dsl
if not isinstance(dsl, str):
dsl = json.dumps(dsl)
#canvas = Canvas(dsl, tenant_id)
#if canvas.get_preset_param():
# canvas = Canvas(dsl, tenant_id)
# if canvas.get_preset_param():
# req["question"] = ""
conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id)
if not conv:
@ -380,9 +343,7 @@ def agent_completions(tenant_id, agent_id):
states = {field: current_dsl.get(field, []) for field in state_fields}
current_dsl.update(new_dsl)
current_dsl.update(states)
API4ConversationService.update_by_id(req["session_id"], {
"dsl": current_dsl
})
API4ConversationService.update_by_id(req["session_id"], {"dsl": current_dsl})
else:
req["question"] = ""
if req.get("stream", True):
@ -399,7 +360,7 @@ def agent_completions(tenant_id, agent_id):
return get_error_data_result(str(e))
@manager.route('/chats/<chat_id>/sessions', methods=['GET']) # noqa: F821
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
@token_required
def list_session(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
@ -418,7 +379,7 @@ def list_session(tenant_id, chat_id):
if not convs:
return get_result(data=[])
for conv in convs:
conv['messages'] = conv.pop("message")
conv["messages"] = conv.pop("message")
infos = conv["messages"]
for info in infos:
if "prompt" in info:
@ -452,7 +413,7 @@ def list_session(tenant_id, chat_id):
return get_result(data=convs)
@manager.route('/agents/<agent_id>/sessions', methods=['GET']) # noqa: F821
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
@token_required
def list_agent_session(tenant_id, agent_id):
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
@ -468,12 +429,11 @@ def list_agent_session(tenant_id, agent_id):
desc = True
# dsl defaults to True in all cases except for False and false
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id,
user_id, include_dsl)
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl)
if not convs:
return get_result(data=[])
for conv in convs:
conv['messages'] = conv.pop("message")
conv["messages"] = conv.pop("message")
infos = conv["messages"]
for info in infos:
if "prompt" in info:
@ -506,7 +466,7 @@ def list_agent_session(tenant_id, agent_id):
return get_result(data=convs)
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"]) # noqa: F821
@manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required
def delete(tenant_id, chat_id):
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
@ -532,14 +492,14 @@ def delete(tenant_id, chat_id):
return get_result()
@manager.route('/agents/<agent_id>/sessions', methods=["DELETE"]) # noqa: F821
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required
def delete_agent_session(tenant_id, agent_id):
req = request.json
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}")
convs = API4ConversationService.query(dialog_id=agent_id)
if not convs:
return get_error_data_result(f"Agent {agent_id} has no sessions")
@ -555,16 +515,16 @@ def delete_agent_session(tenant_id, agent_id):
conv_list.append(conv.id)
else:
conv_list = ids
for session_id in conv_list:
conv = API4ConversationService.query(id=session_id, dialog_id=agent_id)
if not conv:
return get_error_data_result(f"The agent doesn't own the session ${session_id}")
API4ConversationService.delete_by_id(session_id)
return get_result()
@manager.route('/sessions/ask', methods=['POST']) # noqa: F821
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
@token_required
def ask_about(tenant_id):
req = request.json
@ -590,9 +550,7 @@ def ask_about(tenant_id):
for ans in ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream")
@ -603,7 +561,7 @@ def ask_about(tenant_id):
return resp
@manager.route('/sessions/related_questions', methods=['POST']) # noqa: F821
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
@token_required
def related_questions(tenant_id):
req = request.json
@ -635,18 +593,27 @@ Reason:
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
"""
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f"""
ans = chat_mdl.chat(
prompt,
[
{
"role": "user",
"content": f"""
Keywords: {question}
Related search terms:
"""}], {"temperature": 0.9})
""",
}
],
{"temperature": 0.9},
)
return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
@manager.route('/chatbots/<dialog_id>/completions', methods=['POST']) # noqa: F821
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
def chatbot_completions(dialog_id):
req = request.json
token = request.headers.get('Authorization').split()
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
token = token[1]
@ -669,11 +636,11 @@ def chatbot_completions(dialog_id):
return get_result(data=answer)
@manager.route('/agentbots/<agent_id>/completions', methods=['POST']) # noqa: F821
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
def agent_bot_completions(agent_id):
req = request.json
token = request.headers.get('Authorization').split()
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
token = token[1]

View File

@ -47,20 +47,14 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_list(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords, id, name):
def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, id, name):
docs = cls.model.select().where(cls.model.kb_id == kb_id)
if id:
docs = docs.where(
cls.model.id == id)
docs = docs.where(cls.model.id == id)
if name:
docs = docs.where(
cls.model.name == name
)
docs = docs.where(cls.model.name == name)
if keywords:
docs = docs.where(
fn.LOWER(cls.model.name).contains(keywords.lower())
)
docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc())
else:
@ -72,13 +66,9 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords):
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords):
if keywords:
docs = cls.model.select().where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
docs = cls.model.select().where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
else:
docs = cls.model.select().where(cls.model.kb_id == kb_id)
count = docs.count()
@ -97,8 +87,7 @@ class DocumentService(CommonService):
if not cls.save(**doc):
raise RuntimeError("Database error (Document)!")
e, kb = KnowledgebaseService.get_by_id(doc["kb_id"])
if not KnowledgebaseService.update_by_id(
kb.id, {"doc_num": kb.doc_num + 1}):
if not KnowledgebaseService.update_by_id(kb.id, {"doc_num": kb.doc_num + 1}):
raise RuntimeError("Database error (Knowledgebase)!")
return Document(**doc)
@ -108,14 +97,16 @@ class DocumentService(CommonService):
cls.clear_chunk_num(doc.id)
try:
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
{"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update(
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id),
doc.kb_id,
)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete(
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}}, search.index_name(tenant_id), doc.kb_id
)
except Exception:
pass
return cls.delete_by_id(doc.id)
@ -136,67 +127,54 @@ class DocumentService(CommonService):
Tenant.embd_id,
Tenant.img2txt_id,
Tenant.asr_id,
cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
cls.model.update_time,
]
docs = (
cls.model.select(*fields)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value) \
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value,
)
.order_by(cls.model.update_time.asc())
)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
cls.model.run, cls.model.parser_id]
docs = cls.model.select(*fields) \
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress < 1,
cls.model.progress > 0)
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id]
docs = cls.model.select(*fields).where(cls.model.status == StatusEnum.VALID.value, ~(cls.model.type == FileType.VIRTUAL.value), cls.model.progress < 1, cls.model.progress > 0)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duation=cls.model.process_duation + duation).where(
cls.model.id == doc_id).execute()
num = (
cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duation=cls.model.process_duation + duation)
.where(cls.model.id == doc_id)
.execute()
)
if num == 0:
raise LookupError(
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num +
token_num,
chunk_num=Knowledgebase.chunk_num +
chunk_num).where(
Knowledgebase.id == kb_id).execute()
raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num - token_num,
chunk_num=cls.model.chunk_num - chunk_num,
process_duation=cls.model.process_duation + duation).where(
cls.model.id == doc_id).execute()
num = (
cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duation=cls.model.process_duation + duation)
.where(cls.model.id == doc_id)
.execute()
)
if num == 0:
raise LookupError(
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
token_num,
chunk_num=Knowledgebase.chunk_num -
chunk_num
).where(
Knowledgebase.id == kb_id).execute()
raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute()
return num
@classmethod
@ -205,24 +183,17 @@ class DocumentService(CommonService):
doc = cls.model.get_by_id(doc_id)
assert doc, "Can't fine document in database."
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
doc.token_num,
chunk_num=Knowledgebase.chunk_num -
doc.chunk_num,
doc_num=Knowledgebase.doc_num - 1
).where(
Knowledgebase.id == doc.kb_id).execute()
num = (
Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1)
.where(Knowledgebase.id == doc.kb_id)
.execute()
)
return num
@classmethod
@DB.connection_context()
def get_tenant_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
@ -240,11 +211,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_tenant_id_by_name(cls, name):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
@ -253,12 +220,13 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def accessible(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = (
cls.model.select(cls.model.id)
.join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
.join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id))
.where(cls.model.id == doc_id, UserTenant.user_id == user_id)
.paginate(0, 1)
)
docs = docs.dicts()
if not docs:
return False
@ -267,11 +235,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def accessible4deletion(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
@ -280,11 +244,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_embd_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = cls.model.select(Knowledgebase.embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
@ -306,9 +266,9 @@ class DocumentService(CommonService):
Tenant.asr_id,
Tenant.llm_id,
)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id)
)
configs = configs.dicts()
if not configs:
@ -319,8 +279,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
fields = [cls.model.id]
doc_id = cls.model.select(*fields) \
.where(cls.model.name == doc_name)
doc_id = cls.model.select(*fields).where(cls.model.name == doc_name)
doc_id = doc_id.dicts()
if not doc_id:
return
@ -330,8 +289,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def get_thumbnails(cls, docids):
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
return list(cls.model.select(
*fields).where(cls.model.id.in_(docids)).dicts())
return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts())
@classmethod
@DB.connection_context()
@ -359,19 +317,14 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_doc_count(cls, tenant_id):
docs = cls.model.select(cls.model.id).join(Knowledgebase,
on=(Knowledgebase.id == cls.model.kb_id)).where(
Knowledgebase.tenant_id == tenant_id)
docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(Knowledgebase.tenant_id == tenant_id)
return len(docs)
@classmethod
@DB.connection_context()
def begin2parse(cls, docid):
cls.update_by_id(
docid, {"progress": random.random() * 1 / 100.,
"progress_msg": "Task is queued...",
"process_begin_at": get_format_time()
})
cls.update_by_id(docid, {"progress": random.random() * 1 / 100.0, "progress_msg": "Task is queued...", "process_begin_at": get_format_time()})
@classmethod
@DB.connection_context()
def update_meta_fields(cls, doc_id, meta_fields):
@ -420,11 +373,7 @@ class DocumentService(CommonService):
status = TaskStatus.DONE.value
msg = "\n".join(sorted(msg))
info = {
"process_duation": datetime.timestamp(
datetime.now()) -
d["process_begin_at"].timestamp(),
"run": status}
info = {"process_duation": datetime.timestamp(datetime.now()) - d["process_begin_at"].timestamp(), "run": status}
if prg != 0:
info["progress"] = prg
if msg:
@ -437,8 +386,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_kb_doc_count(cls, kb_id):
return len(cls.model.select(cls.model.id).where(
cls.model.kb_id == kb_id).dicts())
return len(cls.model.select(cls.model.id).where(cls.model.kb_id == kb_id).dicts())
@classmethod
@DB.connection_context()
@ -459,14 +407,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty):
def new_task():
nonlocal doc
return {
"id": get_uuid(),
"doc_id": doc["id"],
"from_page": 100000000,
"to_page": 100000000,
"task_type": ty,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
}
return {"id": get_uuid(), "doc_id": doc["id"], "from_page": 100000000, "to_page": 100000000, "task_type": ty, "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty}
task = new_task()
for field in ["doc_id", "from_page", "to_page"]:
@ -478,6 +419,25 @@ def queue_raptor_o_graphrag_tasks(doc, ty):
def doc_upload_and_parse(conversation_id, file_objs, user_id):
"""
上传并解析文档将内容存入知识库
参数:
conversation_id: 会话ID
file_objs: 文件对象列表
user_id: 用户ID
返回:
处理成功的文档ID列表
处理流程:
1. 验证会话和知识库
2. 初始化嵌入模型
3. 上传文件到存储
4. 多线程解析文件内容
5. 生成内容嵌入向量
6. 存入文档存储系统
"""
from rag.app import presentation, picture, naive, audio, email
from api.db.services.dialog_service import DialogService
from api.db.services.file_service import FileService
@ -493,8 +453,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
e, dia = DialogService.get_by_id(conv.dialog_id)
if not dia.kb_ids:
raise LookupError("No knowledge base associated with this conversation. "
"Please add a knowledge base before uploading documents")
raise LookupError("No knowledge base associated with this conversation. Please add a knowledge base before uploading documents")
kb_id = dia.kb_ids[0]
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
@ -508,12 +467,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
def dummy(prog=None, msg=""):
pass
FACTORY = {
ParserType.PRESENTATION.value: presentation,
ParserType.PICTURE.value: picture,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
# 使用线程池执行解析任务
exe = ThreadPoolExecutor(max_workers=12)
@ -522,22 +476,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
for d, blob in files:
doc_nm[d["id"]] = d["name"]
for d, blob in files:
kwargs = {
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000,
"tenant_id": kb.tenant_id,
"lang": kb.language
}
kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language}
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
for (docinfo, _), th in zip(files, threads):
docs = []
doc = {
"doc_id": docinfo["id"],
"kb_id": [kb.id]
}
doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]}
for ck in th.result():
d = deepcopy(doc)
d.update(ck)
@ -552,7 +496,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"])
else:
d["image"].save(output_buffer, format='JPEG')
d["image"].save(output_buffer, format="JPEG")
STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["id"])
@ -569,9 +513,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
nonlocal embd_mdl, chunk_counts, token_counts
vects = []
for i in range(0, len(cnts), batch_size):
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
vts, c = embd_mdl.encode(cnts[i : i + batch_size])
vects.extend(vts.tolist())
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
chunk_counts[doc_id] += len(cnts[i : i + batch_size])
token_counts[doc_id] += c
return vects
@ -585,22 +529,25 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if parser_ids[doc_id] != ParserType.PICTURE.value:
from graphrag.general.mind_map_extractor import MindMapExtractor
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
if len(mind_map) < 32:
raise Exception("Few content: " + mind_map)
cks.append({
"id": get_uuid(),
"doc_id": doc_id,
"kb_id": [kb.id],
"docnm_kwd": doc_nm[doc_id],
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map"
})
cks.append(
{
"id": get_uuid(),
"doc_id": doc_id,
"kb_id": [kb.id],
"docnm_kwd": doc_nm[doc_id],
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map",
}
)
except Exception as e:
logging.exception("Mind map generation error")
@ -614,9 +561,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if not settings.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
return [d["id"] for d, _ in files]

View File

@ -1,5 +1,5 @@
from flask import jsonify, request
from services.conversation.service import get_conversations_by_user_id, get_messages_by_conversation_id, get_conversation_detail
from services.conversation.service import get_conversations_by_user_id, get_messages_by_conversation_id
from .. import conversation_bp
@ -44,20 +44,3 @@ def get_messages(conversation_id):
except Exception as e:
# 错误处理
return jsonify({"code": 500, "message": f"获取消息列表失败: {str(e)}"}), 500
@conversation_bp.route("/<conversation_id>", methods=["GET"])
def get_conversation(conversation_id):
"""获取特定对话的详细信息"""
try:
# 调用服务函数获取对话详情
conversation = get_conversation_detail(conversation_id)
if not conversation:
return jsonify({"code": 404, "message": "对话不存在"}), 404
# 返回符合前端期望格式的数据
return jsonify({"code": 0, "data": conversation, "message": "获取对话详情成功"})
except Exception as e:
# 错误处理
return jsonify({"code": 500, "message": f"获取对话详情失败: {str(e)}"}), 500

View File

@ -23,8 +23,6 @@ def get_conversations_by_user_id(user_id, page=1, size=20, sort_by="update_time"
# 直接使用user_id作为tenant_id
tenant_id = user_id
print(f"查询用户ID: {user_id}, 租户ID: {tenant_id}")
# 查询总记录数
count_sql = """
SELECT COUNT(*) as total
@ -34,7 +32,7 @@ def get_conversations_by_user_id(user_id, page=1, size=20, sort_by="update_time"
cursor.execute(count_sql, (tenant_id,))
total = cursor.fetchone()["total"]
print(f"查询到总记录数: {total}")
# print(f"查询到总记录数: {total}")
# 计算分页偏移量
offset = (page - 1) * size
@ -59,8 +57,8 @@ def get_conversations_by_user_id(user_id, page=1, size=20, sort_by="update_time"
LIMIT %s OFFSET %s
"""
print(f"执行查询: {query}")
print(f"参数: tenant_id={tenant_id}, size={size}, offset={offset}")
# print(f"执行查询: {query}")
# print(f"参数: tenant_id={tenant_id}, size={size}, offset={offset}")
cursor.execute(query, (tenant_id, size, offset))
results = cursor.fetchall()
@ -200,68 +198,3 @@ def get_messages_by_conversation_id(conversation_id, page=1, size=30):
traceback.print_exc()
return None, 0
def get_conversation_detail(conversation_id):
"""
获取特定对话的详细信息
参数:
conversation_id (str): 对话ID
返回:
dict: 对话详情
"""
try:
conn = mysql.connector.connect(**DB_CONFIG)
cursor = conn.cursor(dictionary=True)
# 查询对话信息
query = """
SELECT c.*, d.name as dialog_name, d.icon as dialog_icon
FROM conversation c
LEFT JOIN dialog d ON c.dialog_id = d.id
WHERE c.id = %s
"""
cursor.execute(query, (conversation_id,))
result = cursor.fetchone()
if not result:
print(f"未找到对话ID: {conversation_id}")
return None
# 格式化对话详情
conversation = {
"id": result["id"],
"name": result.get("name", ""),
"dialogId": result.get("dialog_id", ""),
"dialogName": result.get("dialog_name", ""),
"dialogIcon": result.get("dialog_icon", ""),
"createTime": result["create_date"].strftime("%Y-%m-%d %H:%M:%S") if result.get("create_date") else "",
"updateTime": result["update_date"].strftime("%Y-%m-%d %H:%M:%S") if result.get("update_date") else "",
"messages": result.get("message", []),
}
# 打印调试信息
print(f"获取到对话详情: ID={conversation_id}")
print(f"消息数量: {len(conversation['messages']) if conversation['messages'] else 0}")
# 关闭连接
cursor.close()
conn.close()
return conversation
except mysql.connector.Error as err:
print(f"数据库错误: {err}")
# 更详细的错误日志
import traceback
traceback.print_exc()
return None
except Exception as e:
print(f"未知错误: {e}")
import traceback
traceback.print_exc()
return None

View File

@ -1,11 +1,8 @@
<script lang="ts" setup>
import type { CreateOrUpdateTableRequestData, TableData } from "@@/apis/tables/type"
import type { FormInstance, FormRules } from "element-plus"
import { createTableDataApi, deleteTableDataApi, getTableDataApi, resetPasswordApi, updateTableDataApi } from "@@/apis/tables"
import { usePagination } from "@@/composables/usePagination"
import { ChatDotRound, CirclePlus, Delete, Edit, Key, Refresh, RefreshRight, Search, User } from "@element-plus/icons-vue"
import type { TableData } from "@@/apis/tables/type"
import { getTableDataApi } from "@@/apis/tables"
import { ChatDotRound, User } from "@element-plus/icons-vue"
import axios from "axios"
import { cloneDeep } from "lodash-es"
defineOptions({
//
@ -79,10 +76,16 @@ function loadMoreUsers() {
/**
* 监听用户列表滚动事件
* @param event 滚动事件
* @param event DOM滚动事件对象
*/
function handleUserListScroll(event) {
const { scrollTop, scrollHeight, clientHeight } = event.target
function handleUserListScroll(event: Event) {
// event.target HTMLElement
const target = event.target as HTMLElement
if (!target) return
//
const { scrollTop, scrollHeight, clientHeight } = target
// 100px
if (scrollHeight - scrollTop - clientHeight < 100 && userHasMore.value && !userLoading.value) {
loadMoreUsers()
@ -203,10 +206,16 @@ function loadMoreConversations() {
/**
* 监听对话列表滚动事件
* @param event 滚动事件
* @param event DOM滚动事件对象
*/
function handleConversationListScroll(event) {
const { scrollTop, scrollHeight, clientHeight } = event.target
function handleConversationListScroll(event: Event) {
// event.target HTMLElement
const target = event.target as HTMLElement
if (!target) return
//
const { scrollTop, scrollHeight, clientHeight } = target
// 100px
if (scrollHeight - scrollTop - clientHeight < 100 && conversationHasMore.value && !conversationLoading.value) {
loadMoreConversations()
@ -247,7 +256,7 @@ function getMessagesByConversationId(conversationId: string, isLoadMore = false)
const parsedMessages = JSON.parse(conversation.messages)
//
processedMessages = parsedMessages.map((msg, index) => {
processedMessages = parsedMessages.map((msg: { id: any, role: any, content: any, created_at: number }, index: any) => {
return {
id: msg.id || `msg-${index}`,
conversation_id: conversationId,
@ -268,7 +277,7 @@ function getMessagesByConversationId(conversationId: string, isLoadMore = false)
if (isLoadMore) {
//
const existingIds = new Set(messageList.value.map(msg => msg.id))
const uniqueNewMessages = processedMessages.filter(msg => !existingIds.has(msg.id))
const uniqueNewMessages = processedMessages.filter((msg: { id: number }) => !existingIds.has(msg.id))
//
messageList.value = [...messageList.value, ...uniqueNewMessages]
@ -307,7 +316,7 @@ function getMessagesByConversationId(conversationId: string, isLoadMore = false)
* @param content 消息内容
* @returns 处理后的HTML内容
*/
function renderMessageContent(content) {
function renderMessageContent(content: string) {
if (!content) return ""
// Markdown
@ -331,10 +340,16 @@ function loadMoreMessages() {
/**
* 监听消息列表滚动事件
* @param event 滚动事件
* @param event DOM滚动事件对象
*/
function handleMessageListScroll(event) {
const { scrollTop, scrollHeight, clientHeight } = event.target
function handleMessageListScroll(event: Event) {
// event.target HTMLElement
const target = event.target as HTMLElement
if (!target) return
//
const { scrollTop, scrollHeight, clientHeight } = target
// 100px
if (scrollHeight - scrollTop - clientHeight < 100 && messageHasMore.value && !messageLoading.value) {
loadMoreMessages()

View File

@ -1,10 +1,9 @@
<script lang="ts" setup>
import type { CreateOrUpdateTableRequestData, TableData } from "@@/apis/configs/type"
import type { FormInstance, FormRules } from "element-plus"
import { getTableDataApi, updateTableDataApi } from "@@/apis/configs"
import type { TableData } from "@@/apis/configs/type"
import type { FormInstance } from "element-plus"
import { getTableDataApi } from "@@/apis/configs"
import { usePagination } from "@@/composables/usePagination"
import { CirclePlus, Delete, Refresh, RefreshRight, Search } from "@element-plus/icons-vue"
import { cloneDeep } from "lodash-es"
import { Refresh, Search } from "@element-plus/icons-vue"
defineOptions({
//
@ -15,47 +14,41 @@ const loading = ref<boolean>(false)
const { paginationData, handleCurrentChange, handleSizeChange } = usePagination()
// #region
const DEFAULT_FORM_DATA: CreateOrUpdateTableRequestData = {
id: undefined,
username: "",
chatModel: "",
embeddingModel: ""
}
const dialogVisible = ref<boolean>(false)
const formData = ref<CreateOrUpdateTableRequestData>(cloneDeep(DEFAULT_FORM_DATA))
// const DEFAULT_FORM_DATA: CreateOrUpdateTableRequestData = {
// id: undefined,
// username: "",
// chatModel: "",
// embeddingModel: ""
// }
// const dialogVisible = ref<boolean>(false)
// const formData = ref<CreateOrUpdateTableRequestData>(cloneDeep(DEFAULT_FORM_DATA))
//
function handleDelete() {
ElMessage.success("如需删除租户配置,可直接删除负责人账号")
ElMessage.success("如需删除用户配置信息,请直接在前台登录用户账号进行操作")
}
//
function handleUpdate(row: TableData) {
dialogVisible.value = true
formData.value = cloneDeep({
id: row.id,
username: row.username,
chatModel: row.chatModel,
embeddingModel: row.embeddingModel
})
//
function handleUpdate() {
ElMessage.success("如需修改用户配置信息,请直接在前台登录用户账号进行操作")
}
//
function submitForm() {
loading.value = true
updateTableDataApi(formData.value)
.then(() => {
ElMessage.success("修改成功")
dialogVisible.value = false
getTableData() //
})
.catch(() => {
ElMessage.error("修改失败")
})
.finally(() => {
loading.value = false
})
}
// function submitForm() {
// loading.value = true
// updateTableDataApi(formData.value)
// .then(() => {
// ElMessage.success("")
// dialogVisible.value = false
// getTableData() //
// })
// .catch(() => {
// ElMessage.error("")
// })
// .finally(() => {
// loading.value = false
// })
// }
//
const tableData = ref<TableData[]>([])
@ -137,8 +130,8 @@ onActivated(() => {
<el-table-column prop="embeddingModel" label="嵌入模型" align="center" />
<el-table-column prop="updateTime" label="更新时间" align="center" />
<el-table-column fixed="right" label="操作" width="150" align="center">
<template #default="scope">
<el-button type="primary" text bg size="small" @click="handleUpdate(scope.row)">
<template #default="">
<el-button type="primary" text bg size="small" @click="handleUpdate">
修改
</el-button>
<el-button type="danger" text bg size="small" @click="handleDelete()">
@ -163,7 +156,7 @@ onActivated(() => {
</el-card>
<!-- 修改对话框 -->
<el-dialog v-model="dialogVisible" title="修改配置" width="30%">
<!-- <el-dialog v-model="dialogVisible" title="修改配置" width="30%">
<el-form :model="formData" label-width="100px">
<el-form-item label="用户名">
<el-input v-model="formData.username" disabled />
@ -183,7 +176,7 @@ onActivated(() => {
确认
</el-button>
</template>
</el-dialog>
</el-dialog> -->
</div>
</template>

View File

@ -32,16 +32,19 @@ def chunks_format(reference):
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))
return [{
"id": get_value(chunk, "chunk_id", "id"),
"content": get_value(chunk, "content", "content_with_weight"),
"document_id": get_value(chunk, "doc_id", "document_id"),
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"),
"url": chunk.get("url")
} for chunk in reference.get("chunks", [])]
return [
{
"id": get_value(chunk, "chunk_id", "id"),
"content": get_value(chunk, "content", "content_with_weight"),
"document_id": get_value(chunk, "doc_id", "document_id"),
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"),
"url": chunk.get("url"),
}
for chunk in reference.get("chunks", [])
]
def llm_id2llm_type(llm_id):
@ -57,21 +60,21 @@ def llm_id2llm_type(llm_id):
def message_fit_in(msg, max_length=4000):
"""
调整消息列表使其token总数不超过max_length限制
参数:
msg: 消息列表每个元素为包含role和content的字典
max_length: 最大token数限制默认4000
返回:
tuple: (实际token数, 调整后的消息列表)
"""
def count():
"""计算当前消息列表的总token数"""
nonlocal msg
tks_cnts = []
for m in msg:
tks_cnts.append(
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
total = 0
for m in tks_cnts:
total += m["count"]
@ -81,7 +84,7 @@ def message_fit_in(msg, max_length=4000):
# 如果不超限制,直接返回
if c < max_length:
return c, msg
# 第一次精简:保留系统消息和最后一条消息
msg_ = [m for m in msg if m["role"] == "system"]
if len(msg) > 1:
@ -90,20 +93,20 @@ def message_fit_in(msg, max_length=4000):
c = count()
if c < max_length:
return c, msg
# 计算系统消息和最后一条消息的token数
ll = num_tokens_from_string(msg_[0]["content"])
ll2 = num_tokens_from_string(msg_[-1]["content"])
# 如果系统消息占比超过80%,则截断系统消息
if ll / (ll + ll2) > 0.8:
m = msg_[0]["content"]
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
msg[0]["content"] = m
return max_length, msg
# 否则截断最后一条消息
m = msg_[-1]["content"]
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
msg[-1]["content"] = m
return max_length, msg
@ -111,18 +114,18 @@ def message_fit_in(msg, max_length=4000):
def kb_prompt(kbinfos, max_tokens):
"""
将检索到的知识库内容格式化为适合大语言模型的提示词
参数:
kbinfos (dict): 检索结果包含chunks等信息
max_tokens (int): 模型的最大token限制
流程:
1. 提取所有检索到的文档片段内容
2. 计算token数量确保不超过模型限制
3. 获取文档元数据
4. 按文档名组织文档片段
5. 格式化为结构化提示词
返回:
list: 格式化后的知识库内容列表每个元素是一个文档的相关信息
"""
@ -134,7 +137,7 @@ def kb_prompt(kbinfos, max_tokens):
chunks_num += 1
if max_tokens * 0.97 < used_token_count:
knowledges = knowledges[:i]
logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}")
logging.warning(f"Not all the retrieval into prompt: {i + 1}/{len(knowledges)}")
break
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
@ -163,6 +166,10 @@ def citation_prompt():
- 以格式 '##i$$ ##j$$'插入引用其中 i, j 是所引用内容的 ID并用 '##' '$$' 包裹
- 在句子末尾插入引用每个句子最多 4 个引用
- 如果答案内容不来自检索到的文本块则不要插入引用
- 不要使用独立的文档 ID例如 `#ID#`)。
- 在任何情况下不得使用其他引用样式或格式例如 `~~i==``[i]``(i)`
- 引用必须始终使用 `##i$$` 格式。
- 任何未能遵守上述规则的情况包括但不限于格式错误使用禁止的样式或不支持的引用都将被视为错误应跳过为该句添加引用
--- 示例 ---
<SYSTEM>: 以下是知识库:
@ -210,10 +217,7 @@ def keyword_extraction(chat_mdl, content, topn=3):
### 文本内容
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
@ -240,10 +244,7 @@ Requirements:
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
@ -368,10 +369,7 @@ Output:
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple):
@ -384,8 +382,8 @@ Output:
return json_repair.loads(kwd)
except json_repair.JSONDecodeError:
try:
result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
result = kwd.replace(prompt[:-1], "").replace("user", "").replace("model", "").strip()
result = "{" + result.split("{")[1].split("}")[0] + "}"
return json_repair.loads(result)
except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}")

View File

@ -27,7 +27,8 @@ import {
UploadProps,
} from 'antd';
import get from 'lodash/get';
import { CircleStop, Paperclip, SendHorizontal } from 'lucide-react';
// import { CircleStop, Paperclip, SendHorizontal } from 'lucide-react';
import { CircleStop, SendHorizontal } from 'lucide-react';
import {
ChangeEventHandler,
memo,
@ -339,9 +340,9 @@ const MessageInput = ({
return false;
}}
>
<Button type={'primary'} disabled={disabled}>
{/* <Button type={'primary'} disabled={disabled}>
<Paperclip className="size-4" />
</Button>
</Button> */}
</Upload>
)}
{sendLoading ? (

View File

@ -17,13 +17,11 @@
line-height: 1.2;
border-bottom: 2px solid #eaeaea;
padding-bottom: 0.25em;
font-size: 24px;
margin: 0.25em 0.25em;
}
section {
margin-top: 1em;
margin-bottom: 1em;
font-size: 20px;
p {
margin-left: 0;
}
@ -32,7 +30,6 @@
margin-top: 1em;
margin-bottom: 1em;
margin-left: 1em;
font-size: 20px;
}
ul,
ol {
@ -40,7 +37,6 @@
padding-left: 1.8em;
li {
margin-bottom: 0.25em;
font-size: 20px;
}
}
table {
@ -60,7 +56,6 @@
border: none;
padding: 12px;
text-align: left;
font-size: 20px;
border-bottom: 1px solid #ddd;
}

View File

@ -24,9 +24,10 @@ import styles from './index.less';
interface IProps {
controller: AbortController;
fontSize: number;
}
const ChatContainer = ({ controller }: IProps) => {
const ChatContainer = ({ controller, fontSize = 16 }: IProps) => {
const { conversationId } = useGetChatSearchParams();
const { data: conversation } = useFetchNextConversation();
@ -55,7 +56,12 @@ const ChatContainer = ({ controller }: IProps) => {
return (
<>
<Flex flex={1} className={styles.chatContainer} vertical>
<Flex flex={1} vertical className={styles.messageContainer}>
<Flex
flex={1}
vertical
className={styles.messageContainer}
style={{ fontSize: `${fontSize}px` }}
>
<div>
<Spin spinning={loading}>
{derivedMessages?.map((message, i) => {

View File

@ -1,6 +1,10 @@
import { ReactComponent as ChatAppCube } from '@/assets/svg/chat-app-cube.svg';
import RenameModal from '@/components/rename-modal';
import { DeleteOutlined, EditOutlined } from '@ant-design/icons';
import {
DeleteOutlined,
EditOutlined,
SettingOutlined,
} from '@ant-design/icons';
import {
Avatar,
Button,
@ -12,12 +16,12 @@ import {
Space,
Spin,
Tag,
Tooltip,
// Tooltip,
Typography,
} from 'antd';
import { MenuItemProps } from 'antd/lib/menu/MenuItem';
import classNames from 'classnames';
import { useCallback, useState } from 'react';
import { useCallback, useEffect, useState } from 'react';
import ChatConfigurationModal from './chat-configuration-modal';
import ChatContainer from './chat-container';
import {
@ -43,9 +47,9 @@ import {
import { useTranslate } from '@/hooks/common-hooks';
import { useSetSelectedRecord } from '@/hooks/logic-hooks';
import { IDialog } from '@/interfaces/database/chat';
import { Modal, Slider } from 'antd';
import { PictureInPicture2 } from 'lucide-react';
import styles from './index.less';
const { Text } = Typography;
const Chat = () => {
@ -161,6 +165,17 @@ const Chat = () => {
addTemporaryConversation();
}, [addTemporaryConversation]);
const [fontSizeModalVisible, setFontSizeModalVisible] = useState(false);
const [fontSize, setFontSize] = useState(16); // 默认字体大小
// 从localStorage加载字体大小设置
useEffect(() => {
const savedFontSize = localStorage.getItem('chatFontSize');
if (savedFontSize) {
setFontSize(parseInt(savedFontSize));
}
}, []);
const buildAppItems = (dialog: IDialog) => {
const dialogId = dialog.id;
@ -286,6 +301,7 @@ const Chat = () => {
</Flex>
</Flex>
<Divider type={'vertical'} className={styles.divider}></Divider>
<Flex className={styles.chatTitleWrapper}>
<Flex flex={1} vertical>
<Flex
@ -297,15 +313,23 @@ const Chat = () => {
<b>{t('chat')}</b>
<Tag>{conversationList.length}</Tag>
</Space>
<Tooltip title={t('newChat')}>
<div>
<SvgIcon
name="plus-circle-fill"
width={20}
onClick={handleCreateTemporaryConversation}
></SvgIcon>
</div>
</Tooltip>
{/* <Tooltip title={t('newChat')}> */}
<div>
<SettingOutlined
style={{
marginRight: '8px',
fontSize: '20px',
cursor: 'pointer',
}}
onClick={() => setFontSizeModalVisible(true)}
/>
<SvgIcon
name="plus-circle-fill"
width={20}
onClick={handleCreateTemporaryConversation}
></SvgIcon>
</div>
{/* </Tooltip> */}
</Flex>
<Divider></Divider>
<Flex vertical gap={10} className={styles.chatTitleContent}>
@ -356,7 +380,10 @@ const Chat = () => {
</Flex>
</Flex>
<Divider type={'vertical'} className={styles.divider}></Divider>
<ChatContainer controller={controller}></ChatContainer>
<ChatContainer
controller={controller}
fontSize={fontSize}
></ChatContainer>
{dialogEditVisible && (
<ChatConfigurationModal
visible={dialogEditVisible}
@ -386,6 +413,27 @@ const Chat = () => {
isAgent={false}
></EmbedModal>
)}
{fontSizeModalVisible && (
<Modal
title={'设置字体大小'}
open={fontSizeModalVisible}
onCancel={() => setFontSizeModalVisible(false)}
footer={null}
>
<Flex vertical gap="middle" align="center">
{'当前字体大小'}: {fontSize}px
<Slider
min={12}
max={24}
step={1}
defaultValue={fontSize}
style={{ width: '80%' }}
onChange={(value) => setFontSize(value)}
/>
</Flex>
</Modal>
)}
</Flex>
);
};