diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 470f24a..3f665da 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -47,20 +47,14 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() - def get_list(cls, kb_id, page_number, items_per_page, - orderby, desc, keywords, id, name): + def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, id, name): docs = cls.model.select().where(cls.model.kb_id == kb_id) if id: - docs = docs.where( - cls.model.id == id) + docs = docs.where(cls.model.id == id) if name: - docs = docs.where( - cls.model.name == name - ) + docs = docs.where(cls.model.name == name) if keywords: - docs = docs.where( - fn.LOWER(cls.model.name).contains(keywords.lower()) - ) + docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower())) if desc: docs = docs.order_by(cls.model.getter_by(orderby).desc()) else: @@ -72,13 +66,9 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() - def get_by_kb_id(cls, kb_id, page_number, items_per_page, - orderby, desc, keywords): + def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords): if keywords: - docs = cls.model.select().where( - (cls.model.kb_id == kb_id), - (fn.LOWER(cls.model.name).contains(keywords.lower())) - ) + docs = cls.model.select().where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower()))) else: docs = cls.model.select().where(cls.model.kb_id == kb_id) count = docs.count() @@ -97,8 +87,7 @@ class DocumentService(CommonService): if not cls.save(**doc): raise RuntimeError("Database error (Document)!") e, kb = KnowledgebaseService.get_by_id(doc["kb_id"]) - if not KnowledgebaseService.update_by_id( - kb.id, {"doc_num": kb.doc_num + 1}): + if not KnowledgebaseService.update_by_id(kb.id, {"doc_num": kb.doc_num + 1}): raise RuntimeError("Database error (Knowledgebase)!") return Document(**doc) @@ -108,14 +97,16 @@ class DocumentService(CommonService): cls.clear_chunk_num(doc.id) try: settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) - settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id}, - {"remove": {"source_id": doc.id}}, - search.index_name(tenant_id), doc.kb_id) - settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, - {"removed_kwd": "Y"}, - search.index_name(tenant_id), doc.kb_id) - settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}}, - search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.update( + {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id}, + {"remove": {"source_id": doc.id}}, + search.index_name(tenant_id), + doc.kb_id, + ) + settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.delete( + {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}}, search.index_name(tenant_id), doc.kb_id + ) except Exception: pass return cls.delete_by_id(doc.id) @@ -136,67 +127,54 @@ class DocumentService(CommonService): Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, - cls.model.update_time] - docs = cls.model.select(*fields) \ - .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ - .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \ + cls.model.update_time, + ] + docs = ( + cls.model.select(*fields) + .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) + .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) .where( - cls.model.status == StatusEnum.VALID.value, - ~(cls.model.type == FileType.VIRTUAL.value), - cls.model.progress == 0, - cls.model.update_time >= current_timestamp() - 1000 * 600, - cls.model.run == TaskStatus.RUNNING.value) \ + cls.model.status == StatusEnum.VALID.value, + ~(cls.model.type == FileType.VIRTUAL.value), + cls.model.progress == 0, + cls.model.update_time >= current_timestamp() - 1000 * 600, + cls.model.run == TaskStatus.RUNNING.value, + ) .order_by(cls.model.update_time.asc()) + ) return list(docs.dicts()) @classmethod @DB.connection_context() def get_unfinished_docs(cls): - fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, - cls.model.run, cls.model.parser_id] - docs = cls.model.select(*fields) \ - .where( - cls.model.status == StatusEnum.VALID.value, - ~(cls.model.type == FileType.VIRTUAL.value), - cls.model.progress < 1, - cls.model.progress > 0) + fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id] + docs = cls.model.select(*fields).where(cls.model.status == StatusEnum.VALID.value, ~(cls.model.type == FileType.VIRTUAL.value), cls.model.progress < 1, cls.model.progress > 0) return list(docs.dicts()) @classmethod @DB.connection_context() def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): - num = cls.model.update(token_num=cls.model.token_num + token_num, - chunk_num=cls.model.chunk_num + chunk_num, - process_duation=cls.model.process_duation + duation).where( - cls.model.id == doc_id).execute() + num = ( + cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duation=cls.model.process_duation + duation) + .where(cls.model.id == doc_id) + .execute() + ) if num == 0: - raise LookupError( - "Document not found which is supposed to be there") - num = Knowledgebase.update( - token_num=Knowledgebase.token_num + - token_num, - chunk_num=Knowledgebase.chunk_num + - chunk_num).where( - Knowledgebase.id == kb_id).execute() + raise LookupError("Document not found which is supposed to be there") + num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute() return num @classmethod @DB.connection_context() def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): - num = cls.model.update(token_num=cls.model.token_num - token_num, - chunk_num=cls.model.chunk_num - chunk_num, - process_duation=cls.model.process_duation + duation).where( - cls.model.id == doc_id).execute() + num = ( + cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duation=cls.model.process_duation + duation) + .where(cls.model.id == doc_id) + .execute() + ) if num == 0: - raise LookupError( - "Document not found which is supposed to be there") - num = Knowledgebase.update( - token_num=Knowledgebase.token_num - - token_num, - chunk_num=Knowledgebase.chunk_num - - chunk_num - ).where( - Knowledgebase.id == kb_id).execute() + raise LookupError("Document not found which is supposed to be there") + num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute() return num @classmethod @@ -205,24 +183,17 @@ class DocumentService(CommonService): doc = cls.model.get_by_id(doc_id) assert doc, "Can't fine document in database." - num = Knowledgebase.update( - token_num=Knowledgebase.token_num - - doc.token_num, - chunk_num=Knowledgebase.chunk_num - - doc.chunk_num, - doc_num=Knowledgebase.doc_num - 1 - ).where( - Knowledgebase.id == doc.kb_id).execute() + num = ( + Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1) + .where(Knowledgebase.id == doc.kb_id) + .execute() + ) return num @classmethod @DB.connection_context() def get_tenant_id(cls, doc_id): - docs = cls.model.select( - Knowledgebase.tenant_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) + docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: return @@ -240,11 +211,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_tenant_id_by_name(cls, name): - docs = cls.model.select( - Knowledgebase.tenant_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value) + docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: return @@ -253,12 +220,13 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def accessible(cls, doc_id, user_id): - docs = cls.model.select( - cls.model.id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id) - ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) - ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1) + docs = ( + cls.model.select(cls.model.id) + .join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)) + .join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)) + .where(cls.model.id == doc_id, UserTenant.user_id == user_id) + .paginate(0, 1) + ) docs = docs.dicts() if not docs: return False @@ -267,11 +235,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def accessible4deletion(cls, doc_id, user_id): - docs = cls.model.select( - cls.model.id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id) - ).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1) + docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1) docs = docs.dicts() if not docs: return False @@ -280,11 +244,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_embd_id(cls, doc_id): - docs = cls.model.select( - Knowledgebase.embd_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) + docs = cls.model.select(Knowledgebase.embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: return @@ -306,9 +266,9 @@ class DocumentService(CommonService): Tenant.asr_id, Tenant.llm_id, ) - .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) - .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) - .where(cls.model.id == doc_id) + .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) + .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) + .where(cls.model.id == doc_id) ) configs = configs.dicts() if not configs: @@ -319,8 +279,7 @@ class DocumentService(CommonService): @DB.connection_context() def get_doc_id_by_doc_name(cls, doc_name): fields = [cls.model.id] - doc_id = cls.model.select(*fields) \ - .where(cls.model.name == doc_name) + doc_id = cls.model.select(*fields).where(cls.model.name == doc_name) doc_id = doc_id.dicts() if not doc_id: return @@ -330,8 +289,7 @@ class DocumentService(CommonService): @DB.connection_context() def get_thumbnails(cls, docids): fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail] - return list(cls.model.select( - *fields).where(cls.model.id.in_(docids)).dicts()) + return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts()) @classmethod @DB.connection_context() @@ -359,19 +317,14 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_doc_count(cls, tenant_id): - docs = cls.model.select(cls.model.id).join(Knowledgebase, - on=(Knowledgebase.id == cls.model.kb_id)).where( - Knowledgebase.tenant_id == tenant_id) + docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(Knowledgebase.tenant_id == tenant_id) return len(docs) @classmethod @DB.connection_context() def begin2parse(cls, docid): - cls.update_by_id( - docid, {"progress": random.random() * 1 / 100., - "progress_msg": "Task is queued...", - "process_begin_at": get_format_time() - }) + cls.update_by_id(docid, {"progress": random.random() * 1 / 100.0, "progress_msg": "Task is queued...", "process_begin_at": get_format_time()}) + @classmethod @DB.connection_context() def update_meta_fields(cls, doc_id, meta_fields): @@ -420,11 +373,7 @@ class DocumentService(CommonService): status = TaskStatus.DONE.value msg = "\n".join(sorted(msg)) - info = { - "process_duation": datetime.timestamp( - datetime.now()) - - d["process_begin_at"].timestamp(), - "run": status} + info = {"process_duation": datetime.timestamp(datetime.now()) - d["process_begin_at"].timestamp(), "run": status} if prg != 0: info["progress"] = prg if msg: @@ -437,8 +386,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_kb_doc_count(cls, kb_id): - return len(cls.model.select(cls.model.id).where( - cls.model.kb_id == kb_id).dicts()) + return len(cls.model.select(cls.model.id).where(cls.model.kb_id == kb_id).dicts()) @classmethod @DB.connection_context() @@ -459,14 +407,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty): def new_task(): nonlocal doc - return { - "id": get_uuid(), - "doc_id": doc["id"], - "from_page": 100000000, - "to_page": 100000000, - "task_type": ty, - "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty - } + return {"id": get_uuid(), "doc_id": doc["id"], "from_page": 100000000, "to_page": 100000000, "task_type": ty, "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty} task = new_task() for field in ["doc_id", "from_page", "to_page"]: @@ -478,6 +419,25 @@ def queue_raptor_o_graphrag_tasks(doc, ty): def doc_upload_and_parse(conversation_id, file_objs, user_id): + """ + 上传并解析文档,将内容存入知识库 + + 参数: + conversation_id: 会话ID + file_objs: 文件对象列表 + user_id: 用户ID + + 返回: + 处理成功的文档ID列表 + + 处理流程: + 1. 验证会话和知识库 + 2. 初始化嵌入模型 + 3. 上传文件到存储 + 4. 多线程解析文件内容 + 5. 生成内容嵌入向量 + 6. 存入文档存储系统 + """ from rag.app import presentation, picture, naive, audio, email from api.db.services.dialog_service import DialogService from api.db.services.file_service import FileService @@ -493,8 +453,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): e, dia = DialogService.get_by_id(conv.dialog_id) if not dia.kb_ids: - raise LookupError("No knowledge base associated with this conversation. " - "Please add a knowledge base before uploading documents") + raise LookupError("No knowledge base associated with this conversation. Please add a knowledge base before uploading documents") kb_id = dia.kb_ids[0] e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: @@ -508,12 +467,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): def dummy(prog=None, msg=""): pass - FACTORY = { - ParserType.PRESENTATION.value: presentation, - ParserType.PICTURE.value: picture, - ParserType.AUDIO.value: audio, - ParserType.EMAIL.value: email - } + FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email} parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"} # 使用线程池执行解析任务 exe = ThreadPoolExecutor(max_workers=12) @@ -522,22 +476,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): for d, blob in files: doc_nm[d["id"]] = d["name"] for d, blob in files: - kwargs = { - "callback": dummy, - "parser_config": parser_config, - "from_page": 0, - "to_page": 100000, - "tenant_id": kb.tenant_id, - "lang": kb.language - } + kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language} threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs)) for (docinfo, _), th in zip(files, threads): docs = [] - doc = { - "doc_id": docinfo["id"], - "kb_id": [kb.id] - } + doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]} for ck in th.result(): d = deepcopy(doc) d.update(ck) @@ -552,7 +496,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): if isinstance(d["image"], bytes): output_buffer = BytesIO(d["image"]) else: - d["image"].save(output_buffer, format='JPEG') + d["image"].save(output_buffer, format="JPEG") STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue()) d["img_id"] = "{}-{}".format(kb.id, d["id"]) @@ -569,9 +513,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): nonlocal embd_mdl, chunk_counts, token_counts vects = [] for i in range(0, len(cnts), batch_size): - vts, c = embd_mdl.encode(cnts[i: i + batch_size]) + vts, c = embd_mdl.encode(cnts[i : i + batch_size]) vects.extend(vts.tolist()) - chunk_counts[doc_id] += len(cnts[i:i + batch_size]) + chunk_counts[doc_id] += len(cnts[i : i + batch_size]) token_counts[doc_id] += c return vects @@ -585,22 +529,25 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): if parser_ids[doc_id] != ParserType.PICTURE.value: from graphrag.general.mind_map_extractor import MindMapExtractor + mindmap = MindMapExtractor(llm_bdl) try: mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]) mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2) if len(mind_map) < 32: raise Exception("Few content: " + mind_map) - cks.append({ - "id": get_uuid(), - "doc_id": doc_id, - "kb_id": [kb.id], - "docnm_kwd": doc_nm[doc_id], - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])), - "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"), - "content_with_weight": mind_map, - "knowledge_graph_kwd": "mind_map" - }) + cks.append( + { + "id": get_uuid(), + "doc_id": doc_id, + "kb_id": [kb.id], + "docnm_kwd": doc_nm[doc_id], + "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])), + "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"), + "content_with_weight": mind_map, + "knowledge_graph_kwd": "mind_map", + } + ) except Exception as e: logging.exception("Mind map generation error") @@ -614,9 +561,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): if not settings.docStoreConn.indexExist(idxnm, kb_id): settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) try_create_idx = False - settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) + settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id) - DocumentService.increment_chunk_num( - doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) + DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) return [d["id"] for d, _ in files] diff --git a/web/src/components/message-input/index.tsx b/web/src/components/message-input/index.tsx index 529ce73..699cf59 100644 --- a/web/src/components/message-input/index.tsx +++ b/web/src/components/message-input/index.tsx @@ -27,7 +27,8 @@ import { UploadProps, } from 'antd'; import get from 'lodash/get'; -import { CircleStop, Paperclip, SendHorizontal } from 'lucide-react'; +// import { CircleStop, Paperclip, SendHorizontal } from 'lucide-react'; +import { CircleStop, SendHorizontal } from 'lucide-react'; import { ChangeEventHandler, memo, @@ -339,9 +340,9 @@ const MessageInput = ({ return false; }} > - + */} )} {sendLoading ? (