refactor: 临时移除对话界面中,上传文件功能
This commit is contained in:
parent
a2d9490b59
commit
8a7174a256
|
@ -47,20 +47,14 @@ class DocumentService(CommonService):
|
|||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords, id, name):
|
||||
def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, id, name):
|
||||
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
||||
if id:
|
||||
docs = docs.where(
|
||||
cls.model.id == id)
|
||||
docs = docs.where(cls.model.id == id)
|
||||
if name:
|
||||
docs = docs.where(
|
||||
cls.model.name == name
|
||||
)
|
||||
docs = docs.where(cls.model.name == name)
|
||||
if keywords:
|
||||
docs = docs.where(
|
||||
fn.LOWER(cls.model.name).contains(keywords.lower())
|
||||
)
|
||||
docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
if desc:
|
||||
docs = docs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
|
@ -72,13 +66,9 @@ class DocumentService(CommonService):
|
|||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords):
|
||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords):
|
||||
if keywords:
|
||||
docs = cls.model.select().where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
docs = cls.model.select().where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
|
||||
else:
|
||||
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
||||
count = docs.count()
|
||||
|
@ -97,8 +87,7 @@ class DocumentService(CommonService):
|
|||
if not cls.save(**doc):
|
||||
raise RuntimeError("Database error (Document)!")
|
||||
e, kb = KnowledgebaseService.get_by_id(doc["kb_id"])
|
||||
if not KnowledgebaseService.update_by_id(
|
||||
kb.id, {"doc_num": kb.doc_num + 1}):
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"doc_num": kb.doc_num + 1}):
|
||||
raise RuntimeError("Database error (Knowledgebase)!")
|
||||
return Document(**doc)
|
||||
|
||||
|
@ -108,14 +97,16 @@ class DocumentService(CommonService):
|
|||
cls.clear_chunk_num(doc.id)
|
||||
try:
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id},
|
||||
{"remove": {"source_id": doc.id}},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
|
||||
{"removed_kwd": "Y"},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.update(
|
||||
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id},
|
||||
{"remove": {"source_id": doc.id}},
|
||||
search.index_name(tenant_id),
|
||||
doc.kb_id,
|
||||
)
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.delete(
|
||||
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}}, search.index_name(tenant_id), doc.kb_id
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return cls.delete_by_id(doc.id)
|
||||
|
@ -136,67 +127,54 @@ class DocumentService(CommonService):
|
|||
Tenant.embd_id,
|
||||
Tenant.img2txt_id,
|
||||
Tenant.asr_id,
|
||||
cls.model.update_time]
|
||||
docs = cls.model.select(*fields) \
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
|
||||
cls.model.update_time,
|
||||
]
|
||||
docs = (
|
||||
cls.model.select(*fields)
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress == 0,
|
||||
cls.model.update_time >= current_timestamp() - 1000 * 600,
|
||||
cls.model.run == TaskStatus.RUNNING.value) \
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress == 0,
|
||||
cls.model.update_time >= current_timestamp() - 1000 * 600,
|
||||
cls.model.run == TaskStatus.RUNNING.value,
|
||||
)
|
||||
.order_by(cls.model.update_time.asc())
|
||||
)
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_unfinished_docs(cls):
|
||||
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
|
||||
cls.model.run, cls.model.parser_id]
|
||||
docs = cls.model.select(*fields) \
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress < 1,
|
||||
cls.model.progress > 0)
|
||||
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id]
|
||||
docs = cls.model.select(*fields).where(cls.model.status == StatusEnum.VALID.value, ~(cls.model.type == FileType.VIRTUAL.value), cls.model.progress < 1, cls.model.progress > 0)
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
||||
num = cls.model.update(token_num=cls.model.token_num + token_num,
|
||||
chunk_num=cls.model.chunk_num + chunk_num,
|
||||
process_duation=cls.model.process_duation + duation).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
num = (
|
||||
cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duation=cls.model.process_duation + duation)
|
||||
.where(cls.model.id == doc_id)
|
||||
.execute()
|
||||
)
|
||||
if num == 0:
|
||||
raise LookupError(
|
||||
"Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num +
|
||||
token_num,
|
||||
chunk_num=Knowledgebase.chunk_num +
|
||||
chunk_num).where(
|
||||
Knowledgebase.id == kb_id).execute()
|
||||
raise LookupError("Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
||||
num = cls.model.update(token_num=cls.model.token_num - token_num,
|
||||
chunk_num=cls.model.chunk_num - chunk_num,
|
||||
process_duation=cls.model.process_duation + duation).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
num = (
|
||||
cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duation=cls.model.process_duation + duation)
|
||||
.where(cls.model.id == doc_id)
|
||||
.execute()
|
||||
)
|
||||
if num == 0:
|
||||
raise LookupError(
|
||||
"Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num -
|
||||
token_num,
|
||||
chunk_num=Knowledgebase.chunk_num -
|
||||
chunk_num
|
||||
).where(
|
||||
Knowledgebase.id == kb_id).execute()
|
||||
raise LookupError("Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
|
@ -205,24 +183,17 @@ class DocumentService(CommonService):
|
|||
doc = cls.model.get_by_id(doc_id)
|
||||
assert doc, "Can't fine document in database."
|
||||
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num -
|
||||
doc.token_num,
|
||||
chunk_num=Knowledgebase.chunk_num -
|
||||
doc.chunk_num,
|
||||
doc_num=Knowledgebase.doc_num - 1
|
||||
).where(
|
||||
Knowledgebase.id == doc.kb_id).execute()
|
||||
num = (
|
||||
Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1)
|
||||
.where(Knowledgebase.id == doc.kb_id)
|
||||
.execute()
|
||||
)
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_id(cls, doc_id):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.tenant_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
|
@ -240,11 +211,7 @@ class DocumentService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_id_by_name(cls, name):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.tenant_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
|
@ -253,12 +220,13 @@ class DocumentService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible(cls, doc_id, user_id):
|
||||
docs = cls.model.select(
|
||||
cls.model.id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)
|
||||
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
||||
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
|
||||
docs = (
|
||||
cls.model.select(cls.model.id)
|
||||
.join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
|
||||
.join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id))
|
||||
.where(cls.model.id == doc_id, UserTenant.user_id == user_id)
|
||||
.paginate(0, 1)
|
||||
)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return False
|
||||
|
@ -267,11 +235,7 @@ class DocumentService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible4deletion(cls, doc_id, user_id):
|
||||
docs = cls.model.select(
|
||||
cls.model.id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)
|
||||
).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
|
||||
docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return False
|
||||
|
@ -280,11 +244,7 @@ class DocumentService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_embd_id(cls, doc_id):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.embd_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = cls.model.select(Knowledgebase.embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
|
@ -306,9 +266,9 @@ class DocumentService(CommonService):
|
|||
Tenant.asr_id,
|
||||
Tenant.llm_id,
|
||||
)
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(cls.model.id == doc_id)
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(cls.model.id == doc_id)
|
||||
)
|
||||
configs = configs.dicts()
|
||||
if not configs:
|
||||
|
@ -319,8 +279,7 @@ class DocumentService(CommonService):
|
|||
@DB.connection_context()
|
||||
def get_doc_id_by_doc_name(cls, doc_name):
|
||||
fields = [cls.model.id]
|
||||
doc_id = cls.model.select(*fields) \
|
||||
.where(cls.model.name == doc_name)
|
||||
doc_id = cls.model.select(*fields).where(cls.model.name == doc_name)
|
||||
doc_id = doc_id.dicts()
|
||||
if not doc_id:
|
||||
return
|
||||
|
@ -330,8 +289,7 @@ class DocumentService(CommonService):
|
|||
@DB.connection_context()
|
||||
def get_thumbnails(cls, docids):
|
||||
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
|
||||
return list(cls.model.select(
|
||||
*fields).where(cls.model.id.in_(docids)).dicts())
|
||||
return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
|
@ -359,19 +317,14 @@ class DocumentService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_doc_count(cls, tenant_id):
|
||||
docs = cls.model.select(cls.model.id).join(Knowledgebase,
|
||||
on=(Knowledgebase.id == cls.model.kb_id)).where(
|
||||
Knowledgebase.tenant_id == tenant_id)
|
||||
docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(Knowledgebase.tenant_id == tenant_id)
|
||||
return len(docs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def begin2parse(cls, docid):
|
||||
cls.update_by_id(
|
||||
docid, {"progress": random.random() * 1 / 100.,
|
||||
"progress_msg": "Task is queued...",
|
||||
"process_begin_at": get_format_time()
|
||||
})
|
||||
cls.update_by_id(docid, {"progress": random.random() * 1 / 100.0, "progress_msg": "Task is queued...", "process_begin_at": get_format_time()})
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_meta_fields(cls, doc_id, meta_fields):
|
||||
|
@ -420,11 +373,7 @@ class DocumentService(CommonService):
|
|||
status = TaskStatus.DONE.value
|
||||
|
||||
msg = "\n".join(sorted(msg))
|
||||
info = {
|
||||
"process_duation": datetime.timestamp(
|
||||
datetime.now()) -
|
||||
d["process_begin_at"].timestamp(),
|
||||
"run": status}
|
||||
info = {"process_duation": datetime.timestamp(datetime.now()) - d["process_begin_at"].timestamp(), "run": status}
|
||||
if prg != 0:
|
||||
info["progress"] = prg
|
||||
if msg:
|
||||
|
@ -437,8 +386,7 @@ class DocumentService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_doc_count(cls, kb_id):
|
||||
return len(cls.model.select(cls.model.id).where(
|
||||
cls.model.kb_id == kb_id).dicts())
|
||||
return len(cls.model.select(cls.model.id).where(cls.model.kb_id == kb_id).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
|
@ -459,14 +407,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty):
|
|||
|
||||
def new_task():
|
||||
nonlocal doc
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc["id"],
|
||||
"from_page": 100000000,
|
||||
"to_page": 100000000,
|
||||
"task_type": ty,
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
|
||||
}
|
||||
return {"id": get_uuid(), "doc_id": doc["id"], "from_page": 100000000, "to_page": 100000000, "task_type": ty, "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty}
|
||||
|
||||
task = new_task()
|
||||
for field in ["doc_id", "from_page", "to_page"]:
|
||||
|
@ -478,6 +419,25 @@ def queue_raptor_o_graphrag_tasks(doc, ty):
|
|||
|
||||
|
||||
def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
"""
|
||||
上传并解析文档,将内容存入知识库
|
||||
|
||||
参数:
|
||||
conversation_id: 会话ID
|
||||
file_objs: 文件对象列表
|
||||
user_id: 用户ID
|
||||
|
||||
返回:
|
||||
处理成功的文档ID列表
|
||||
|
||||
处理流程:
|
||||
1. 验证会话和知识库
|
||||
2. 初始化嵌入模型
|
||||
3. 上传文件到存储
|
||||
4. 多线程解析文件内容
|
||||
5. 生成内容嵌入向量
|
||||
6. 存入文档存储系统
|
||||
"""
|
||||
from rag.app import presentation, picture, naive, audio, email
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.file_service import FileService
|
||||
|
@ -493,8 +453,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not dia.kb_ids:
|
||||
raise LookupError("No knowledge base associated with this conversation. "
|
||||
"Please add a knowledge base before uploading documents")
|
||||
raise LookupError("No knowledge base associated with this conversation. Please add a knowledge base before uploading documents")
|
||||
kb_id = dia.kb_ids[0]
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
|
@ -508,12 +467,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
FACTORY = {
|
||||
ParserType.PRESENTATION.value: presentation,
|
||||
ParserType.PICTURE.value: picture,
|
||||
ParserType.AUDIO.value: audio,
|
||||
ParserType.EMAIL.value: email
|
||||
}
|
||||
FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
|
||||
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
|
||||
# 使用线程池执行解析任务
|
||||
exe = ThreadPoolExecutor(max_workers=12)
|
||||
|
@ -522,22 +476,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
for d, blob in files:
|
||||
doc_nm[d["id"]] = d["name"]
|
||||
for d, blob in files:
|
||||
kwargs = {
|
||||
"callback": dummy,
|
||||
"parser_config": parser_config,
|
||||
"from_page": 0,
|
||||
"to_page": 100000,
|
||||
"tenant_id": kb.tenant_id,
|
||||
"lang": kb.language
|
||||
}
|
||||
kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language}
|
||||
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
|
||||
|
||||
for (docinfo, _), th in zip(files, threads):
|
||||
docs = []
|
||||
doc = {
|
||||
"doc_id": docinfo["id"],
|
||||
"kb_id": [kb.id]
|
||||
}
|
||||
doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]}
|
||||
for ck in th.result():
|
||||
d = deepcopy(doc)
|
||||
d.update(ck)
|
||||
|
@ -552,7 +496,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
if isinstance(d["image"], bytes):
|
||||
output_buffer = BytesIO(d["image"])
|
||||
else:
|
||||
d["image"].save(output_buffer, format='JPEG')
|
||||
d["image"].save(output_buffer, format="JPEG")
|
||||
|
||||
STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(kb.id, d["id"])
|
||||
|
@ -569,9 +513,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
nonlocal embd_mdl, chunk_counts, token_counts
|
||||
vects = []
|
||||
for i in range(0, len(cnts), batch_size):
|
||||
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
|
||||
vts, c = embd_mdl.encode(cnts[i : i + batch_size])
|
||||
vects.extend(vts.tolist())
|
||||
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
|
||||
chunk_counts[doc_id] += len(cnts[i : i + batch_size])
|
||||
token_counts[doc_id] += c
|
||||
return vects
|
||||
|
||||
|
@ -585,22 +529,25 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
|
||||
if parser_ids[doc_id] != ParserType.PICTURE.value:
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
|
||||
mindmap = MindMapExtractor(llm_bdl)
|
||||
try:
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
|
||||
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
|
||||
if len(mind_map) < 32:
|
||||
raise Exception("Few content: " + mind_map)
|
||||
cks.append({
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc_id,
|
||||
"kb_id": [kb.id],
|
||||
"docnm_kwd": doc_nm[doc_id],
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
|
||||
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
|
||||
"content_with_weight": mind_map,
|
||||
"knowledge_graph_kwd": "mind_map"
|
||||
})
|
||||
cks.append(
|
||||
{
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc_id,
|
||||
"kb_id": [kb.id],
|
||||
"docnm_kwd": doc_nm[doc_id],
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
|
||||
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
|
||||
"content_with_weight": mind_map,
|
||||
"knowledge_graph_kwd": "mind_map",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("Mind map generation error")
|
||||
|
||||
|
@ -614,9 +561,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
|
||||
return [d["id"] for d, _ in files]
|
||||
|
|
|
@ -27,7 +27,8 @@ import {
|
|||
UploadProps,
|
||||
} from 'antd';
|
||||
import get from 'lodash/get';
|
||||
import { CircleStop, Paperclip, SendHorizontal } from 'lucide-react';
|
||||
// import { CircleStop, Paperclip, SendHorizontal } from 'lucide-react';
|
||||
import { CircleStop, SendHorizontal } from 'lucide-react';
|
||||
import {
|
||||
ChangeEventHandler,
|
||||
memo,
|
||||
|
@ -339,9 +340,9 @@ const MessageInput = ({
|
|||
return false;
|
||||
}}
|
||||
>
|
||||
<Button type={'primary'} disabled={disabled}>
|
||||
{/* <Button type={'primary'} disabled={disabled}>
|
||||
<Paperclip className="size-4" />
|
||||
</Button>
|
||||
</Button> */}
|
||||
</Upload>
|
||||
)}
|
||||
{sendLoading ? (
|
||||
|
|
Loading…
Reference in New Issue