diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index d29666a..5fb41a7 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -37,7 +37,8 @@ from api.db.services.file_service import FileService from flask import jsonify, request, Response -@manager.route('/chats//sessions', methods=['POST']) # noqa: F821 + +@manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 @token_required def create(tenant_id, chat_id): req = request.json @@ -50,7 +51,7 @@ def create(tenant_id, chat_id): "dialog_id": req["dialog_id"], "name": req.get("name", "New session"), "message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}], - "user_id": req.get("user_id", "") + "user_id": req.get("user_id", ""), } if not conv.get("name"): return get_error_data_result(message="`name` can not be empty.") @@ -59,20 +60,20 @@ def create(tenant_id, chat_id): if not e: return get_error_data_result(message="Fail to create a session!") conv = conv.to_dict() - conv['messages'] = conv.pop("message") + conv["messages"] = conv.pop("message") conv["chat_id"] = conv.pop("dialog_id") del conv["reference"] return get_result(data=conv) -@manager.route('/agents//sessions', methods=['POST']) # noqa: F821 +@manager.route("/agents//sessions", methods=["POST"]) # noqa: F821 @token_required def create_agent_session(tenant_id, agent_id): req = request.json if not request.is_json: req = request.form files = request.files - user_id = request.args.get('user_id', '') + user_id = request.args.get("user_id", "") e, cvs = UserCanvasService.get_by_id(agent_id) if not e: @@ -113,7 +114,7 @@ def create_agent_session(tenant_id, agent_id): ele.pop("value") else: if req is not None and req.get(ele["key"]): - ele["value"] = req[ele['key']] + ele["value"] = req[ele["key"]] else: if "value" in ele: ele.pop("value") @@ -121,20 +122,13 @@ def create_agent_session(tenant_id, agent_id): for ans in canvas.run(stream=False): pass cvs.dsl = json.loads(str(canvas)) - conv = { - "id": get_uuid(), - "dialog_id": cvs.id, - "user_id": user_id, - "message": [{"role": "assistant", "content": canvas.get_prologue()}], - "source": "agent", - "dsl": cvs.dsl - } + conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl} API4ConversationService.save(**conv) conv["agent_id"] = conv.pop("dialog_id") return get_result(data=conv) -@manager.route('/chats//sessions/', methods=['PUT']) # noqa: F821 +@manager.route("/chats//sessions/", methods=["PUT"]) # noqa: F821 @token_required def update(tenant_id, chat_id, session_id): req = request.json @@ -156,14 +150,14 @@ def update(tenant_id, chat_id, session_id): return get_result() -@manager.route('/chats//completions', methods=['POST']) # noqa: F821 +@manager.route("/chats//completions", methods=["POST"]) # noqa: F821 @token_required def chat_completion(tenant_id, chat_id): req = request.json if not req: req = {"question": ""} if not req.get("session_id"): - req["question"]="" + req["question"] = "" if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): return get_error_data_result(f"You don't own the chat {chat_id}") if req.get("session_id"): @@ -185,7 +179,7 @@ def chat_completion(tenant_id, chat_id): return get_result(data=answer) -@manager.route('chats_openai//chat/completions', methods=['POST']) # noqa: F821 +@manager.route("chats_openai//chat/completions", methods=["POST"]) # noqa: F821 @validate_request("model", "messages") # noqa: F821 @token_required def chat_completion_openai_like(tenant_id, chat_id): @@ -259,39 +253,23 @@ def chat_completion_openai_like(tenant_id, chat_id): # The choices field on the last chunk will always be an empty array []. def streamed_response_generator(chat_id, dia, msg): token_used = 0 - should_split_index = 0 + answer_cache = "" response = { "id": f"chatcmpl-{chat_id}", - "choices": [ - { - "delta": { - "content": "", - "role": "assistant", - "function_call": None, - "tool_calls": None - }, - "finish_reason": None, - "index": 0, - "logprobs": None - } - ], + "choices": [{"delta": {"content": "", "role": "assistant", "function_call": None, "tool_calls": None}, "finish_reason": None, "index": 0, "logprobs": None}], "created": int(time.time()), "model": "model", "object": "chat.completion.chunk", "system_fingerprint": "", - "usage": None + "usage": None, } try: for ans in chat(dia, msg, True): answer = ans["answer"] - incremental = answer[should_split_index:] + incremental = answer.replace(answer_cache, "", 1) + answer_cache = answer.rstrip("") token_used += len(incremental) - if incremental.endswith(""): - response_data_len = len(incremental.rstrip("")) - else: - response_data_len = len(incremental) - should_split_index += response_data_len response["choices"][0]["delta"]["content"] = incremental yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" except Exception as e: @@ -301,15 +279,10 @@ def chat_completion_openai_like(tenant_id, chat_id): # The last chunk response["choices"][0]["delta"]["content"] = None response["choices"][0]["finish_reason"] = "stop" - response["usage"] = { - "prompt_tokens": len(prompt), - "completion_tokens": token_used, - "total_tokens": len(prompt) + token_used - } + response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used} yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" yield "data:[DONE]\n\n" - resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Connection", "keep-alive") @@ -324,7 +297,7 @@ def chat_completion_openai_like(tenant_id, chat_id): break content = answer["answer"] - response = { + response = { "id": f"chatcmpl-{chat_id}", "object": "chat.completion", "created": int(time.time()), @@ -336,25 +309,15 @@ def chat_completion_openai_like(tenant_id, chat_id): "completion_tokens_details": { "reasoning_tokens": context_token_used, "accepted_prediction_tokens": len(content), - "rejected_prediction_tokens": 0 # 0 for simplicity - } + "rejected_prediction_tokens": 0, # 0 for simplicity + }, }, - "choices": [ - { - "message": { - "role": "assistant", - "content": content - }, - "logprobs": None, - "finish_reason": "stop", - "index": 0 - } - ] + "choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": "stop", "index": 0}], } return jsonify(response) -@manager.route('/agents//completions', methods=['POST']) # noqa: F821 +@manager.route("/agents//completions", methods=["POST"]) # noqa: F821 @token_required def agent_completions(tenant_id, agent_id): req = request.json @@ -365,8 +328,8 @@ def agent_completions(tenant_id, agent_id): dsl = cvs[0].dsl if not isinstance(dsl, str): dsl = json.dumps(dsl) - #canvas = Canvas(dsl, tenant_id) - #if canvas.get_preset_param(): + # canvas = Canvas(dsl, tenant_id) + # if canvas.get_preset_param(): # req["question"] = "" conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id) if not conv: @@ -380,9 +343,7 @@ def agent_completions(tenant_id, agent_id): states = {field: current_dsl.get(field, []) for field in state_fields} current_dsl.update(new_dsl) current_dsl.update(states) - API4ConversationService.update_by_id(req["session_id"], { - "dsl": current_dsl - }) + API4ConversationService.update_by_id(req["session_id"], {"dsl": current_dsl}) else: req["question"] = "" if req.get("stream", True): @@ -399,7 +360,7 @@ def agent_completions(tenant_id, agent_id): return get_error_data_result(str(e)) -@manager.route('/chats//sessions', methods=['GET']) # noqa: F821 +@manager.route("/chats//sessions", methods=["GET"]) # noqa: F821 @token_required def list_session(tenant_id, chat_id): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): @@ -418,7 +379,7 @@ def list_session(tenant_id, chat_id): if not convs: return get_result(data=[]) for conv in convs: - conv['messages'] = conv.pop("message") + conv["messages"] = conv.pop("message") infos = conv["messages"] for info in infos: if "prompt" in info: @@ -452,7 +413,7 @@ def list_session(tenant_id, chat_id): return get_result(data=convs) -@manager.route('/agents//sessions', methods=['GET']) # noqa: F821 +@manager.route("/agents//sessions", methods=["GET"]) # noqa: F821 @token_required def list_agent_session(tenant_id, agent_id): if not UserCanvasService.query(user_id=tenant_id, id=agent_id): @@ -468,12 +429,11 @@ def list_agent_session(tenant_id, agent_id): desc = True # dsl defaults to True in all cases except for False and false include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false" - convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, - user_id, include_dsl) + convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl) if not convs: return get_result(data=[]) for conv in convs: - conv['messages'] = conv.pop("message") + conv["messages"] = conv.pop("message") infos = conv["messages"] for info in infos: if "prompt" in info: @@ -506,7 +466,7 @@ def list_agent_session(tenant_id, agent_id): return get_result(data=convs) -@manager.route('/chats//sessions', methods=["DELETE"]) # noqa: F821 +@manager.route("/chats//sessions", methods=["DELETE"]) # noqa: F821 @token_required def delete(tenant_id, chat_id): if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): @@ -532,14 +492,14 @@ def delete(tenant_id, chat_id): return get_result() -@manager.route('/agents//sessions', methods=["DELETE"]) # noqa: F821 +@manager.route("/agents//sessions", methods=["DELETE"]) # noqa: F821 @token_required def delete_agent_session(tenant_id, agent_id): req = request.json cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id) if not cvs: return get_error_data_result(f"You don't own the agent {agent_id}") - + convs = API4ConversationService.query(dialog_id=agent_id) if not convs: return get_error_data_result(f"Agent {agent_id} has no sessions") @@ -555,16 +515,16 @@ def delete_agent_session(tenant_id, agent_id): conv_list.append(conv.id) else: conv_list = ids - + for session_id in conv_list: conv = API4ConversationService.query(id=session_id, dialog_id=agent_id) if not conv: return get_error_data_result(f"The agent doesn't own the session ${session_id}") API4ConversationService.delete_by_id(session_id) return get_result() - -@manager.route('/sessions/ask', methods=['POST']) # noqa: F821 + +@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821 @token_required def ask_about(tenant_id): req = request.json @@ -590,9 +550,7 @@ def ask_about(tenant_id): for ans in ask(req["question"], req["kb_ids"], uid): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: - yield "data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" resp = Response(stream(), mimetype="text/event-stream") @@ -603,7 +561,7 @@ def ask_about(tenant_id): return resp -@manager.route('/sessions/related_questions', methods=['POST']) # noqa: F821 +@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821 @token_required def related_questions(tenant_id): req = request.json @@ -635,18 +593,27 @@ Reason: - At the same time, related terms can also help search engines better understand user needs and return more accurate search results. """ - ans = chat_mdl.chat(prompt, [{"role": "user", "content": f""" + ans = chat_mdl.chat( + prompt, + [ + { + "role": "user", + "content": f""" Keywords: {question} Related search terms: - """}], {"temperature": 0.9}) + """, + } + ], + {"temperature": 0.9}, + ) return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) -@manager.route('/chatbots//completions', methods=['POST']) # noqa: F821 +@manager.route("/chatbots//completions", methods=["POST"]) # noqa: F821 def chatbot_completions(dialog_id): req = request.json - token = request.headers.get('Authorization').split() + token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') token = token[1] @@ -669,11 +636,11 @@ def chatbot_completions(dialog_id): return get_result(data=answer) -@manager.route('/agentbots//completions', methods=['POST']) # noqa: F821 +@manager.route("/agentbots//completions", methods=["POST"]) # noqa: F821 def agent_bot_completions(agent_id): req = request.json - token = request.headers.get('Authorization').split() + token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') token = token[1] diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 470f24a..3f665da 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -47,20 +47,14 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() - def get_list(cls, kb_id, page_number, items_per_page, - orderby, desc, keywords, id, name): + def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, id, name): docs = cls.model.select().where(cls.model.kb_id == kb_id) if id: - docs = docs.where( - cls.model.id == id) + docs = docs.where(cls.model.id == id) if name: - docs = docs.where( - cls.model.name == name - ) + docs = docs.where(cls.model.name == name) if keywords: - docs = docs.where( - fn.LOWER(cls.model.name).contains(keywords.lower()) - ) + docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower())) if desc: docs = docs.order_by(cls.model.getter_by(orderby).desc()) else: @@ -72,13 +66,9 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() - def get_by_kb_id(cls, kb_id, page_number, items_per_page, - orderby, desc, keywords): + def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords): if keywords: - docs = cls.model.select().where( - (cls.model.kb_id == kb_id), - (fn.LOWER(cls.model.name).contains(keywords.lower())) - ) + docs = cls.model.select().where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower()))) else: docs = cls.model.select().where(cls.model.kb_id == kb_id) count = docs.count() @@ -97,8 +87,7 @@ class DocumentService(CommonService): if not cls.save(**doc): raise RuntimeError("Database error (Document)!") e, kb = KnowledgebaseService.get_by_id(doc["kb_id"]) - if not KnowledgebaseService.update_by_id( - kb.id, {"doc_num": kb.doc_num + 1}): + if not KnowledgebaseService.update_by_id(kb.id, {"doc_num": kb.doc_num + 1}): raise RuntimeError("Database error (Knowledgebase)!") return Document(**doc) @@ -108,14 +97,16 @@ class DocumentService(CommonService): cls.clear_chunk_num(doc.id) try: settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) - settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id}, - {"remove": {"source_id": doc.id}}, - search.index_name(tenant_id), doc.kb_id) - settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, - {"removed_kwd": "Y"}, - search.index_name(tenant_id), doc.kb_id) - settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}}, - search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.update( + {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id}, + {"remove": {"source_id": doc.id}}, + search.index_name(tenant_id), + doc.kb_id, + ) + settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.delete( + {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}}, search.index_name(tenant_id), doc.kb_id + ) except Exception: pass return cls.delete_by_id(doc.id) @@ -136,67 +127,54 @@ class DocumentService(CommonService): Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, - cls.model.update_time] - docs = cls.model.select(*fields) \ - .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ - .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \ + cls.model.update_time, + ] + docs = ( + cls.model.select(*fields) + .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) + .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) .where( - cls.model.status == StatusEnum.VALID.value, - ~(cls.model.type == FileType.VIRTUAL.value), - cls.model.progress == 0, - cls.model.update_time >= current_timestamp() - 1000 * 600, - cls.model.run == TaskStatus.RUNNING.value) \ + cls.model.status == StatusEnum.VALID.value, + ~(cls.model.type == FileType.VIRTUAL.value), + cls.model.progress == 0, + cls.model.update_time >= current_timestamp() - 1000 * 600, + cls.model.run == TaskStatus.RUNNING.value, + ) .order_by(cls.model.update_time.asc()) + ) return list(docs.dicts()) @classmethod @DB.connection_context() def get_unfinished_docs(cls): - fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, - cls.model.run, cls.model.parser_id] - docs = cls.model.select(*fields) \ - .where( - cls.model.status == StatusEnum.VALID.value, - ~(cls.model.type == FileType.VIRTUAL.value), - cls.model.progress < 1, - cls.model.progress > 0) + fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id] + docs = cls.model.select(*fields).where(cls.model.status == StatusEnum.VALID.value, ~(cls.model.type == FileType.VIRTUAL.value), cls.model.progress < 1, cls.model.progress > 0) return list(docs.dicts()) @classmethod @DB.connection_context() def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): - num = cls.model.update(token_num=cls.model.token_num + token_num, - chunk_num=cls.model.chunk_num + chunk_num, - process_duation=cls.model.process_duation + duation).where( - cls.model.id == doc_id).execute() + num = ( + cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duation=cls.model.process_duation + duation) + .where(cls.model.id == doc_id) + .execute() + ) if num == 0: - raise LookupError( - "Document not found which is supposed to be there") - num = Knowledgebase.update( - token_num=Knowledgebase.token_num + - token_num, - chunk_num=Knowledgebase.chunk_num + - chunk_num).where( - Knowledgebase.id == kb_id).execute() + raise LookupError("Document not found which is supposed to be there") + num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute() return num @classmethod @DB.connection_context() def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): - num = cls.model.update(token_num=cls.model.token_num - token_num, - chunk_num=cls.model.chunk_num - chunk_num, - process_duation=cls.model.process_duation + duation).where( - cls.model.id == doc_id).execute() + num = ( + cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duation=cls.model.process_duation + duation) + .where(cls.model.id == doc_id) + .execute() + ) if num == 0: - raise LookupError( - "Document not found which is supposed to be there") - num = Knowledgebase.update( - token_num=Knowledgebase.token_num - - token_num, - chunk_num=Knowledgebase.chunk_num - - chunk_num - ).where( - Knowledgebase.id == kb_id).execute() + raise LookupError("Document not found which is supposed to be there") + num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute() return num @classmethod @@ -205,24 +183,17 @@ class DocumentService(CommonService): doc = cls.model.get_by_id(doc_id) assert doc, "Can't fine document in database." - num = Knowledgebase.update( - token_num=Knowledgebase.token_num - - doc.token_num, - chunk_num=Knowledgebase.chunk_num - - doc.chunk_num, - doc_num=Knowledgebase.doc_num - 1 - ).where( - Knowledgebase.id == doc.kb_id).execute() + num = ( + Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1) + .where(Knowledgebase.id == doc.kb_id) + .execute() + ) return num @classmethod @DB.connection_context() def get_tenant_id(cls, doc_id): - docs = cls.model.select( - Knowledgebase.tenant_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) + docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: return @@ -240,11 +211,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_tenant_id_by_name(cls, name): - docs = cls.model.select( - Knowledgebase.tenant_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value) + docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: return @@ -253,12 +220,13 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def accessible(cls, doc_id, user_id): - docs = cls.model.select( - cls.model.id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id) - ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) - ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1) + docs = ( + cls.model.select(cls.model.id) + .join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)) + .join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)) + .where(cls.model.id == doc_id, UserTenant.user_id == user_id) + .paginate(0, 1) + ) docs = docs.dicts() if not docs: return False @@ -267,11 +235,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def accessible4deletion(cls, doc_id, user_id): - docs = cls.model.select( - cls.model.id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id) - ).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1) + docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1) docs = docs.dicts() if not docs: return False @@ -280,11 +244,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_embd_id(cls, doc_id): - docs = cls.model.select( - Knowledgebase.embd_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) + docs = cls.model.select(Knowledgebase.embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: return @@ -306,9 +266,9 @@ class DocumentService(CommonService): Tenant.asr_id, Tenant.llm_id, ) - .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) - .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) - .where(cls.model.id == doc_id) + .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) + .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) + .where(cls.model.id == doc_id) ) configs = configs.dicts() if not configs: @@ -319,8 +279,7 @@ class DocumentService(CommonService): @DB.connection_context() def get_doc_id_by_doc_name(cls, doc_name): fields = [cls.model.id] - doc_id = cls.model.select(*fields) \ - .where(cls.model.name == doc_name) + doc_id = cls.model.select(*fields).where(cls.model.name == doc_name) doc_id = doc_id.dicts() if not doc_id: return @@ -330,8 +289,7 @@ class DocumentService(CommonService): @DB.connection_context() def get_thumbnails(cls, docids): fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail] - return list(cls.model.select( - *fields).where(cls.model.id.in_(docids)).dicts()) + return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts()) @classmethod @DB.connection_context() @@ -359,19 +317,14 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_doc_count(cls, tenant_id): - docs = cls.model.select(cls.model.id).join(Knowledgebase, - on=(Knowledgebase.id == cls.model.kb_id)).where( - Knowledgebase.tenant_id == tenant_id) + docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(Knowledgebase.tenant_id == tenant_id) return len(docs) @classmethod @DB.connection_context() def begin2parse(cls, docid): - cls.update_by_id( - docid, {"progress": random.random() * 1 / 100., - "progress_msg": "Task is queued...", - "process_begin_at": get_format_time() - }) + cls.update_by_id(docid, {"progress": random.random() * 1 / 100.0, "progress_msg": "Task is queued...", "process_begin_at": get_format_time()}) + @classmethod @DB.connection_context() def update_meta_fields(cls, doc_id, meta_fields): @@ -420,11 +373,7 @@ class DocumentService(CommonService): status = TaskStatus.DONE.value msg = "\n".join(sorted(msg)) - info = { - "process_duation": datetime.timestamp( - datetime.now()) - - d["process_begin_at"].timestamp(), - "run": status} + info = {"process_duation": datetime.timestamp(datetime.now()) - d["process_begin_at"].timestamp(), "run": status} if prg != 0: info["progress"] = prg if msg: @@ -437,8 +386,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_kb_doc_count(cls, kb_id): - return len(cls.model.select(cls.model.id).where( - cls.model.kb_id == kb_id).dicts()) + return len(cls.model.select(cls.model.id).where(cls.model.kb_id == kb_id).dicts()) @classmethod @DB.connection_context() @@ -459,14 +407,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty): def new_task(): nonlocal doc - return { - "id": get_uuid(), - "doc_id": doc["id"], - "from_page": 100000000, - "to_page": 100000000, - "task_type": ty, - "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty - } + return {"id": get_uuid(), "doc_id": doc["id"], "from_page": 100000000, "to_page": 100000000, "task_type": ty, "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty} task = new_task() for field in ["doc_id", "from_page", "to_page"]: @@ -478,6 +419,25 @@ def queue_raptor_o_graphrag_tasks(doc, ty): def doc_upload_and_parse(conversation_id, file_objs, user_id): + """ + 上传并解析文档,将内容存入知识库 + + 参数: + conversation_id: 会话ID + file_objs: 文件对象列表 + user_id: 用户ID + + 返回: + 处理成功的文档ID列表 + + 处理流程: + 1. 验证会话和知识库 + 2. 初始化嵌入模型 + 3. 上传文件到存储 + 4. 多线程解析文件内容 + 5. 生成内容嵌入向量 + 6. 存入文档存储系统 + """ from rag.app import presentation, picture, naive, audio, email from api.db.services.dialog_service import DialogService from api.db.services.file_service import FileService @@ -493,8 +453,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): e, dia = DialogService.get_by_id(conv.dialog_id) if not dia.kb_ids: - raise LookupError("No knowledge base associated with this conversation. " - "Please add a knowledge base before uploading documents") + raise LookupError("No knowledge base associated with this conversation. Please add a knowledge base before uploading documents") kb_id = dia.kb_ids[0] e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: @@ -508,12 +467,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): def dummy(prog=None, msg=""): pass - FACTORY = { - ParserType.PRESENTATION.value: presentation, - ParserType.PICTURE.value: picture, - ParserType.AUDIO.value: audio, - ParserType.EMAIL.value: email - } + FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email} parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"} # 使用线程池执行解析任务 exe = ThreadPoolExecutor(max_workers=12) @@ -522,22 +476,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): for d, blob in files: doc_nm[d["id"]] = d["name"] for d, blob in files: - kwargs = { - "callback": dummy, - "parser_config": parser_config, - "from_page": 0, - "to_page": 100000, - "tenant_id": kb.tenant_id, - "lang": kb.language - } + kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language} threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs)) for (docinfo, _), th in zip(files, threads): docs = [] - doc = { - "doc_id": docinfo["id"], - "kb_id": [kb.id] - } + doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]} for ck in th.result(): d = deepcopy(doc) d.update(ck) @@ -552,7 +496,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): if isinstance(d["image"], bytes): output_buffer = BytesIO(d["image"]) else: - d["image"].save(output_buffer, format='JPEG') + d["image"].save(output_buffer, format="JPEG") STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue()) d["img_id"] = "{}-{}".format(kb.id, d["id"]) @@ -569,9 +513,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): nonlocal embd_mdl, chunk_counts, token_counts vects = [] for i in range(0, len(cnts), batch_size): - vts, c = embd_mdl.encode(cnts[i: i + batch_size]) + vts, c = embd_mdl.encode(cnts[i : i + batch_size]) vects.extend(vts.tolist()) - chunk_counts[doc_id] += len(cnts[i:i + batch_size]) + chunk_counts[doc_id] += len(cnts[i : i + batch_size]) token_counts[doc_id] += c return vects @@ -585,22 +529,25 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): if parser_ids[doc_id] != ParserType.PICTURE.value: from graphrag.general.mind_map_extractor import MindMapExtractor + mindmap = MindMapExtractor(llm_bdl) try: mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]) mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2) if len(mind_map) < 32: raise Exception("Few content: " + mind_map) - cks.append({ - "id": get_uuid(), - "doc_id": doc_id, - "kb_id": [kb.id], - "docnm_kwd": doc_nm[doc_id], - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])), - "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"), - "content_with_weight": mind_map, - "knowledge_graph_kwd": "mind_map" - }) + cks.append( + { + "id": get_uuid(), + "doc_id": doc_id, + "kb_id": [kb.id], + "docnm_kwd": doc_nm[doc_id], + "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])), + "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"), + "content_with_weight": mind_map, + "knowledge_graph_kwd": "mind_map", + } + ) except Exception as e: logging.exception("Mind map generation error") @@ -614,9 +561,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): if not settings.docStoreConn.indexExist(idxnm, kb_id): settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) try_create_idx = False - settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) + settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id) - DocumentService.increment_chunk_num( - doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) + DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) return [d["id"] for d, _ in files] diff --git a/management/server/routes/conversation/routes.py b/management/server/routes/conversation/routes.py index ac750f6..08f53ec 100644 --- a/management/server/routes/conversation/routes.py +++ b/management/server/routes/conversation/routes.py @@ -1,5 +1,5 @@ from flask import jsonify, request -from services.conversation.service import get_conversations_by_user_id, get_messages_by_conversation_id, get_conversation_detail +from services.conversation.service import get_conversations_by_user_id, get_messages_by_conversation_id from .. import conversation_bp @@ -44,20 +44,3 @@ def get_messages(conversation_id): except Exception as e: # 错误处理 return jsonify({"code": 500, "message": f"获取消息列表失败: {str(e)}"}), 500 - - -@conversation_bp.route("/", methods=["GET"]) -def get_conversation(conversation_id): - """获取特定对话的详细信息""" - try: - # 调用服务函数获取对话详情 - conversation = get_conversation_detail(conversation_id) - - if not conversation: - return jsonify({"code": 404, "message": "对话不存在"}), 404 - - # 返回符合前端期望格式的数据 - return jsonify({"code": 0, "data": conversation, "message": "获取对话详情成功"}) - except Exception as e: - # 错误处理 - return jsonify({"code": 500, "message": f"获取对话详情失败: {str(e)}"}), 500 diff --git a/management/server/services/conversation/service.py b/management/server/services/conversation/service.py index b1d886e..dcff44d 100644 --- a/management/server/services/conversation/service.py +++ b/management/server/services/conversation/service.py @@ -23,8 +23,6 @@ def get_conversations_by_user_id(user_id, page=1, size=20, sort_by="update_time" # 直接使用user_id作为tenant_id tenant_id = user_id - print(f"查询用户ID: {user_id}, 租户ID: {tenant_id}") - # 查询总记录数 count_sql = """ SELECT COUNT(*) as total @@ -34,7 +32,7 @@ def get_conversations_by_user_id(user_id, page=1, size=20, sort_by="update_time" cursor.execute(count_sql, (tenant_id,)) total = cursor.fetchone()["total"] - print(f"查询到总记录数: {total}") + # print(f"查询到总记录数: {total}") # 计算分页偏移量 offset = (page - 1) * size @@ -59,8 +57,8 @@ def get_conversations_by_user_id(user_id, page=1, size=20, sort_by="update_time" LIMIT %s OFFSET %s """ - print(f"执行查询: {query}") - print(f"参数: tenant_id={tenant_id}, size={size}, offset={offset}") + # print(f"执行查询: {query}") + # print(f"参数: tenant_id={tenant_id}, size={size}, offset={offset}") cursor.execute(query, (tenant_id, size, offset)) results = cursor.fetchall() @@ -200,68 +198,3 @@ def get_messages_by_conversation_id(conversation_id, page=1, size=30): traceback.print_exc() return None, 0 - - -def get_conversation_detail(conversation_id): - """ - 获取特定对话的详细信息 - - 参数: - conversation_id (str): 对话ID - - 返回: - dict: 对话详情 - """ - try: - conn = mysql.connector.connect(**DB_CONFIG) - cursor = conn.cursor(dictionary=True) - - # 查询对话信息 - query = """ - SELECT c.*, d.name as dialog_name, d.icon as dialog_icon - FROM conversation c - LEFT JOIN dialog d ON c.dialog_id = d.id - WHERE c.id = %s - """ - cursor.execute(query, (conversation_id,)) - result = cursor.fetchone() - - if not result: - print(f"未找到对话ID: {conversation_id}") - return None - - # 格式化对话详情 - conversation = { - "id": result["id"], - "name": result.get("name", ""), - "dialogId": result.get("dialog_id", ""), - "dialogName": result.get("dialog_name", ""), - "dialogIcon": result.get("dialog_icon", ""), - "createTime": result["create_date"].strftime("%Y-%m-%d %H:%M:%S") if result.get("create_date") else "", - "updateTime": result["update_date"].strftime("%Y-%m-%d %H:%M:%S") if result.get("update_date") else "", - "messages": result.get("message", []), - } - - # 打印调试信息 - print(f"获取到对话详情: ID={conversation_id}") - print(f"消息数量: {len(conversation['messages']) if conversation['messages'] else 0}") - - # 关闭连接 - cursor.close() - conn.close() - - return conversation - - except mysql.connector.Error as err: - print(f"数据库错误: {err}") - # 更详细的错误日志 - import traceback - - traceback.print_exc() - return None - except Exception as e: - print(f"未知错误: {e}") - import traceback - - traceback.print_exc() - return None diff --git a/management/web/src/pages/conversation/index.vue b/management/web/src/pages/conversation/index.vue index c4bf074..f15d7f2 100644 --- a/management/web/src/pages/conversation/index.vue +++ b/management/web/src/pages/conversation/index.vue @@ -1,11 +1,8 @@