Fix: 增加解析分词,修复召回时关键词判断失效问题 issue#133
This commit is contained in:
commit
401e3d81c4
|
@ -54,3 +54,4 @@ docker/models
|
||||||
management/web/types/auto
|
management/web/types/auto
|
||||||
web/node_modules/.cache/logger/umi.log
|
web/node_modules/.cache/logger/umi.log
|
||||||
management/models--slanet_plus
|
management/models--slanet_plus
|
||||||
|
node_modules/.cache/logger/umi.log
|
||||||
|
|
17
README.md
17
README.md
|
@ -76,14 +76,24 @@ ollama pull bge-m3:latest
|
||||||
|
|
||||||
#### 1. 使用Docker Compose运行
|
#### 1. 使用Docker Compose运行
|
||||||
|
|
||||||
在项目根目录下执行
|
- 使用GPU运行(需保证首张显卡有6GB以上剩余显存):
|
||||||
|
|
||||||
|
1. 在宿主机安装nvidia-container-runtime,让 Docker 自动挂载 GPU 设备和驱动:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo apt install -y nvidia-container-runtime
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 在项目根目录下执行
|
||||||
|
|
||||||
使用GPU运行:
|
|
||||||
```bash
|
```bash
|
||||||
docker compose -f docker/docker-compose_gpu.yml up -d
|
docker compose -f docker/docker-compose_gpu.yml up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
使用CPU运行:
|
- 使用CPU运行:
|
||||||
|
|
||||||
|
在项目根目录下执行
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker compose -f docker/docker-compose.yml up -d
|
docker compose -f docker/docker-compose.yml up -d
|
||||||
```
|
```
|
||||||
|
@ -92,7 +102,6 @@ docker compose -f docker/docker-compose.yml up -d
|
||||||
|
|
||||||
访问地址:`服务器ip:8888`,进入到后台管理界面
|
访问地址:`服务器ip:8888`,进入到后台管理界面
|
||||||
|
|
||||||
图文教程:[https://blog.csdn.net/qq1198768105/article/details/147475488](https://blog.csdn.net/qq1198768105/article/details/147475488)
|
|
||||||
|
|
||||||
#### 2. 源码运行(mysql、minio、es等组件仍需docker启动)
|
#### 2. 源码运行(mysql、minio、es等组件仍需docker启动)
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,8 @@ from flask_login import login_required, current_user
|
||||||
from rag.app.qa import rmPrefix, beAdoc
|
from rag.app.qa import rmPrefix, beAdoc
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import search, rag_tokenizer
|
from rag.nlp import search, rag_tokenizer
|
||||||
from rag.prompts import keyword_extraction
|
|
||||||
|
# from rag.prompts import keyword_extraction, cross_languages
|
||||||
from rag.settings import PAGERANK_FLD
|
from rag.settings import PAGERANK_FLD
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
from api.db import LLMType, ParserType
|
from api.db import LLMType, ParserType
|
||||||
|
@ -37,9 +38,9 @@ import xxhash
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_id")
|
@validate_request("doc_id") # 验证请求中必须包含 doc_id 参数
|
||||||
def list_chunk():
|
def list_chunk():
|
||||||
req = request.json
|
req = request.json
|
||||||
doc_id = req["doc_id"]
|
doc_id = req["doc_id"]
|
||||||
|
@ -54,9 +55,7 @@ def list_chunk():
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||||
query = {
|
query = {"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True}
|
||||||
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
|
|
||||||
}
|
|
||||||
if "available_int" in req:
|
if "available_int" in req:
|
||||||
query["available_int"] = int(req["available_int"])
|
query["available_int"] = int(req["available_int"])
|
||||||
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
||||||
|
@ -64,9 +63,7 @@ def list_chunk():
|
||||||
for id in sres.ids:
|
for id in sres.ids:
|
||||||
d = {
|
d = {
|
||||||
"chunk_id": id,
|
"chunk_id": id,
|
||||||
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[
|
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get("content_with_weight", ""),
|
||||||
id].get(
|
|
||||||
"content_with_weight", ""),
|
|
||||||
"doc_id": sres.field[id]["doc_id"],
|
"doc_id": sres.field[id]["doc_id"],
|
||||||
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
||||||
"important_kwd": sres.field[id].get("important_kwd", []),
|
"important_kwd": sres.field[id].get("important_kwd", []),
|
||||||
|
@ -81,12 +78,11 @@ def list_chunk():
|
||||||
return get_json_result(data=res)
|
return get_json_result(data=res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message='No chunk found!',
|
return get_json_result(data=False, message="No chunk found!", code=settings.RetCode.DATA_ERROR)
|
||||||
code=settings.RetCode.DATA_ERROR)
|
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/get', methods=['GET']) # noqa: F821
|
@manager.route("/get", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def get():
|
def get():
|
||||||
chunk_id = request.args["chunk_id"]
|
chunk_id = request.args["chunk_id"]
|
||||||
|
@ -112,19 +108,16 @@ def get():
|
||||||
return get_json_result(data=chunk)
|
return get_json_result(data=chunk)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("NotFoundError") >= 0:
|
if str(e).find("NotFoundError") >= 0:
|
||||||
return get_json_result(data=False, message='Chunk not found!',
|
return get_json_result(data=False, message="Chunk not found!", code=settings.RetCode.DATA_ERROR)
|
||||||
code=settings.RetCode.DATA_ERROR)
|
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
||||||
def set():
|
def set():
|
||||||
req = request.json
|
req = request.json
|
||||||
d = {
|
d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]}
|
||||||
"id": req["chunk_id"],
|
|
||||||
"content_with_weight": req["content_with_weight"]}
|
|
||||||
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
|
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
|
||||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||||
if "important_kwd" in req:
|
if "important_kwd" in req:
|
||||||
|
@ -153,13 +146,9 @@ def set():
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
|
||||||
if doc.parser_id == ParserType.QA:
|
if doc.parser_id == ParserType.QA:
|
||||||
arr = [
|
arr = [t for t in re.split(r"[\n\t]", req["content_with_weight"]) if len(t) > 1]
|
||||||
t for t in re.split(
|
|
||||||
r"[\n\t]",
|
|
||||||
req["content_with_weight"]) if len(t) > 1]
|
|
||||||
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
||||||
d = beAdoc(d, q, a, not any(
|
d = beAdoc(d, q, a, not any([rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
|
||||||
|
|
||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||||
|
@ -170,7 +159,7 @@ def set():
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/switch', methods=['POST']) # noqa: F821
|
@manager.route("/switch", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("chunk_ids", "available_int", "doc_id")
|
@validate_request("chunk_ids", "available_int", "doc_id")
|
||||||
def switch():
|
def switch():
|
||||||
|
@ -180,20 +169,19 @@ def switch():
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
for cid in req["chunk_ids"]:
|
for cid in req["chunk_ids"]:
|
||||||
if not settings.docStoreConn.update({"id": cid},
|
if not settings.docStoreConn.update({"id": cid}, {"available_int": int(req["available_int"])}, search.index_name(DocumentService.get_tenant_id(req["doc_id"])), doc.kb_id):
|
||||||
{"available_int": int(req["available_int"])},
|
|
||||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
|
||||||
doc.kb_id):
|
|
||||||
return get_data_error_result(message="Index updating failure")
|
return get_data_error_result(message="Index updating failure")
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("chunk_ids", "doc_id")
|
@validate_request("chunk_ids", "doc_id")
|
||||||
def rm():
|
def rm():
|
||||||
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
req = request.json
|
req = request.json
|
||||||
try:
|
try:
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
|
@ -204,19 +192,21 @@ def rm():
|
||||||
deleted_chunk_ids = req["chunk_ids"]
|
deleted_chunk_ids = req["chunk_ids"]
|
||||||
chunk_number = len(deleted_chunk_ids)
|
chunk_number = len(deleted_chunk_ids)
|
||||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||||
|
for cid in deleted_chunk_ids:
|
||||||
|
if STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||||
|
STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
@manager.route("/create", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_id", "content_with_weight")
|
@validate_request("doc_id", "content_with_weight")
|
||||||
def create():
|
def create():
|
||||||
req = request.json
|
req = request.json
|
||||||
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
||||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), "content_with_weight": req["content_with_weight"]}
|
||||||
"content_with_weight": req["content_with_weight"]}
|
|
||||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||||
d["important_kwd"] = req.get("important_kwd", [])
|
d["important_kwd"] = req.get("important_kwd", [])
|
||||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_kwd", [])))
|
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_kwd", [])))
|
||||||
|
@ -252,14 +242,35 @@ def create():
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
DocumentService.increment_chunk_num(
|
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
||||||
doc.id, doc.kb_id, c, 1, 0)
|
|
||||||
return get_json_result(data={"chunk_id": chunck_id})
|
return get_json_result(data={"chunk_id": chunck_id})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821
|
"""
|
||||||
|
{
|
||||||
|
"similarity_threshold": 0.2,
|
||||||
|
"vector_similarity_weight": 0.30000000000000004,
|
||||||
|
"question": "香港",
|
||||||
|
"doc_ids": [],
|
||||||
|
"kb_id": "4b071030bc8e43f1bfb8b7831f320d2f",
|
||||||
|
"page": 1,
|
||||||
|
"size": 10
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"similarity_threshold": 0.2,
|
||||||
|
"vector_similarity_weight": 0.30000000000000004,
|
||||||
|
"question": "显著优势",
|
||||||
|
"doc_ids": [],
|
||||||
|
"kb_id": "1848bc54384611f0b33e4e66786d0323",
|
||||||
|
"page": 1,
|
||||||
|
"size": 10
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/retrieval_test", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("kb_id", "question")
|
@validate_request("kb_id", "question")
|
||||||
def retrieval_test():
|
def retrieval_test():
|
||||||
|
@ -268,57 +279,53 @@ def retrieval_test():
|
||||||
size = int(req.get("size", 30))
|
size = int(req.get("size", 30))
|
||||||
question = req["question"]
|
question = req["question"]
|
||||||
kb_ids = req["kb_id"]
|
kb_ids = req["kb_id"]
|
||||||
|
# 如果kb_ids是字符串,将其转换为列表
|
||||||
if isinstance(kb_ids, str):
|
if isinstance(kb_ids, str):
|
||||||
kb_ids = [kb_ids]
|
kb_ids = [kb_ids]
|
||||||
doc_ids = req.get("doc_ids", [])
|
doc_ids = req.get("doc_ids", [])
|
||||||
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
||||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||||
use_kg = req.get("use_kg", False)
|
|
||||||
top = int(req.get("top_k", 1024))
|
top = int(req.get("top_k", 1024))
|
||||||
|
# langs = req.get("cross_languages", []) # 获取跨语言设定
|
||||||
tenant_ids = []
|
tenant_ids = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 查询当前用户所属的租户
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
|
# 验证知识库权限
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
if KnowledgebaseService.query(
|
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||||
tenant_id=tenant.tenant_id, id=kb_id):
|
|
||||||
tenant_ids.append(tenant.tenant_id)
|
tenant_ids.append(tenant.tenant_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
# 获取知识库信息
|
||||||
code=settings.RetCode.OPERATING_ERROR)
|
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Knowledgebase not found!")
|
return get_data_error_result(message="Knowledgebase not found!")
|
||||||
|
|
||||||
|
# if langs:
|
||||||
|
# question = cross_languages(kb.tenant_id, None, question, langs) # 跨语言处理
|
||||||
|
|
||||||
|
# 加载嵌入模型
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
|
# 加载重排序模型(如果指定)
|
||||||
rerank_mdl = None
|
rerank_mdl = None
|
||||||
if req.get("rerank_id"):
|
if req.get("rerank_id"):
|
||||||
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
||||||
|
|
||||||
if req.get("keyword", False):
|
# 对问题进行标签化
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
# labels = label_question(question, [kb])
|
||||||
question += keyword_extraction(chat_mdl, question)
|
labels = None
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
# 执行检索操作
|
||||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
ranks = settings.retrievaler.retrieval(
|
||||||
similarity_threshold, vector_similarity_weight, top,
|
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
|
|
||||||
rank_feature=labels
|
|
||||||
)
|
)
|
||||||
if use_kg:
|
|
||||||
ck = settings.kg_retrievaler.retrieval(question,
|
|
||||||
tenant_ids,
|
|
||||||
kb_ids,
|
|
||||||
embd_mdl,
|
|
||||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
|
||||||
if ck["content_with_weight"]:
|
|
||||||
ranks["chunks"].insert(0, ck)
|
|
||||||
|
|
||||||
|
# 移除不必要的向量信息
|
||||||
for c in ranks["chunks"]:
|
for c in ranks["chunks"]:
|
||||||
c.pop("vector", None)
|
c.pop("vector", None)
|
||||||
ranks["labels"] = labels
|
ranks["labels"] = labels
|
||||||
|
@ -326,47 +333,5 @@ def retrieval_test():
|
||||||
return get_json_result(data=ranks)
|
return get_json_result(data=ranks)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
return get_json_result(data=False, message="No chunk found! Check the chunk status please!", code=settings.RetCode.DATA_ERROR)
|
||||||
code=settings.RetCode.DATA_ERROR)
|
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
|
|
||||||
@login_required
|
|
||||||
def knowledge_graph():
|
|
||||||
doc_id = request.args["doc_id"]
|
|
||||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
|
||||||
req = {
|
|
||||||
"doc_ids": [doc_id],
|
|
||||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
|
||||||
}
|
|
||||||
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
|
||||||
obj = {"graph": {}, "mind_map": {}}
|
|
||||||
for id in sres.ids[:2]:
|
|
||||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
|
||||||
try:
|
|
||||||
content_json = json.loads(sres.field[id]["content_with_weight"])
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if ty == 'mind_map':
|
|
||||||
node_dict = {}
|
|
||||||
|
|
||||||
def repeat_deal(content_json, node_dict):
|
|
||||||
if 'id' in content_json:
|
|
||||||
if content_json['id'] in node_dict:
|
|
||||||
node_name = content_json['id']
|
|
||||||
content_json['id'] += f"({node_dict[content_json['id']]})"
|
|
||||||
node_dict[node_name] += 1
|
|
||||||
else:
|
|
||||||
node_dict[content_json['id']] = 1
|
|
||||||
if 'children' in content_json and content_json['children']:
|
|
||||||
for item in content_json['children']:
|
|
||||||
repeat_deal(item, node_dict)
|
|
||||||
|
|
||||||
repeat_deal(content_json, node_dict)
|
|
||||||
|
|
||||||
obj[ty] = content_json
|
|
||||||
|
|
||||||
return get_json_result(data=obj)
|
|
||||||
|
|
|
@ -4,9 +4,11 @@ import redis
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
load_dotenv("../../docker/.env")
|
env_path = Path(__file__).parent.parent.parent / "docker" / ".env"
|
||||||
|
load_dotenv(env_path)
|
||||||
|
|
||||||
|
|
||||||
# 检测是否在Docker容器中运行
|
# 检测是否在Docker容器中运行
|
||||||
|
|
|
@ -25,3 +25,4 @@ omegaconf==2.3.0
|
||||||
rapid-table==1.0.3
|
rapid-table==1.0.3
|
||||||
openai==1.70.0
|
openai==1.70.0
|
||||||
redis==6.2.0
|
redis==6.2.0
|
||||||
|
tokenizer==3.4.5
|
|
@ -2,12 +2,12 @@ import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
import json
|
import json
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
import mysql.connector
|
import mysql.connector
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import re
|
import re
|
||||||
import requests
|
import requests
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from database import MINIO_CONFIG, DB_CONFIG, get_minio_client, get_es_client
|
from database import MINIO_CONFIG, DB_CONFIG, get_minio_client, get_es_client
|
||||||
|
@ -17,17 +17,18 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
||||||
from magic_pdf.config.enums import SupportedPdfParseMethod
|
from magic_pdf.config.enums import SupportedPdfParseMethod
|
||||||
from magic_pdf.data.read_api import read_local_office, read_local_images
|
from magic_pdf.data.read_api import read_local_office, read_local_images
|
||||||
from utils import generate_uuid
|
from utils import generate_uuid
|
||||||
|
from .rag_tokenizer import RagTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
tknzr = RagTokenizer()
|
||||||
|
|
||||||
|
|
||||||
# 自定义tokenizer和文本处理函数,替代rag.nlp中的功能
|
|
||||||
def tokenize_text(text):
|
def tokenize_text(text):
|
||||||
"""将文本分词,替代rag_tokenizer功能"""
|
return tknzr.tokenize(text)
|
||||||
# 简单实现,未来可能需要改成更复杂的分词逻辑
|
|
||||||
return text.split()
|
|
||||||
|
|
||||||
|
|
||||||
def merge_chunks(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
def merge_chunks(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||||
"""合并文本块,替代naive_merge功能"""
|
"""合并文本块,替代naive_merge功能(预留函数)"""
|
||||||
if not sections:
|
if not sections:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -149,25 +150,7 @@ def _create_task_record(doc_id, chunk_ids_list):
|
||||||
progress, progress_msg, retry_count, digest, chunk_ids, task_type, priority
|
progress, progress_msg, retry_count, digest, chunk_ids, task_type, priority
|
||||||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
"""
|
"""
|
||||||
task_params = [
|
task_params = [task_id, current_timestamp, current_date_only, current_timestamp, current_date_only, doc_id, 0, 1, None, 0.0, 1.0, "MinerU解析完成", 1, digest, chunk_ids_str, "", 0]
|
||||||
task_id,
|
|
||||||
current_timestamp,
|
|
||||||
current_date_only,
|
|
||||||
current_timestamp,
|
|
||||||
current_date_only,
|
|
||||||
doc_id,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
1.0,
|
|
||||||
"MinerU解析完成",
|
|
||||||
1,
|
|
||||||
digest,
|
|
||||||
chunk_ids_str,
|
|
||||||
"",
|
|
||||||
0
|
|
||||||
]
|
|
||||||
cursor.execute(task_insert, task_params)
|
cursor.execute(task_insert, task_params)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
print(f"[Parser-INFO] Task记录创建成功,Task ID: {task_id}")
|
print(f"[Parser-INFO] Task记录创建成功,Task ID: {task_id}")
|
||||||
|
@ -214,13 +197,13 @@ def process_table_content(content_list):
|
||||||
new_content_list = []
|
new_content_list = []
|
||||||
|
|
||||||
for item in content_list:
|
for item in content_list:
|
||||||
if 'table_body' in item and item['table_body']:
|
if "table_body" in item and item["table_body"]:
|
||||||
# 使用BeautifulSoup解析HTML表格
|
# 使用BeautifulSoup解析HTML表格
|
||||||
soup = BeautifulSoup(item['table_body'], 'html.parser')
|
soup = BeautifulSoup(item["table_body"], "html.parser")
|
||||||
table = soup.find('table')
|
table = soup.find("table")
|
||||||
|
|
||||||
if table:
|
if table:
|
||||||
rows = table.find_all('tr')
|
rows = table.find_all("tr")
|
||||||
# 获取表头(第一行)
|
# 获取表头(第一行)
|
||||||
header_row = rows[0] if rows else None
|
header_row = rows[0] if rows else None
|
||||||
|
|
||||||
|
@ -230,7 +213,7 @@ def process_table_content(content_list):
|
||||||
new_item = item.copy()
|
new_item = item.copy()
|
||||||
|
|
||||||
# 创建只包含当前行的表格
|
# 创建只包含当前行的表格
|
||||||
new_table = soup.new_tag('table')
|
new_table = soup.new_tag("table")
|
||||||
|
|
||||||
# 如果有表头,添加表头
|
# 如果有表头,添加表头
|
||||||
if header_row and i > 0:
|
if header_row and i > 0:
|
||||||
|
@ -241,7 +224,7 @@ def process_table_content(content_list):
|
||||||
|
|
||||||
# 创建新的HTML结构
|
# 创建新的HTML结构
|
||||||
new_html = f"<html><body>{str(new_table)}</body></html>"
|
new_html = f"<html><body>{str(new_table)}</body></html>"
|
||||||
new_item['table_body'] = f"\n\n{new_html}\n\n"
|
new_item["table_body"] = f"\n\n{new_html}\n\n"
|
||||||
|
|
||||||
# 添加到新的内容列表
|
# 添加到新的内容列表
|
||||||
new_content_list.append(new_item)
|
new_content_list.append(new_item)
|
||||||
|
@ -252,6 +235,7 @@ def process_table_content(content_list):
|
||||||
|
|
||||||
return new_content_list
|
return new_content_list
|
||||||
|
|
||||||
|
|
||||||
def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
||||||
"""
|
"""
|
||||||
执行文档解析的核心逻辑
|
执行文档解析的核心逻辑
|
||||||
|
@ -305,7 +289,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
||||||
if normalized_base_url.endswith("/v1"):
|
if normalized_base_url.endswith("/v1"):
|
||||||
# 如果 base_url 已经是 http://host/v1 形式
|
# 如果 base_url 已经是 http://host/v1 形式
|
||||||
embedding_url = normalized_base_url + "/" + endpoint_segment
|
embedding_url = normalized_base_url + "/" + endpoint_segment
|
||||||
elif normalized_base_url.endswith('/embeddings'):
|
elif normalized_base_url.endswith("/embeddings"):
|
||||||
# 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API,无需再进行处理)
|
# 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API,无需再进行处理)
|
||||||
embedding_url = normalized_base_url
|
embedding_url = normalized_base_url
|
||||||
else:
|
else:
|
||||||
|
@ -613,7 +597,6 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
||||||
)
|
)
|
||||||
|
|
||||||
# 准备ES文档
|
# 准备ES文档
|
||||||
content_tokens = tokenize_text(content) # 分词
|
|
||||||
current_time_es = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
current_time_es = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
current_timestamp_es = datetime.now().timestamp()
|
current_timestamp_es = datetime.now().timestamp()
|
||||||
|
|
||||||
|
@ -625,11 +608,11 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
||||||
"doc_id": doc_id,
|
"doc_id": doc_id,
|
||||||
"kb_id": kb_id,
|
"kb_id": kb_id,
|
||||||
"docnm_kwd": doc_info["name"],
|
"docnm_kwd": doc_info["name"],
|
||||||
"title_tks": doc_info["name"],
|
"title_tks": tokenize_text(doc_info["name"]),
|
||||||
"title_sm_tks": doc_info["name"],
|
"title_sm_tks": tokenize_text(doc_info["name"]),
|
||||||
"content_with_weight": content,
|
"content_with_weight": content,
|
||||||
"content_ltks": " ".join(content_tokens), # 字符串类型
|
"content_ltks": tokenize_text(content),
|
||||||
"content_sm_ltks": " ".join(content_tokens), # 字符串类型
|
"content_sm_ltks": tokenize_text(content),
|
||||||
"page_num_int": [page_idx + 1],
|
"page_num_int": [page_idx + 1],
|
||||||
"position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]]
|
"position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]]
|
||||||
"top_int": [1],
|
"top_int": [1],
|
||||||
|
@ -755,7 +738,6 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
||||||
traceback.print_exc() # 打印详细错误堆栈
|
traceback.print_exc() # 打印详细错误堆栈
|
||||||
# 更新文档状态为失败
|
# 更新文档状态为失败
|
||||||
_update_document_progress(doc_id, status="1", run="0", message=error_message, process_duration=process_duration) # status=1表示完成,run=0表示失败
|
_update_document_progress(doc_id, status="1", run="0", message=error_message, process_duration=process_duration) # status=1表示完成,run=0表示失败
|
||||||
# 不抛出异常,让调用者知道任务已结束(但失败)
|
|
||||||
return {"success": False, "error": error_message}
|
return {"success": False, "error": error_message}
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -0,0 +1,412 @@
|
||||||
|
import logging
|
||||||
|
import copy
|
||||||
|
import datrie
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
from hanziconv import HanziConv
|
||||||
|
from nltk import word_tokenize
|
||||||
|
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
||||||
|
|
||||||
|
|
||||||
|
# def get_project_base_directory():
|
||||||
|
# return os.path.abspath(
|
||||||
|
# os.path.join(
|
||||||
|
# os.path.dirname(os.path.realpath(__file__)),
|
||||||
|
# os.pardir,
|
||||||
|
# os.pardir,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
class RagTokenizer:
|
||||||
|
def key_(self, line):
|
||||||
|
return str(line.lower().encode("utf-8"))[2:-1]
|
||||||
|
|
||||||
|
def rkey_(self, line):
|
||||||
|
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
||||||
|
|
||||||
|
def loadDict_(self, fnm):
|
||||||
|
print(f"[HUQIE]:Build trie from {fnm}")
|
||||||
|
try:
|
||||||
|
of = open(fnm, "r", encoding="utf-8")
|
||||||
|
while True:
|
||||||
|
line = of.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
line = re.sub(r"[\r\n]+", "", line)
|
||||||
|
line = re.split(r"[ \t]", line)
|
||||||
|
k = self.key_(line[0])
|
||||||
|
F = int(math.log(float(line[1]) / self.DENOMINATOR) + 0.5)
|
||||||
|
if k not in self.trie_ or self.trie_[k][0] < F:
|
||||||
|
self.trie_[self.key_(line[0])] = (F, line[2])
|
||||||
|
self.trie_[self.rkey_(line[0])] = 1
|
||||||
|
|
||||||
|
dict_file_cache = fnm + ".trie"
|
||||||
|
print(f"[HUQIE]:Build trie cache to {dict_file_cache}")
|
||||||
|
self.trie_.save(dict_file_cache)
|
||||||
|
of.close()
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"[HUQIE]:Build trie {fnm} failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.DENOMINATOR = 1000000
|
||||||
|
self.DIR_ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "res", "huqie")
|
||||||
|
|
||||||
|
self.stemmer = PorterStemmer()
|
||||||
|
self.lemmatizer = WordNetLemmatizer()
|
||||||
|
|
||||||
|
self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)"
|
||||||
|
|
||||||
|
trie_file_name = self.DIR_ + ".txt.trie"
|
||||||
|
# check if trie file existence
|
||||||
|
if os.path.exists(trie_file_name):
|
||||||
|
try:
|
||||||
|
# load trie from file
|
||||||
|
self.trie_ = datrie.Trie.load(trie_file_name)
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
# fail to load trie from file, build default trie
|
||||||
|
logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
|
||||||
|
self.trie_ = datrie.Trie(string.printable)
|
||||||
|
else:
|
||||||
|
# file not exist, build default trie
|
||||||
|
print(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
|
||||||
|
self.trie_ = datrie.Trie(string.printable)
|
||||||
|
|
||||||
|
# load data from dict file and save to trie file
|
||||||
|
self.loadDict_(self.DIR_ + ".txt")
|
||||||
|
|
||||||
|
def _strQ2B(self, ustring):
|
||||||
|
"""全角转半角,转小写"""
|
||||||
|
rstring = ""
|
||||||
|
for uchar in ustring:
|
||||||
|
inside_code = ord(uchar)
|
||||||
|
if inside_code == 0x3000:
|
||||||
|
inside_code = 0x0020
|
||||||
|
else:
|
||||||
|
inside_code -= 0xFEE0
|
||||||
|
if inside_code < 0x0020 or inside_code > 0x7E: # After the conversion, if it's not a half-width character, return the original character.
|
||||||
|
rstring += uchar
|
||||||
|
else:
|
||||||
|
rstring += chr(inside_code)
|
||||||
|
return rstring
|
||||||
|
|
||||||
|
def _tradi2simp(self, line):
|
||||||
|
"""繁体转简体"""
|
||||||
|
return HanziConv.toSimplified(line)
|
||||||
|
|
||||||
|
def dfs_(self, chars, s, preTks, tkslist):
|
||||||
|
res = s
|
||||||
|
# if s > MAX_L or s>= len(chars):
|
||||||
|
if s >= len(chars):
|
||||||
|
tkslist.append(preTks)
|
||||||
|
return res
|
||||||
|
|
||||||
|
# pruning
|
||||||
|
S = s + 1
|
||||||
|
if s + 2 <= len(chars):
|
||||||
|
t1, t2 = "".join(chars[s : s + 1]), "".join(chars[s : s + 2])
|
||||||
|
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
|
||||||
|
S = s + 2
|
||||||
|
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
||||||
|
t1 = preTks[-1][0] + "".join(chars[s : s + 1])
|
||||||
|
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
||||||
|
S = s + 2
|
||||||
|
|
||||||
|
for e in range(S, len(chars) + 1):
|
||||||
|
t = "".join(chars[s:e])
|
||||||
|
k = self.key_(t)
|
||||||
|
|
||||||
|
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
|
||||||
|
break
|
||||||
|
|
||||||
|
if k in self.trie_:
|
||||||
|
pretks = copy.deepcopy(preTks)
|
||||||
|
if k in self.trie_:
|
||||||
|
pretks.append((t, self.trie_[k]))
|
||||||
|
else:
|
||||||
|
pretks.append((t, (-12, "")))
|
||||||
|
res = max(res, self.dfs_(chars, e, pretks, tkslist))
|
||||||
|
|
||||||
|
if res > s:
|
||||||
|
return res
|
||||||
|
|
||||||
|
t = "".join(chars[s : s + 1])
|
||||||
|
k = self.key_(t)
|
||||||
|
if k in self.trie_:
|
||||||
|
preTks.append((t, self.trie_[k]))
|
||||||
|
else:
|
||||||
|
preTks.append((t, (-12, "")))
|
||||||
|
|
||||||
|
return self.dfs_(chars, s + 1, preTks, tkslist)
|
||||||
|
|
||||||
|
def freq(self, tk):
|
||||||
|
k = self.key_(tk)
|
||||||
|
if k not in self.trie_:
|
||||||
|
return 0
|
||||||
|
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
|
||||||
|
|
||||||
|
def score_(self, tfts):
|
||||||
|
B = 30
|
||||||
|
F, L, tks = 0, 0, []
|
||||||
|
for tk, (freq, tag) in tfts:
|
||||||
|
F += freq
|
||||||
|
L += 0 if len(tk) < 2 else 1
|
||||||
|
tks.append(tk)
|
||||||
|
# F /= len(tks)
|
||||||
|
L /= len(tks)
|
||||||
|
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
|
||||||
|
return tks, B / len(tks) + L + F
|
||||||
|
|
||||||
|
def sortTks_(self, tkslist):
|
||||||
|
res = []
|
||||||
|
for tfts in tkslist:
|
||||||
|
tks, s = self.score_(tfts)
|
||||||
|
res.append((tks, s))
|
||||||
|
return sorted(res, key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
def merge_(self, tks):
|
||||||
|
# if split chars is part of token
|
||||||
|
res = []
|
||||||
|
tks = re.sub(r"[ ]+", " ", tks).split()
|
||||||
|
s = 0
|
||||||
|
while True:
|
||||||
|
if s >= len(tks):
|
||||||
|
break
|
||||||
|
E = s + 1
|
||||||
|
for e in range(s + 2, min(len(tks) + 2, s + 6)):
|
||||||
|
tk = "".join(tks[s:e])
|
||||||
|
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
|
||||||
|
E = e
|
||||||
|
res.append("".join(tks[s:E]))
|
||||||
|
s = E
|
||||||
|
|
||||||
|
return " ".join(res)
|
||||||
|
|
||||||
|
def maxForward_(self, line):
|
||||||
|
res = []
|
||||||
|
s = 0
|
||||||
|
while s < len(line):
|
||||||
|
e = s + 1
|
||||||
|
t = line[s:e]
|
||||||
|
while e < len(line) and self.trie_.has_keys_with_prefix(self.key_(t)):
|
||||||
|
e += 1
|
||||||
|
t = line[s:e]
|
||||||
|
|
||||||
|
while e - 1 > s and self.key_(t) not in self.trie_:
|
||||||
|
e -= 1
|
||||||
|
t = line[s:e]
|
||||||
|
|
||||||
|
if self.key_(t) in self.trie_:
|
||||||
|
res.append((t, self.trie_[self.key_(t)]))
|
||||||
|
else:
|
||||||
|
res.append((t, (0, "")))
|
||||||
|
|
||||||
|
s = e
|
||||||
|
|
||||||
|
return self.score_(res)
|
||||||
|
|
||||||
|
def maxBackward_(self, line):
|
||||||
|
res = []
|
||||||
|
s = len(line) - 1
|
||||||
|
while s >= 0:
|
||||||
|
e = s + 1
|
||||||
|
t = line[s:e]
|
||||||
|
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
|
||||||
|
s -= 1
|
||||||
|
t = line[s:e]
|
||||||
|
|
||||||
|
while s + 1 < e and self.key_(t) not in self.trie_:
|
||||||
|
s += 1
|
||||||
|
t = line[s:e]
|
||||||
|
|
||||||
|
if self.key_(t) in self.trie_:
|
||||||
|
res.append((t, self.trie_[self.key_(t)]))
|
||||||
|
else:
|
||||||
|
res.append((t, (0, "")))
|
||||||
|
|
||||||
|
s -= 1
|
||||||
|
|
||||||
|
return self.score_(res[::-1])
|
||||||
|
|
||||||
|
def _split_by_lang(self, line):
|
||||||
|
"""根据语言进行切分"""
|
||||||
|
txt_lang_pairs = []
|
||||||
|
arr = re.split(self.SPLIT_CHAR, line)
|
||||||
|
for a in arr:
|
||||||
|
if not a:
|
||||||
|
continue
|
||||||
|
s = 0
|
||||||
|
e = s + 1
|
||||||
|
zh = is_chinese(a[s])
|
||||||
|
while e < len(a):
|
||||||
|
_zh = is_chinese(a[e])
|
||||||
|
if _zh == zh:
|
||||||
|
e += 1
|
||||||
|
continue
|
||||||
|
txt_lang_pairs.append((a[s:e], zh))
|
||||||
|
s = e
|
||||||
|
e = s + 1
|
||||||
|
zh = _zh
|
||||||
|
if s >= len(a):
|
||||||
|
continue
|
||||||
|
txt_lang_pairs.append((a[s:e], zh))
|
||||||
|
return txt_lang_pairs
|
||||||
|
|
||||||
|
def tokenize(self, line: str) -> str:
|
||||||
|
"""
|
||||||
|
对输入文本进行分词,支持中英文混合处理。
|
||||||
|
|
||||||
|
分词流程:
|
||||||
|
1. 预处理:
|
||||||
|
- 将所有非单词字符(字母、数字、下划线以外的)替换为空格。
|
||||||
|
- 全角字符转半角。
|
||||||
|
- 转换为小写。
|
||||||
|
- 繁体中文转简体中文。
|
||||||
|
2. 按语言切分:
|
||||||
|
- 将预处理后的文本按语言(中文/非中文)分割成多个片段。
|
||||||
|
3. 分段处理:
|
||||||
|
- 对于非中文(通常是英文)片段:
|
||||||
|
- 使用 NLTK 的 `word_tokenize` 进行分词。
|
||||||
|
- 对分词结果进行词干提取 (PorterStemmer) 和词形还原 (WordNetLemmatizer)。
|
||||||
|
- 对于中文片段:
|
||||||
|
- 如果片段过短(长度<2)或为纯粹的英文/数字模式(如 "abc-def", "123.45"),则直接保留该片段。
|
||||||
|
- 否则,采用基于词典的混合分词策略:
|
||||||
|
a. 执行正向最大匹配 (FMM) 和逆向最大匹配 (BMM) 得到两组分词结果 (`tks` 和 `tks1`)。
|
||||||
|
b. 比较 FMM 和 BMM 的结果:
|
||||||
|
i. 找到两者从开头开始最长的相同分词序列,这部分通常是无歧义的,直接加入结果。
|
||||||
|
ii. 对于 FMM 和 BMM 结果不一致的歧义部分(即从第一个不同点开始的子串):
|
||||||
|
- 提取出这段有歧义的原始文本。
|
||||||
|
- 调用 `self.dfs_` (深度优先搜索) 在这段文本上探索所有可能的分词组合。
|
||||||
|
- `self.dfs_` 会利用Trie词典,并由 `self.sortTks_` 对所有组合进行评分和排序。
|
||||||
|
- 选择得分最高的分词方案作为该歧义段落的结果。
|
||||||
|
iii.继续处理 FMM 和 BMM 结果中歧义段落之后的部分,重复步骤 i 和 ii,直到两个序列都处理完毕。
|
||||||
|
c. 如果在比较完所有对应部分后,FMM 或 BMM 仍有剩余(理论上如果实现正确且输入相同,剩余部分也应相同),
|
||||||
|
则对这部分剩余的原始文本同样使用 `self.dfs_` 进行最优分词。
|
||||||
|
4. 后处理:
|
||||||
|
- 将所有处理过的片段(英文词元、中文词元)用空格连接起来。
|
||||||
|
- 调用 `self.merge_` 对连接后的结果进行进一步的合并操作,
|
||||||
|
尝试合并一些可能被错误分割但实际是一个完整词的片段(基于词典检查)。
|
||||||
|
5. 返回最终分词结果字符串(词元间用空格分隔)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
line (str): 待分词的原始输入字符串。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 分词后的字符串,词元之间用空格分隔。
|
||||||
|
"""
|
||||||
|
# 1. 预处理
|
||||||
|
line = re.sub(r"\W+", " ", line) # 将非字母数字下划线替换为空格
|
||||||
|
line = self._strQ2B(line).lower() # 全角转半角,转小写
|
||||||
|
line = self._tradi2simp(line) # 繁体转简体
|
||||||
|
|
||||||
|
# 2. 按语言切分
|
||||||
|
arr = self._split_by_lang(line) # 将文本分割成 (文本片段, 是否为中文) 的列表
|
||||||
|
res = [] # 存储最终分词结果的列表
|
||||||
|
|
||||||
|
# 3. 分段处理
|
||||||
|
for L, lang in arr: # L 是文本片段,lang 是布尔值表示是否为中文
|
||||||
|
if not lang: # 如果不是中文
|
||||||
|
# 使用NLTK进行分词、词干提取和词形还原
|
||||||
|
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
|
||||||
|
continue # 处理下一个片段
|
||||||
|
|
||||||
|
# 如果是中文,但长度小于2或匹配纯英文/数字模式,则直接添加,不进一步切分
|
||||||
|
if len(L) < 2 or re.match(r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
|
||||||
|
res.append(L)
|
||||||
|
continue # 处理下一个片段
|
||||||
|
|
||||||
|
# 对较长的中文片段执行FMM和BMM
|
||||||
|
tks, s = self.maxForward_(L) # tks: FMM结果列表, s: FMM评分 FMM (Forward Maximum Matching - 正向最大匹配)
|
||||||
|
tks1, s1 = self.maxBackward_(L) # tks1: BMM结果列表, s1: BMM评分 BMM (Backward Maximum Matching - 逆向最大匹配)
|
||||||
|
|
||||||
|
# 初始化用于比较FMM和BMM结果的指针
|
||||||
|
i, j = 0, 0 # i 指向 tks1 (BMM), j 指向 tks (FMM)
|
||||||
|
_i, _j = 0, 0 # _i, _j 记录上一段歧义处理的结束位置
|
||||||
|
|
||||||
|
# 3.b.i. 查找 FMM 和 BMM 从头开始的最长相同前缀
|
||||||
|
same = 0 # 相同词元的数量
|
||||||
|
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||||
|
same += 1
|
||||||
|
if same > 0: # 如果存在相同前缀
|
||||||
|
res.append(" ".join(tks[j : j + same])) # 将FMM中的相同部分加入结果
|
||||||
|
|
||||||
|
# 更新指针到不同部分的开始
|
||||||
|
_i = i + same
|
||||||
|
_j = j + same
|
||||||
|
# 准备开始处理可能存在的歧义部分
|
||||||
|
j = _j + 1 # FMM指针向后移动一位(或多位,取决于下面tk的构造)
|
||||||
|
i = _i + 1 # BMM指针向后移动一位(或多位)
|
||||||
|
|
||||||
|
# 3.b.ii. 迭代处理 FMM 和 BMM 结果中的歧义部分
|
||||||
|
while i < len(tks1) and j < len(tks):
|
||||||
|
# tk1 是 BMM 从上一个同步点 _i 到当前指针 i 形成的字符串
|
||||||
|
# tk 是 FMM 从上一个同步点 _j 到当前指针 j 形成的字符串
|
||||||
|
tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
|
||||||
|
|
||||||
|
if tk1 != tk: # 如果这两个子串不相同,说明FMM和BMM的切分路径出现分叉
|
||||||
|
# 尝试通过移动较短子串的指针来寻找下一个可能的同步点
|
||||||
|
if len(tk1) > len(tk):
|
||||||
|
j += 1
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
continue # 继续外层while循环
|
||||||
|
|
||||||
|
# 如果子串相同,但当前位置的单个词元不同,则这也是一个需要DFS解决的歧义点
|
||||||
|
if tks1[i] != tks[j]: # 注意:这里比较的是tks1[i]和tks[j],而不是tk1和tk的最后一个词
|
||||||
|
i += 1
|
||||||
|
j += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 从_j到j (不包括j处的词) 这段 FMM 产生的文本是歧义的,需要用DFS解决。
|
||||||
|
tkslist = []
|
||||||
|
self.dfs_("".join(tks[_j:j]), 0, [], tkslist) # 对这段FMM子串进行DFS
|
||||||
|
if tkslist: # 确保DFS有结果
|
||||||
|
res.append(" ".join(self.sortTks_(tkslist)[0][0])) # 取最优DFS结果
|
||||||
|
|
||||||
|
# 处理当前这个相同的词元 (tks[j] 或 tks1[i]) 以及之后连续相同的词元
|
||||||
|
same = 1
|
||||||
|
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||||
|
same += 1
|
||||||
|
res.append(" ".join(tks[j : j + same])) # 将FMM中从j开始的连续相同部分加入结果
|
||||||
|
|
||||||
|
# 更新指针到下一个不同部分的开始
|
||||||
|
_i = i + same
|
||||||
|
_j = j + same
|
||||||
|
j = _j + 1
|
||||||
|
i = _i + 1
|
||||||
|
|
||||||
|
# 3.c. 处理 FMM 或 BMM 可能的尾部剩余部分
|
||||||
|
# 如果 _i (BMM的已处理指针) 还没有到达 tks1 的末尾
|
||||||
|
# (并且假设 _j (FMM的已处理指针) 也未到 tks 的末尾,且剩余部分代表相同的原始文本)
|
||||||
|
if _i < len(tks1):
|
||||||
|
# 断言确保FMM的已处理指针也未到末尾
|
||||||
|
assert _j < len(tks)
|
||||||
|
# 断言FMM和BMM的剩余部分代表相同的原始字符串
|
||||||
|
assert "".join(tks1[_i:]) == "".join(tks[_j:])
|
||||||
|
# 对FMM的剩余部分(代表了原始文本的尾部)进行DFS分词
|
||||||
|
tkslist = []
|
||||||
|
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
|
||||||
|
if tkslist: # 确保DFS有结果
|
||||||
|
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||||
|
|
||||||
|
# 4. 后处理
|
||||||
|
res_str = " ".join(res) # 将所有分词结果用空格连接
|
||||||
|
return self.merge_(res_str) # 返回经过合并处理的最终分词结果
|
||||||
|
|
||||||
|
|
||||||
|
def is_chinese(s):
|
||||||
|
if s >= "\u4e00" and s <= "\u9fa5":
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tknzr = RagTokenizer()
|
||||||
|
tks = tknzr.tokenize("基于动态视觉相机的光流估计研究_孙文义.pdf")
|
||||||
|
print(tks)
|
||||||
|
tks = tknzr.tokenize("图3-1 事件流输入表征。(a)事件帧;(b)时间面;(c)体素网格\n(a)\n(b)\n(c)")
|
||||||
|
print(tks)
|
|
@ -7,6 +7,7 @@ import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from utils import generate_uuid
|
from utils import generate_uuid
|
||||||
from database import DB_CONFIG
|
from database import DB_CONFIG
|
||||||
|
|
||||||
# 解析相关模块
|
# 解析相关模块
|
||||||
from .document_parser import perform_parse, _update_document_progress
|
from .document_parser import perform_parse, _update_document_progress
|
||||||
|
|
||||||
|
@ -14,15 +15,15 @@ from .document_parser import perform_parse, _update_document_progress
|
||||||
# 结构: { kb_id: {"status": "running/completed/failed", "total": N, "current": M, "message": "...", "start_time": timestamp} }
|
# 结构: { kb_id: {"status": "running/completed/failed", "total": N, "current": M, "message": "...", "start_time": timestamp} }
|
||||||
SEQUENTIAL_BATCH_TASKS = {}
|
SEQUENTIAL_BATCH_TASKS = {}
|
||||||
|
|
||||||
class KnowledgebaseService:
|
|
||||||
|
|
||||||
|
class KnowledgebaseService:
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_db_connection(cls):
|
def _get_db_connection(cls):
|
||||||
"""创建数据库连接"""
|
"""创建数据库连接"""
|
||||||
return mysql.connector.connect(**DB_CONFIG)
|
return mysql.connector.connect(**DB_CONFIG)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_knowledgebase_list(cls, page=1, size=10, name='', sort_by="create_time", sort_order="desc"):
|
def get_knowledgebase_list(cls, page=1, size=10, name="", sort_by="create_time", sort_order="desc"):
|
||||||
"""获取知识库列表"""
|
"""获取知识库列表"""
|
||||||
conn = cls._get_db_connection()
|
conn = cls._get_db_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
@ -65,33 +66,30 @@ class KnowledgebaseService:
|
||||||
# 处理结果
|
# 处理结果
|
||||||
for result in results:
|
for result in results:
|
||||||
# 处理空描述
|
# 处理空描述
|
||||||
if not result.get('description'):
|
if not result.get("description"):
|
||||||
result['description'] = "暂无描述"
|
result["description"] = "暂无描述"
|
||||||
# 处理时间格式
|
# 处理时间格式
|
||||||
if result.get('create_date'):
|
if result.get("create_date"):
|
||||||
if isinstance(result['create_date'], datetime):
|
if isinstance(result["create_date"], datetime):
|
||||||
result['create_date'] = result['create_date'].strftime('%Y-%m-%d %H:%M:%S')
|
result["create_date"] = result["create_date"].strftime("%Y-%m-%d %H:%M:%S")
|
||||||
elif isinstance(result['create_date'], str):
|
elif isinstance(result["create_date"], str):
|
||||||
try:
|
try:
|
||||||
# 尝试解析已有字符串格式
|
# 尝试解析已有字符串格式
|
||||||
datetime.strptime(result['create_date'], '%Y-%m-%d %H:%M:%S')
|
datetime.strptime(result["create_date"], "%Y-%m-%d %H:%M:%S")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
result['create_date'] = ""
|
result["create_date"] = ""
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
count_query = "SELECT COUNT(*) as total FROM knowledgebase"
|
count_query = "SELECT COUNT(*) as total FROM knowledgebase"
|
||||||
if name:
|
if name:
|
||||||
count_query += " WHERE name LIKE %s"
|
count_query += " WHERE name LIKE %s"
|
||||||
cursor.execute(count_query, params[:1] if name else [])
|
cursor.execute(count_query, params[:1] if name else [])
|
||||||
total = cursor.fetchone()['total']
|
total = cursor.fetchone()["total"]
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return {
|
return {"list": results, "total": total}
|
||||||
'list': results,
|
|
||||||
'total': total
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_knowledgebase_detail(cls, kb_id):
|
def get_knowledgebase_detail(cls, kb_id):
|
||||||
|
@ -115,17 +113,17 @@ class KnowledgebaseService:
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
# 处理空描述
|
# 处理空描述
|
||||||
if not result.get('description'):
|
if not result.get("description"):
|
||||||
result['description'] = "暂无描述"
|
result["description"] = "暂无描述"
|
||||||
# 处理时间格式
|
# 处理时间格式
|
||||||
if result.get('create_date'):
|
if result.get("create_date"):
|
||||||
if isinstance(result['create_date'], datetime):
|
if isinstance(result["create_date"], datetime):
|
||||||
result['create_date'] = result['create_date'].strftime('%Y-%m-%d %H:%M:%S')
|
result["create_date"] = result["create_date"].strftime("%Y-%m-%d %H:%M:%S")
|
||||||
elif isinstance(result['create_date'], str):
|
elif isinstance(result["create_date"], str):
|
||||||
try:
|
try:
|
||||||
datetime.strptime(result['create_date'], '%Y-%m-%d %H:%M:%S')
|
datetime.strptime(result["create_date"], "%Y-%m-%d %H:%M:%S")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
result['create_date'] = ""
|
result["create_date"] = ""
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
@ -157,7 +155,7 @@ class KnowledgebaseService:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 检查知识库名称是否已存在
|
# 检查知识库名称是否已存在
|
||||||
exists = cls._check_name_exists(data['name'])
|
exists = cls._check_name_exists(data["name"])
|
||||||
if exists:
|
if exists:
|
||||||
raise Exception("知识库名称已存在")
|
raise Exception("知识库名称已存在")
|
||||||
|
|
||||||
|
@ -165,8 +163,8 @@ class KnowledgebaseService:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 使用传入的 creator_id 作为 tenant_id 和 created_by
|
# 使用传入的 creator_id 作为 tenant_id 和 created_by
|
||||||
tenant_id = data.get('creator_id')
|
tenant_id = data.get("creator_id")
|
||||||
created_by = data.get('creator_id')
|
created_by = data.get("creator_id")
|
||||||
|
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
# 如果没有提供 creator_id,则使用默认值
|
# 如果没有提供 creator_id,则使用默认值
|
||||||
|
@ -181,8 +179,8 @@ class KnowledgebaseService:
|
||||||
earliest_user = cursor.fetchone()
|
earliest_user = cursor.fetchone()
|
||||||
|
|
||||||
if earliest_user:
|
if earliest_user:
|
||||||
tenant_id = earliest_user['id']
|
tenant_id = earliest_user["id"]
|
||||||
created_by = earliest_user['id']
|
created_by = earliest_user["id"]
|
||||||
print(f"使用创建时间最早的用户ID作为tenant_id和created_by: {tenant_id}")
|
print(f"使用创建时间最早的用户ID作为tenant_id和created_by: {tenant_id}")
|
||||||
else:
|
else:
|
||||||
# 如果找不到用户,使用默认值
|
# 如果找不到用户,使用默认值
|
||||||
|
@ -196,10 +194,9 @@ class KnowledgebaseService:
|
||||||
else:
|
else:
|
||||||
print(f"使用传入的 creator_id 作为 tenant_id 和 created_by: {tenant_id}")
|
print(f"使用传入的 creator_id 作为 tenant_id 和 created_by: {tenant_id}")
|
||||||
|
|
||||||
|
|
||||||
# --- 获取动态 embd_id ---
|
# --- 获取动态 embd_id ---
|
||||||
dynamic_embd_id = None
|
dynamic_embd_id = None
|
||||||
default_embd_id = 'bge-m3' # Fallback default
|
default_embd_id = "bge-m3" # Fallback default
|
||||||
try:
|
try:
|
||||||
query_embedding_model = """
|
query_embedding_model = """
|
||||||
SELECT llm_name
|
SELECT llm_name
|
||||||
|
@ -211,8 +208,8 @@ class KnowledgebaseService:
|
||||||
cursor.execute(query_embedding_model)
|
cursor.execute(query_embedding_model)
|
||||||
embedding_model = cursor.fetchone()
|
embedding_model = cursor.fetchone()
|
||||||
|
|
||||||
if embedding_model and embedding_model.get('llm_name'):
|
if embedding_model and embedding_model.get("llm_name"):
|
||||||
dynamic_embd_id = embedding_model['llm_name']
|
dynamic_embd_id = embedding_model["llm_name"]
|
||||||
# 对硅基流动平台进行特异性处理
|
# 对硅基流动平台进行特异性处理
|
||||||
if dynamic_embd_id == "netease-youdao/bce-embedding-base_v1":
|
if dynamic_embd_id == "netease-youdao/bce-embedding-base_v1":
|
||||||
dynamic_embd_id = "BAAI/bge-m3"
|
dynamic_embd_id = "BAAI/bge-m3"
|
||||||
|
@ -226,7 +223,7 @@ class KnowledgebaseService:
|
||||||
traceback.print_exc() # Log the full traceback for debugging
|
traceback.print_exc() # Log the full traceback for debugging
|
||||||
|
|
||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
create_date = current_time.strftime('%Y-%m-%d %H:%M:%S')
|
create_date = current_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
create_time = int(current_time.timestamp() * 1000) # 毫秒级时间戳
|
create_time = int(current_time.timestamp() * 1000) # 毫秒级时间戳
|
||||||
update_date = create_date
|
update_date = create_date
|
||||||
update_time = create_time
|
update_time = create_time
|
||||||
|
@ -249,7 +246,8 @@ class KnowledgebaseService:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 设置默认值
|
# 设置默认值
|
||||||
default_parser_config = json.dumps({
|
default_parser_config = json.dumps(
|
||||||
|
{
|
||||||
"layout_recognize": "MinerU",
|
"layout_recognize": "MinerU",
|
||||||
"chunk_token_num": 512,
|
"chunk_token_num": 512,
|
||||||
"delimiter": "\n!?;。;!?",
|
"delimiter": "\n!?;。;!?",
|
||||||
|
@ -257,11 +255,14 @@ class KnowledgebaseService:
|
||||||
"auto_questions": 0,
|
"auto_questions": 0,
|
||||||
"html4excel": False,
|
"html4excel": False,
|
||||||
"raptor": {"use_raptor": False},
|
"raptor": {"use_raptor": False},
|
||||||
"graphrag": {"use_graphrag": False}
|
"graphrag": {"use_graphrag": False},
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
kb_id = generate_uuid()
|
kb_id = generate_uuid()
|
||||||
cursor.execute(query, (
|
cursor.execute(
|
||||||
|
query,
|
||||||
|
(
|
||||||
kb_id, # id
|
kb_id, # id
|
||||||
create_time, # create_time
|
create_time, # create_time
|
||||||
create_date, # create_date
|
create_date, # create_date
|
||||||
|
@ -269,22 +270,23 @@ class KnowledgebaseService:
|
||||||
update_date, # update_date
|
update_date, # update_date
|
||||||
None, # avatar
|
None, # avatar
|
||||||
tenant_id, # tenant_id
|
tenant_id, # tenant_id
|
||||||
data['name'], # name
|
data["name"], # name
|
||||||
data.get('language', 'Chinese'), # language
|
data.get("language", "Chinese"), # language
|
||||||
data.get('description', ''), # description
|
data.get("description", ""), # description
|
||||||
dynamic_embd_id, # embd_id
|
dynamic_embd_id, # embd_id
|
||||||
data.get('permission', 'me'), # permission
|
data.get("permission", "me"), # permission
|
||||||
created_by, # created_by - 使用内部获取的值
|
created_by, # created_by - 使用内部获取的值
|
||||||
0, # doc_num
|
0, # doc_num
|
||||||
0, # token_num
|
0, # token_num
|
||||||
0, # chunk_num
|
0, # chunk_num
|
||||||
0.7, # similarity_threshold
|
0.7, # similarity_threshold
|
||||||
0.3, # vector_similarity_weight
|
0.3, # vector_similarity_weight
|
||||||
'naive', # parser_id
|
"naive", # parser_id
|
||||||
default_parser_config, # parser_config
|
default_parser_config, # parser_config
|
||||||
0, # pagerank
|
0, # pagerank
|
||||||
'1' # status
|
"1", # status
|
||||||
))
|
),
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
@ -310,8 +312,8 @@ class KnowledgebaseService:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 如果要更新名称,先检查名称是否已存在
|
# 如果要更新名称,先检查名称是否已存在
|
||||||
if data.get('name') and data['name'] != kb['name']:
|
if data.get("name") and data["name"] != kb["name"]:
|
||||||
exists = cls._check_name_exists(data['name'])
|
exists = cls._check_name_exists(data["name"])
|
||||||
if exists:
|
if exists:
|
||||||
raise Exception("知识库名称已存在")
|
raise Exception("知识库名称已存在")
|
||||||
|
|
||||||
|
@ -319,21 +321,21 @@ class KnowledgebaseService:
|
||||||
update_fields = []
|
update_fields = []
|
||||||
params = []
|
params = []
|
||||||
|
|
||||||
if data.get('name'):
|
if data.get("name"):
|
||||||
update_fields.append("name = %s")
|
update_fields.append("name = %s")
|
||||||
params.append(data['name'])
|
params.append(data["name"])
|
||||||
|
|
||||||
if 'description' in data:
|
if "description" in data:
|
||||||
update_fields.append("description = %s")
|
update_fields.append("description = %s")
|
||||||
params.append(data['description'])
|
params.append(data["description"])
|
||||||
|
|
||||||
if 'permission' in data:
|
if "permission" in data:
|
||||||
update_fields.append("permission = %s")
|
update_fields.append("permission = %s")
|
||||||
params.append(data['permission'])
|
params.append(data["permission"])
|
||||||
|
|
||||||
# 更新时间
|
# 更新时间
|
||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
update_date = current_time.strftime('%Y-%m-%d %H:%M:%S')
|
update_date = current_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
update_fields.append("update_date = %s")
|
update_fields.append("update_date = %s")
|
||||||
params.append(update_date)
|
params.append(update_date)
|
||||||
|
|
||||||
|
@ -344,7 +346,7 @@ class KnowledgebaseService:
|
||||||
# 构建并执行更新语句
|
# 构建并执行更新语句
|
||||||
query = f"""
|
query = f"""
|
||||||
UPDATE knowledgebase
|
UPDATE knowledgebase
|
||||||
SET {', '.join(update_fields)}
|
SET {", ".join(update_fields)}
|
||||||
WHERE id = %s
|
WHERE id = %s
|
||||||
"""
|
"""
|
||||||
params.append(kb_id)
|
params.append(kb_id)
|
||||||
|
@ -396,8 +398,7 @@ class KnowledgebaseService:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 检查所有ID是否存在
|
# 检查所有ID是否存在
|
||||||
check_query = "SELECT id FROM knowledgebase WHERE id IN (%s)" % \
|
check_query = "SELECT id FROM knowledgebase WHERE id IN (%s)" % ",".join(["%s"] * len(kb_ids))
|
||||||
','.join(['%s'] * len(kb_ids))
|
|
||||||
cursor.execute(check_query, kb_ids)
|
cursor.execute(check_query, kb_ids)
|
||||||
existing_ids = [row[0] for row in cursor.fetchall()]
|
existing_ids = [row[0] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
@ -406,8 +407,7 @@ class KnowledgebaseService:
|
||||||
raise Exception(f"以下知识库不存在: {', '.join(missing_ids)}")
|
raise Exception(f"以下知识库不存在: {', '.join(missing_ids)}")
|
||||||
|
|
||||||
# 执行批量删除
|
# 执行批量删除
|
||||||
delete_query = "DELETE FROM knowledgebase WHERE id IN (%s)" % \
|
delete_query = "DELETE FROM knowledgebase WHERE id IN (%s)" % ",".join(["%s"] * len(kb_ids))
|
||||||
','.join(['%s'] * len(kb_ids))
|
|
||||||
cursor.execute(delete_query, kb_ids)
|
cursor.execute(delete_query, kb_ids)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
|
@ -420,7 +420,7 @@ class KnowledgebaseService:
|
||||||
raise Exception(f"批量删除知识库失败: {str(e)}")
|
raise Exception(f"批量删除知识库失败: {str(e)}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_knowledgebase_documents(cls, kb_id, page=1, size=10, name='', sort_by="create_time", sort_order="desc"):
|
def get_knowledgebase_documents(cls, kb_id, page=1, size=10, name="", sort_by="create_time", sort_order="desc"):
|
||||||
"""获取知识库下的文档列表"""
|
"""获取知识库下的文档列表"""
|
||||||
try:
|
try:
|
||||||
conn = cls._get_db_connection()
|
conn = cls._get_db_connection()
|
||||||
|
@ -473,8 +473,8 @@ class KnowledgebaseService:
|
||||||
|
|
||||||
# 处理日期时间格式
|
# 处理日期时间格式
|
||||||
for result in results:
|
for result in results:
|
||||||
if result.get('create_date'):
|
if result.get("create_date"):
|
||||||
result['create_date'] = result['create_date'].strftime('%Y-%m-%d %H:%M:%S')
|
result["create_date"] = result["create_date"].strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
count_query = "SELECT COUNT(*) as total FROM document WHERE kb_id = %s"
|
count_query = "SELECT COUNT(*) as total FROM document WHERE kb_id = %s"
|
||||||
|
@ -484,15 +484,12 @@ class KnowledgebaseService:
|
||||||
count_params.append(f"%{name}%")
|
count_params.append(f"%{name}%")
|
||||||
|
|
||||||
cursor.execute(count_query, count_params)
|
cursor.execute(count_query, count_params)
|
||||||
total = cursor.fetchone()['total']
|
total = cursor.fetchone()["total"]
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return {
|
return {"list": results, "total": total}
|
||||||
'list': results,
|
|
||||||
'total': total
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"获取知识库文档列表失败: {str(e)}")
|
print(f"获取知识库文档列表失败: {str(e)}")
|
||||||
|
@ -519,10 +516,10 @@ class KnowledgebaseService:
|
||||||
earliest_user = cursor.fetchone()
|
earliest_user = cursor.fetchone()
|
||||||
|
|
||||||
if earliest_user:
|
if earliest_user:
|
||||||
created_by = earliest_user['id']
|
created_by = earliest_user["id"]
|
||||||
print(f"使用创建时间最早的用户ID: {created_by}")
|
print(f"使用创建时间最早的用户ID: {created_by}")
|
||||||
else:
|
else:
|
||||||
created_by = 'system'
|
created_by = "system"
|
||||||
print("未找到用户, 使用默认用户ID: system")
|
print("未找到用户, 使用默认用户ID: system")
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
@ -543,7 +540,7 @@ class KnowledgebaseService:
|
||||||
SELECT id, name, location, size, type
|
SELECT id, name, location, size, type
|
||||||
FROM file
|
FROM file
|
||||||
WHERE id IN (%s)
|
WHERE id IN (%s)
|
||||||
""" % ','.join(['%s'] * len(file_ids))
|
""" % ",".join(["%s"] * len(file_ids))
|
||||||
|
|
||||||
print(f"[DEBUG] 执行文件查询SQL: {file_query}")
|
print(f"[DEBUG] 执行文件查询SQL: {file_query}")
|
||||||
print(f"[DEBUG] 查询参数: {file_ids}")
|
print(f"[DEBUG] 查询参数: {file_ids}")
|
||||||
|
@ -592,20 +589,18 @@ class KnowledgebaseService:
|
||||||
|
|
||||||
# 设置默认值
|
# 设置默认值
|
||||||
default_parser_id = "naive"
|
default_parser_id = "naive"
|
||||||
default_parser_config = json.dumps({
|
default_parser_config = json.dumps(
|
||||||
|
{
|
||||||
"layout_recognize": "MinerU",
|
"layout_recognize": "MinerU",
|
||||||
"chunk_token_num": 512,
|
"chunk_token_num": 512,
|
||||||
"delimiter": "\n!?;。;!?",
|
"delimiter": "\n!?;。;!?",
|
||||||
"auto_keywords": 0,
|
"auto_keywords": 0,
|
||||||
"auto_questions": 0,
|
"auto_questions": 0,
|
||||||
"html4excel": False,
|
"html4excel": False,
|
||||||
"raptor": {
|
"raptor": {"use_raptor": False},
|
||||||
"use_raptor": False
|
"graphrag": {"use_graphrag": False},
|
||||||
},
|
|
||||||
"graphrag": {
|
|
||||||
"use_graphrag": False
|
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
default_source_type = "local"
|
default_source_type = "local"
|
||||||
|
|
||||||
# 插入document表
|
# 插入document表
|
||||||
|
@ -626,11 +621,30 @@ class KnowledgebaseService:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
doc_params = [
|
doc_params = [
|
||||||
doc_id, create_time, current_date, create_time, current_date, # ID和时间
|
doc_id,
|
||||||
None, kb_id, default_parser_id, default_parser_config, default_source_type, # thumbnail到source_type
|
create_time,
|
||||||
file_type, created_by, file_name, file_location, file_size, # type到size
|
current_date,
|
||||||
0, 0, 0.0, None, None, # token_num到process_begin_at
|
create_time,
|
||||||
0.0, None, '0', '1' # process_duation到status
|
current_date, # ID和时间
|
||||||
|
None,
|
||||||
|
kb_id,
|
||||||
|
default_parser_id,
|
||||||
|
default_parser_config,
|
||||||
|
default_source_type, # thumbnail到source_type
|
||||||
|
file_type,
|
||||||
|
created_by,
|
||||||
|
file_name,
|
||||||
|
file_location,
|
||||||
|
file_size, # type到size
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0.0,
|
||||||
|
None,
|
||||||
|
None, # token_num到process_begin_at
|
||||||
|
0.0,
|
||||||
|
None,
|
||||||
|
"0",
|
||||||
|
"1", # process_duation到status
|
||||||
]
|
]
|
||||||
|
|
||||||
cursor.execute(doc_query, doc_params)
|
cursor.execute(doc_query, doc_params)
|
||||||
|
@ -647,10 +661,7 @@ class KnowledgebaseService:
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
f2d_params = [
|
f2d_params = [f2d_id, create_time, current_date, create_time, current_date, file_id, doc_id]
|
||||||
f2d_id, create_time, current_date, create_time, current_date,
|
|
||||||
file_id, doc_id
|
|
||||||
]
|
|
||||||
|
|
||||||
cursor.execute(f2d_query, f2d_params)
|
cursor.execute(f2d_query, f2d_params)
|
||||||
|
|
||||||
|
@ -673,14 +684,13 @@ class KnowledgebaseService:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return {
|
return {"added_count": added_count}
|
||||||
"added_count": added_count
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] 添加文档失败: {str(e)}")
|
print(f"[ERROR] 添加文档失败: {str(e)}")
|
||||||
print(f"[ERROR] 错误类型: {type(e)}")
|
print(f"[ERROR] 错误类型: {type(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
print(f"[ERROR] 堆栈信息: {traceback.format_exc()}")
|
print(f"[ERROR] 堆栈信息: {traceback.format_exc()}")
|
||||||
raise Exception(f"添加文档到知识库失败: {str(e)}")
|
raise Exception(f"添加文档到知识库失败: {str(e)}")
|
||||||
|
|
||||||
|
@ -757,7 +767,7 @@ class KnowledgebaseService:
|
||||||
f2d_result = cursor.fetchone()
|
f2d_result = cursor.fetchone()
|
||||||
if not f2d_result:
|
if not f2d_result:
|
||||||
raise Exception("无法找到文件到文档的映射关系")
|
raise Exception("无法找到文件到文档的映射关系")
|
||||||
file_id = f2d_result['file_id']
|
file_id = f2d_result["file_id"]
|
||||||
|
|
||||||
file_query = "SELECT parent_id FROM file WHERE id = %s"
|
file_query = "SELECT parent_id FROM file WHERE id = %s"
|
||||||
cursor.execute(file_query, (file_id,))
|
cursor.execute(file_query, (file_id,))
|
||||||
|
@ -770,7 +780,7 @@ class KnowledgebaseService:
|
||||||
conn = None # 确保连接已关闭
|
conn = None # 确保连接已关闭
|
||||||
|
|
||||||
# 2. 更新文档状态为处理中 (使用 parser 模块的函数)
|
# 2. 更新文档状态为处理中 (使用 parser 模块的函数)
|
||||||
_update_document_progress(doc_id, status='2', run='1', progress=0.0, message='开始解析')
|
_update_document_progress(doc_id, status="2", run="1", progress=0.0, message="开始解析")
|
||||||
|
|
||||||
# 3. 调用后台解析函数
|
# 3. 调用后台解析函数
|
||||||
embedding_config = cls.get_system_embedding_config()
|
embedding_config = cls.get_system_embedding_config()
|
||||||
|
@ -783,7 +793,7 @@ class KnowledgebaseService:
|
||||||
print(f"文档解析启动或执行过程中出错 (Doc ID: {doc_id}): {str(e)}")
|
print(f"文档解析启动或执行过程中出错 (Doc ID: {doc_id}): {str(e)}")
|
||||||
# 确保在异常时更新状态为失败
|
# 确保在异常时更新状态为失败
|
||||||
try:
|
try:
|
||||||
_update_document_progress(doc_id, status='1', run='0', message=f"解析失败: {str(e)}")
|
_update_document_progress(doc_id, status="1", run="0", message=f"解析失败: {str(e)}")
|
||||||
except Exception as update_err:
|
except Exception as update_err:
|
||||||
print(f"更新文档失败状态时出错 (Doc ID: {doc_id}): {str(update_err)}")
|
print(f"更新文档失败状态时出错 (Doc ID: {doc_id}): {str(update_err)}")
|
||||||
# raise Exception(f"文档解析失败: {str(e)}")
|
# raise Exception(f"文档解析失败: {str(e)}")
|
||||||
|
@ -808,13 +818,12 @@ class KnowledgebaseService:
|
||||||
return {
|
return {
|
||||||
"task_id": doc_id, # 使用 doc_id 作为任务标识符
|
"task_id": doc_id, # 使用 doc_id 作为任务标识符
|
||||||
"status": "processing",
|
"status": "processing",
|
||||||
"message": "文档解析任务已提交到后台处理"
|
"message": "文档解析任务已提交到后台处理",
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"启动异步解析任务失败 (Doc ID: {doc_id}): {str(e)}")
|
print(f"启动异步解析任务失败 (Doc ID: {doc_id}): {str(e)}")
|
||||||
# 可以在这里尝试更新文档状态为失败
|
|
||||||
try:
|
try:
|
||||||
_update_document_progress(doc_id, status='1', run='0', message=f"启动解析失败: {str(e)}")
|
_update_document_progress(doc_id, status="1", run="0", message=f"启动解析失败: {str(e)}")
|
||||||
except Exception as update_err:
|
except Exception as update_err:
|
||||||
print(f"更新文档启动失败状态时出错 (Doc ID: {doc_id}): {str(update_err)}")
|
print(f"更新文档启动失败状态时出错 (Doc ID: {doc_id}): {str(update_err)}")
|
||||||
raise Exception(f"启动异步解析任务失败: {str(e)}")
|
raise Exception(f"启动异步解析任务失败: {str(e)}")
|
||||||
|
@ -904,26 +913,26 @@ class KnowledgebaseService:
|
||||||
|
|
||||||
payload = {"input": ["Test connection"], "model": model_name}
|
payload = {"input": ["Test connection"], "model": model_name}
|
||||||
|
|
||||||
if not base_url.startswith(('http://', 'https://')):
|
if not base_url.startswith(("http://", "https://")):
|
||||||
base_url = 'http://' + base_url
|
base_url = "http://" + base_url
|
||||||
if not base_url.endswith('/'):
|
if not base_url.endswith("/"):
|
||||||
base_url += '/'
|
base_url += "/"
|
||||||
|
|
||||||
# --- URL 拼接优化 ---
|
# --- URL 拼接优化 ---
|
||||||
endpoint_segment = "embeddings"
|
endpoint_segment = "embeddings"
|
||||||
full_endpoint_path = "v1/embeddings"
|
full_endpoint_path = "v1/embeddings"
|
||||||
# 移除末尾斜杠以方便判断
|
# 移除末尾斜杠以方便判断
|
||||||
normalized_base_url = base_url.rstrip('/')
|
normalized_base_url = base_url.rstrip("/")
|
||||||
|
|
||||||
if normalized_base_url.endswith('/v1'):
|
if normalized_base_url.endswith("/v1"):
|
||||||
# 如果 base_url 已经是 http://host/v1 形式
|
# 如果 base_url 已经是 http://host/v1 形式
|
||||||
current_test_url = normalized_base_url + '/' + endpoint_segment
|
current_test_url = normalized_base_url + "/" + endpoint_segment
|
||||||
elif normalized_base_url.endswith('/embeddings'):
|
elif normalized_base_url.endswith("/embeddings"):
|
||||||
# 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API,无需再进行处理)
|
# 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API,无需再进行处理)
|
||||||
current_test_url = normalized_base_url
|
current_test_url = normalized_base_url
|
||||||
else:
|
else:
|
||||||
# 如果 base_url 是 http://host 或 http://host/api 形式
|
# 如果 base_url 是 http://host 或 http://host/api 形式
|
||||||
current_test_url = normalized_base_url + '/' + full_endpoint_path
|
current_test_url = normalized_base_url + "/" + full_endpoint_path
|
||||||
|
|
||||||
# --- 结束 URL 拼接优化 ---
|
# --- 结束 URL 拼接优化 ---
|
||||||
print(f"尝试请求 URL: {current_test_url}")
|
print(f"尝试请求 URL: {current_test_url}")
|
||||||
|
@ -933,8 +942,9 @@ class KnowledgebaseService:
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
res_json = response.json()
|
res_json = response.json()
|
||||||
if ("data" in res_json and isinstance(res_json["data"], list) and len(res_json["data"]) > 0 and "embedding" in res_json["data"][0] and len(res_json["data"][0]["embedding"]) > 0) or \
|
if (
|
||||||
(isinstance(res_json, list) and len(res_json) > 0 and isinstance(res_json[0], list) and len(res_json[0]) > 0):
|
"data" in res_json and isinstance(res_json["data"], list) and len(res_json["data"]) > 0 and "embedding" in res_json["data"][0] and len(res_json["data"][0]["embedding"]) > 0
|
||||||
|
) or (isinstance(res_json, list) and len(res_json) > 0 and isinstance(res_json[0], list) and len(res_json[0]) > 0):
|
||||||
print(f"连接测试成功: {current_test_url}")
|
print(f"连接测试成功: {current_test_url}")
|
||||||
return True, "连接成功"
|
return True, "连接成功"
|
||||||
else:
|
else:
|
||||||
|
@ -971,13 +981,9 @@ class KnowledgebaseService:
|
||||||
|
|
||||||
if not earliest_user:
|
if not earliest_user:
|
||||||
# 如果没有用户,返回空配置
|
# 如果没有用户,返回空配置
|
||||||
return {
|
return {"llm_name": "", "api_key": "", "api_base": ""}
|
||||||
"llm_name": "",
|
|
||||||
"api_key": "",
|
|
||||||
"api_base": ""
|
|
||||||
}
|
|
||||||
|
|
||||||
earliest_user_id = earliest_user['id']
|
earliest_user_id = earliest_user["id"]
|
||||||
|
|
||||||
# 2. 根据最早用户ID查询 tenant_llm 表中 model_type 为 embedding 的配置
|
# 2. 根据最早用户ID查询 tenant_llm 表中 model_type 为 embedding 的配置
|
||||||
query_embedding_config = """
|
query_embedding_config = """
|
||||||
|
@ -995,8 +1001,8 @@ class KnowledgebaseService:
|
||||||
api_key = config.get("api_key", "")
|
api_key = config.get("api_key", "")
|
||||||
api_base = config.get("api_base", "")
|
api_base = config.get("api_base", "")
|
||||||
# 对模型名称进行处理 (可选,根据需要保留或移除)
|
# 对模型名称进行处理 (可选,根据需要保留或移除)
|
||||||
if llm_name and '___' in llm_name:
|
if llm_name and "___" in llm_name:
|
||||||
llm_name = llm_name.split('___')[0]
|
llm_name = llm_name.split("___")[0]
|
||||||
|
|
||||||
# (对硅基流动平台进行特异性处理)
|
# (对硅基流动平台进行特异性处理)
|
||||||
if llm_name == "netease-youdao/bce-embedding-base_v1":
|
if llm_name == "netease-youdao/bce-embedding-base_v1":
|
||||||
|
@ -1007,18 +1013,10 @@ class KnowledgebaseService:
|
||||||
api_base = "https://api.siliconflow.cn/v1/embeddings"
|
api_base = "https://api.siliconflow.cn/v1/embeddings"
|
||||||
|
|
||||||
# 如果有配置,返回
|
# 如果有配置,返回
|
||||||
return {
|
return {"llm_name": llm_name, "api_key": api_key, "api_base": api_base}
|
||||||
"llm_name": llm_name,
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_base": api_base
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
# 如果最早的用户没有 embedding 配置,返回空
|
# 如果最早的用户没有 embedding 配置,返回空
|
||||||
return {
|
return {"llm_name": "", "api_key": "", "api_base": ""}
|
||||||
"llm_name": "",
|
|
||||||
"api_key": "",
|
|
||||||
"api_base": ""
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"获取系统 Embedding 配置时出错: {e}")
|
print(f"获取系统 Embedding 配置时出错: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
@ -1040,11 +1038,7 @@ class KnowledgebaseService:
|
||||||
|
|
||||||
print(f"开始设置系统 Embedding 配置: {llm_name}, {api_base}, {api_key}")
|
print(f"开始设置系统 Embedding 配置: {llm_name}, {api_base}, {api_key}")
|
||||||
# 执行连接测试
|
# 执行连接测试
|
||||||
is_connected, message = cls._test_embedding_connection(
|
is_connected, message = cls._test_embedding_connection(base_url=api_base, model_name=llm_name, api_key=api_key)
|
||||||
base_url=api_base,
|
|
||||||
model_name=llm_name,
|
|
||||||
api_key=api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_connected:
|
if not is_connected:
|
||||||
# 返回具体的测试失败原因给调用者(路由层)处理
|
# 返回具体的测试失败原因给调用者(路由层)处理
|
||||||
|
@ -1152,8 +1146,8 @@ class KnowledgebaseService:
|
||||||
|
|
||||||
# 按顺序解析每个文档
|
# 按顺序解析每个文档
|
||||||
for i, doc in enumerate(documents_to_parse):
|
for i, doc in enumerate(documents_to_parse):
|
||||||
doc_id = doc['id']
|
doc_id = doc["id"]
|
||||||
doc_name = doc['name']
|
doc_name = doc["name"]
|
||||||
|
|
||||||
# 更新当前进度
|
# 更新当前进度
|
||||||
task_info["current"] = i + 1
|
task_info["current"] = i + 1
|
||||||
|
@ -1172,14 +1166,13 @@ class KnowledgebaseService:
|
||||||
failed_count += 1
|
failed_count += 1
|
||||||
error_msg = result.get("message", "未知错误") if result else "未知错误"
|
error_msg = result.get("message", "未知错误") if result else "未知错误"
|
||||||
print(f"[Seq Batch] KB {kb_id}: Document {doc_id} parsing failed: {error_msg}")
|
print(f"[Seq Batch] KB {kb_id}: Document {doc_id} parsing failed: {error_msg}")
|
||||||
# 即使单个失败,也继续处理下一个
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
failed_count += 1
|
failed_count += 1
|
||||||
print(f"[Seq Batch ERROR] KB {kb_id}: Error calling parse_document for {doc_id}: {str(e)}")
|
print(f"[Seq Batch ERROR] KB {kb_id}: Error calling parse_document for {doc_id}: {str(e)}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
# 尝试更新文档状态为失败,以防 parse_document 内部未处理
|
# 更新文档状态为失败
|
||||||
try:
|
try:
|
||||||
_update_document_progress(doc_id, status='1', run='0', progress=0.0, message=f"批量任务中解析失败: {str(e)[:255]}")
|
_update_document_progress(doc_id, status="1", run="0", progress=0.0, message=f"批量任务中解析失败: {str(e)[:255]}")
|
||||||
except Exception as update_err:
|
except Exception as update_err:
|
||||||
print(f"[Service-ERROR] 更新文档 {doc_id} 失败状态时出错: {str(update_err)}")
|
print(f"[Service-ERROR] 更新文档 {doc_id} 失败状态时出错: {str(update_err)}")
|
||||||
|
|
||||||
|
@ -1189,7 +1182,7 @@ class KnowledgebaseService:
|
||||||
final_message = f"批量顺序解析完成。总计 {total_count} 个,成功 {parsed_count} 个,失败 {failed_count} 个。耗时 {duration} 秒。"
|
final_message = f"批量顺序解析完成。总计 {total_count} 个,成功 {parsed_count} 个,失败 {failed_count} 个。耗时 {duration} 秒。"
|
||||||
task_info["status"] = "completed"
|
task_info["status"] = "completed"
|
||||||
task_info["message"] = final_message
|
task_info["message"] = final_message
|
||||||
task_info["current"] = total_count # 确保 current 等于 total
|
task_info["current"] = total_count
|
||||||
SEQUENTIAL_BATCH_TASKS[kb_id] = task_info
|
SEQUENTIAL_BATCH_TASKS[kb_id] = task_info
|
||||||
print(f"[Seq Batch] KB {kb_id}: {final_message}")
|
print(f"[Seq Batch] KB {kb_id}: {final_message}")
|
||||||
|
|
||||||
|
@ -1217,13 +1210,7 @@ class KnowledgebaseService:
|
||||||
|
|
||||||
# 初始化任务状态
|
# 初始化任务状态
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
SEQUENTIAL_BATCH_TASKS[kb_id] = {
|
SEQUENTIAL_BATCH_TASKS[kb_id] = {"status": "starting", "total": 0, "current": 0, "message": "任务准备启动...", "start_time": start_time}
|
||||||
"status": "starting",
|
|
||||||
"total": 0,
|
|
||||||
"current": 0,
|
|
||||||
"message": "任务准备启动...",
|
|
||||||
"start_time": start_time
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 启动后台线程执行顺序解析逻辑
|
# 启动后台线程执行顺序解析逻辑
|
||||||
|
@ -1239,13 +1226,7 @@ class KnowledgebaseService:
|
||||||
print(f"[Seq Batch ERROR] KB {kb_id}: {error_message}")
|
print(f"[Seq Batch ERROR] KB {kb_id}: {error_message}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
# 更新任务状态为失败
|
# 更新任务状态为失败
|
||||||
SEQUENTIAL_BATCH_TASKS[kb_id] = {
|
SEQUENTIAL_BATCH_TASKS[kb_id] = {"status": "failed", "total": 0, "current": 0, "message": error_message, "start_time": start_time}
|
||||||
"status": "failed",
|
|
||||||
"total": 0,
|
|
||||||
"current": 0,
|
|
||||||
"message": error_message,
|
|
||||||
"start_time": start_time
|
|
||||||
}
|
|
||||||
return {"success": False, "message": error_message}
|
return {"success": False, "message": error_message}
|
||||||
|
|
||||||
# 获取顺序批量解析进度
|
# 获取顺序批量解析进度
|
||||||
|
@ -1294,10 +1275,7 @@ class KnowledgebaseService:
|
||||||
doc["status"] = doc.get("status", "0")
|
doc["status"] = doc.get("status", "0")
|
||||||
doc["run"] = doc.get("run", "0")
|
doc["run"] = doc.get("run", "0")
|
||||||
|
|
||||||
|
return {"documents": documents_status}
|
||||||
return {
|
|
||||||
"documents": documents_status
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"获取知识库 {kb_id} 文档进度失败: {str(e)}")
|
print(f"获取知识库 {kb_id} 文档进度失败: {str(e)}")
|
||||||
|
|
|
@ -47,10 +47,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
eng = lang.lower() == "english"
|
eng = lang.lower() == "english"
|
||||||
res = []
|
res = []
|
||||||
doc = {
|
doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
|
||||||
"docnm_kwd": filename,
|
|
||||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
|
||||||
}
|
|
||||||
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
excel_parser = Excel()
|
excel_parser = Excel()
|
||||||
|
@ -83,11 +80,9 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||||
content = ""
|
content = ""
|
||||||
i += 1
|
i += 1
|
||||||
if len(res) % 999 == 0:
|
if len(res) % 999 == 0:
|
||||||
callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + (
|
callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
|
||||||
|
|
||||||
callback(0.6, ("Extract TAG: {}".format(len(res)) + (
|
callback(0.6, ("Extract TAG: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -110,40 +105,61 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||||
res.append(beAdoc(deepcopy(doc), content, row[1], eng, i))
|
res.append(beAdoc(deepcopy(doc), content, row[1], eng, i))
|
||||||
content = ""
|
content = ""
|
||||||
if len(res) % 999 == 0:
|
if len(res) % 999 == 0:
|
||||||
callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + (
|
callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
|
||||||
|
|
||||||
callback(0.6, ("Extract TAG : {}".format(len(res)) + (
|
callback(0.6, ("Extract TAG : {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("Excel, csv(txt) format files are supported.")
|
||||||
"Excel, csv(txt) format files are supported.")
|
|
||||||
|
|
||||||
|
|
||||||
def label_question(question, kbs):
|
def label_question(question, kbs):
|
||||||
|
"""
|
||||||
|
标记问题的标签。
|
||||||
|
|
||||||
|
该函数通过给定的问题和知识库列表,对问题进行标签标记。它首先确定哪些知识库配置了标签,
|
||||||
|
然后从缓存中获取这些标签,必要时从设置中检索标签。最后,使用这些标签对问题进行标记。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
question (str): 需要标记的问题。
|
||||||
|
kbs (list): 知识库对象列表,用于标签标记。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list: 与问题相关的标签列表。
|
||||||
|
"""
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
|
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
|
||||||
from api import settings
|
from api import settings
|
||||||
|
|
||||||
|
# 初始化标签和标签知识库ID列表
|
||||||
tags = None
|
tags = None
|
||||||
tag_kb_ids = []
|
tag_kb_ids = []
|
||||||
|
|
||||||
|
# 遍历知识库,收集所有标签知识库ID
|
||||||
for kb in kbs:
|
for kb in kbs:
|
||||||
if kb.parser_config.get("tag_kb_ids"):
|
if kb.parser_config.get("tag_kb_ids"):
|
||||||
tag_kb_ids.extend(kb.parser_config["tag_kb_ids"])
|
tag_kb_ids.extend(kb.parser_config["tag_kb_ids"])
|
||||||
|
|
||||||
|
# 如果存在标签知识库ID,则进一步处理
|
||||||
if tag_kb_ids:
|
if tag_kb_ids:
|
||||||
|
# 尝试从缓存中获取所有标签
|
||||||
all_tags = get_tags_from_cache(tag_kb_ids)
|
all_tags = get_tags_from_cache(tag_kb_ids)
|
||||||
|
|
||||||
|
# 如果缓存中没有标签,从设置中检索标签,并设置缓存
|
||||||
if not all_tags:
|
if not all_tags:
|
||||||
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
||||||
set_tags_to_cache(all_tags, tag_kb_ids)
|
set_tags_to_cache(all_tags, tag_kb_ids)
|
||||||
else:
|
else:
|
||||||
|
# 如果缓存中获取到标签,将其解析为JSON格式
|
||||||
all_tags = json.loads(all_tags)
|
all_tags = json.loads(all_tags)
|
||||||
|
|
||||||
|
# 根据标签知识库ID获取对应的标签知识库
|
||||||
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
|
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
|
||||||
tags = settings.retrievaler.tag_query(question,
|
|
||||||
list(set([kb.tenant_id for kb in tag_kbs])),
|
# 使用设置中的检索器对问题进行标签标记
|
||||||
tag_kb_ids,
|
tags = settings.retrievaler.tag_query(question, list(set([kb.tenant_id for kb in tag_kbs])), tag_kb_ids, all_tags, kb.parser_config.get("topn_tags", 3))
|
||||||
all_tags,
|
|
||||||
kb.parser_config.get("topn_tags", 3)
|
# 返回标记的标签
|
||||||
)
|
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
|
||||||
|
@ -152,4 +168,5 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
def dummy(prog=None, msg=""):
|
def dummy(prog=None, msg=""):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
179
rag/nlp/query.py
179
rag/nlp/query.py
|
@ -53,6 +53,16 @@ class FulltextQueryer:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def rmWWW(txt):
|
def rmWWW(txt):
|
||||||
|
"""
|
||||||
|
移除文本中的WWW(WHAT、WHO、WHERE等疑问词)。
|
||||||
|
|
||||||
|
本函数通过一系列正则表达式模式来识别并替换文本中的疑问词,以简化文本或为后续处理做准备。
|
||||||
|
参数:
|
||||||
|
- txt: 待处理的文本字符串。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 处理后的文本字符串,如果所有疑问词都被移除且文本为空,则返回原始文本。
|
||||||
|
"""
|
||||||
patts = [
|
patts = [
|
||||||
(
|
(
|
||||||
r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
||||||
|
@ -61,7 +71,8 @@ class FulltextQueryer:
|
||||||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||||
(
|
(
|
||||||
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
|
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
|
||||||
" ")
|
" ",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
otxt = txt
|
otxt = txt
|
||||||
for r, p in patts:
|
for r, p in patts:
|
||||||
|
@ -70,28 +81,53 @@ class FulltextQueryer:
|
||||||
txt = otxt
|
txt = otxt
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
def question(self, txt, tbl="qa", min_match: float = 0.6):
|
@staticmethod
|
||||||
|
def add_space_between_eng_zh(txt):
|
||||||
"""
|
"""
|
||||||
处理用户问题并生成全文检索表达式
|
在英文和中文之间添加空格。
|
||||||
|
|
||||||
|
该函数通过正则表达式匹配文本中英文和中文相邻的情况,并在它们之间插入空格。
|
||||||
|
这样做可以改善文本的可读性,特别是在混合使用英文和中文时。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
txt: 原始问题文本
|
txt (str): 需要处理的文本字符串。
|
||||||
tbl: 查询表名(默认"qa")
|
|
||||||
min_match: 最小匹配阈值(默认0.6)
|
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
MatchTextExpr: 全文检索表达式对象
|
str: 处理后的文本字符串,其中英文和中文之间添加了空格。
|
||||||
list: 提取的关键词列表
|
|
||||||
"""
|
"""
|
||||||
# 1. 文本预处理:去除特殊字符、繁体转简体、全角转半角、转小写
|
# (ENG/ENG+NUM) + ZH
|
||||||
|
txt = re.sub(r"([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)", r"\1 \2", txt)
|
||||||
|
# ENG + ZH
|
||||||
|
txt = re.sub(r"([A-Za-z])([\u4e00-\u9fa5]+)", r"\1 \2", txt)
|
||||||
|
# ZH + (ENG/ENG+NUM)
|
||||||
|
txt = re.sub(r"([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)", r"\1 \2", txt)
|
||||||
|
txt = re.sub(r"([\u4e00-\u9fa5]+)([A-Za-z])", r"\1 \2", txt)
|
||||||
|
return txt
|
||||||
|
|
||||||
|
def question(self, txt, tbl="qa", min_match: float = 0.6):
|
||||||
|
"""
|
||||||
|
根据输入的文本生成查询表达式,用于在数据库中匹配相关问题。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- txt (str): 输入的文本。
|
||||||
|
- tbl (str): 数据表名,默认为"qa"。
|
||||||
|
- min_match (float): 最小匹配度,默认为0.6。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- MatchTextExpr: 生成的查询表达式对象。
|
||||||
|
- keywords (list): 提取的关键词列表。
|
||||||
|
"""
|
||||||
|
txt = FulltextQueryer.add_space_between_eng_zh(txt) # 在英文和中文之间添加空格
|
||||||
|
# 使用正则表达式替换特殊字符为单个空格,并将文本转换为简体中文和小写
|
||||||
txt = re.sub(
|
txt = re.sub(
|
||||||
r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
|
r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
|
||||||
" ",
|
" ",
|
||||||
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
|
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
|
||||||
).strip()
|
).strip()
|
||||||
txt = FulltextQueryer.rmWWW(txt) # 去除停用词
|
otxt = txt
|
||||||
|
txt = FulltextQueryer.rmWWW(txt)
|
||||||
|
|
||||||
# 2. 非中文文本处理
|
# 如果文本不是中文,则进行英文处理
|
||||||
if not self.isChinese(txt):
|
if not self.isChinese(txt):
|
||||||
txt = FulltextQueryer.rmWWW(txt)
|
txt = FulltextQueryer.rmWWW(txt)
|
||||||
tks = rag_tokenizer.tokenize(txt).split()
|
tks = rag_tokenizer.tokenize(txt).split()
|
||||||
|
@ -106,11 +142,10 @@ class FulltextQueryer:
|
||||||
syn = self.syn.lookup(tk)
|
syn = self.syn.lookup(tk)
|
||||||
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
|
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
|
||||||
keywords.extend(syn)
|
keywords.extend(syn)
|
||||||
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
|
syn = ['"{}"^{:.4f}'.format(s, w / 4.0) for s in syn if s.strip()]
|
||||||
syns.append(" ".join(syn))
|
syns.append(" ".join(syn))
|
||||||
|
|
||||||
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
|
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if tk and not re.match(r"[.^+\(\)-]", tk)]
|
||||||
tk and not re.match(r"[.^+\(\)-]", tk)]
|
|
||||||
for i in range(1, len(tks_w)):
|
for i in range(1, len(tks_w)):
|
||||||
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
|
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
|
||||||
if not left or not right:
|
if not left or not right:
|
||||||
|
@ -126,48 +161,53 @@ class FulltextQueryer:
|
||||||
if not q:
|
if not q:
|
||||||
q.append(txt)
|
q.append(txt)
|
||||||
query = " ".join(q)
|
query = " ".join(q)
|
||||||
return MatchTextExpr(
|
return MatchTextExpr(self.query_fields, query, 100), keywords
|
||||||
self.query_fields, query, 100
|
|
||||||
), keywords
|
|
||||||
|
|
||||||
def need_fine_grained_tokenize(tk):
|
def need_fine_grained_tokenize(tk):
|
||||||
"""
|
"""
|
||||||
判断是否需要细粒度分词
|
判断是否需要对词进行细粒度分词。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
tk: 待判断的词条
|
- tk (str): 待判断的词。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
bool: True表示需要细粒度分词
|
- bool: 是否需要进行细粒度分词。
|
||||||
"""
|
"""
|
||||||
|
# 长度小于3的词不处理
|
||||||
if len(tk) < 3:
|
if len(tk) < 3:
|
||||||
return False
|
return False
|
||||||
|
# 匹配特定模式的词不处理(如数字、字母、符号组合)
|
||||||
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
|
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
txt = FulltextQueryer.rmWWW(txt) # 二次去除停用词
|
txt = FulltextQueryer.rmWWW(txt)
|
||||||
qs, keywords = [], [] # 初始化查询表达式和关键词列表
|
qs, keywords = [], []
|
||||||
# 3. 中文文本处理(最多处理256个词)
|
# 遍历文本分割后的前256个片段(防止处理过长文本)
|
||||||
for tt in self.tw.split(txt)[:256]: # .split():
|
for tt in self.tw.split(txt)[:256]: # 注:这个split似乎是对英文设计,中文不起作用
|
||||||
if not tt:
|
if not tt:
|
||||||
continue
|
continue
|
||||||
# 3.1 基础关键词收集
|
# 将当前片段加入关键词列表
|
||||||
keywords.append(tt)
|
keywords.append(tt)
|
||||||
twts = self.tw.weights([tt]) # 获取词权重
|
# 获取当前片段的权重
|
||||||
syns = self.syn.lookup(tt) # 查询同义词
|
twts = self.tw.weights([tt])
|
||||||
# 3.2 同义词扩展(最多扩展到32个关键词)
|
# 查找同义词
|
||||||
|
syns = self.syn.lookup(tt)
|
||||||
|
# 如果有同义词且关键词数量未超过32,将同义词加入关键词列表
|
||||||
if syns and len(keywords) < 32:
|
if syns and len(keywords) < 32:
|
||||||
keywords.extend(syns)
|
keywords.extend(syns)
|
||||||
|
# 调试日志:输出权重信息
|
||||||
logging.debug(json.dumps(twts, ensure_ascii=False))
|
logging.debug(json.dumps(twts, ensure_ascii=False))
|
||||||
|
# 初始化查询条件列表
|
||||||
tms = []
|
tms = []
|
||||||
# 3.3 处理每个词及其权重
|
# 按权重降序排序处理每个token
|
||||||
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
||||||
# 3.3.1 细粒度分词处理
|
# 如果需要细粒度分词,则进行分词处理
|
||||||
sm = (
|
sm = rag_tokenizer.fine_grained_tokenize(tk).split() if need_fine_grained_tokenize(tk) else []
|
||||||
rag_tokenizer.fine_grained_tokenize(tk).split()
|
# 对每个分词结果进行清洗:
|
||||||
if need_fine_grained_tokenize(tk)
|
# 1. 去除标点符号和特殊字符
|
||||||
else []
|
# 2. 使用subSpecialChar进一步处理
|
||||||
)
|
# 3. 过滤掉长度<=1的词
|
||||||
# 3.3.2 清洗分词结果
|
|
||||||
sm = [
|
sm = [
|
||||||
re.sub(
|
re.sub(
|
||||||
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
||||||
|
@ -178,59 +218,65 @@ class FulltextQueryer:
|
||||||
]
|
]
|
||||||
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
||||||
sm = [m for m in sm if len(m) > 1]
|
sm = [m for m in sm if len(m) > 1]
|
||||||
# 3.3.3 收集关键词(不超过32个)
|
|
||||||
if len(keywords) < 32:
|
|
||||||
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
|
||||||
keywords.extend(sm)
|
|
||||||
|
|
||||||
# 3.3.4 同义词处理
|
# 如果关键词数量未达上限,添加处理后的token和分词结果
|
||||||
|
if len(keywords) < 32:
|
||||||
|
keywords.append(re.sub(r"[ \\\"']+", "", tk)) # 去除转义字符
|
||||||
|
keywords.extend(sm) # 添加分词结果
|
||||||
|
# 获取当前token的同义词并进行处理
|
||||||
tk_syns = self.syn.lookup(tk)
|
tk_syns = self.syn.lookup(tk)
|
||||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||||
|
# 添加有效同义词到关键词列表
|
||||||
if len(keywords) < 32:
|
if len(keywords) < 32:
|
||||||
keywords.extend([s for s in tk_syns if s])
|
keywords.extend([s for s in tk_syns if s])
|
||||||
|
# 对同义词进行分词处理,并为包含空格的同义词添加引号
|
||||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
tk_syns = [f'"{s}"' if s.find(" ") > 0 else s for s in tk_syns]
|
||||||
# 关键词数量限制
|
|
||||||
|
# 关键词数量达到上限则停止处理
|
||||||
if len(keywords) >= 32:
|
if len(keywords) >= 32:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 3.3.5 构建查询表达式
|
# 处理当前token用于构建查询条件:
|
||||||
|
# 1. 特殊字符处理
|
||||||
|
# 2. 为包含空格的token添加引号
|
||||||
|
# 3. 如果有同义词,构建OR条件并降低权重
|
||||||
|
# 4. 如果有分词结果,添加OR条件
|
||||||
tk = FulltextQueryer.subSpecialChar(tk)
|
tk = FulltextQueryer.subSpecialChar(tk)
|
||||||
if tk.find(" ") > 0:
|
if tk.find(" ") > 0:
|
||||||
tk = '"%s"' % tk # 处理短语查询
|
tk = '"%s"' % tk
|
||||||
if tk_syns:
|
if tk_syns:
|
||||||
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns) # 添加同义词查询
|
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
|
||||||
if sm:
|
if sm:
|
||||||
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm)) # 添加细粒度分词查询
|
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
|
||||||
if tk.strip():
|
if tk.strip():
|
||||||
tms.append((tk, w)) # 保存带权重的查询表达式
|
tms.append((tk, w))
|
||||||
|
|
||||||
# 3.4 合并当前词的查询表达式
|
# 将处理后的查询条件按权重组合成字符串
|
||||||
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
||||||
|
|
||||||
# 3.5 添加相邻词组合查询(提升短语匹配权重)
|
# 如果有多个权重项,添加短语搜索条件(提高相邻词匹配的权重)
|
||||||
if len(twts) > 1:
|
if len(twts) > 1:
|
||||||
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
|
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
|
||||||
|
|
||||||
# 3.6 处理同义词查询表达式
|
# 处理同义词的查询条件
|
||||||
syns = " OR ".join(
|
syns = " OR ".join(['"%s"' % rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s)) for s in syns])
|
||||||
[
|
# 组合主查询条件和同义词条件
|
||||||
'"%s"'
|
|
||||||
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
|
|
||||||
for s in syns
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if syns and tms:
|
if syns and tms:
|
||||||
tms = f"({tms})^5 OR ({syns})^0.7"
|
tms = f"({tms})^5 OR ({syns})^0.7"
|
||||||
|
# 将最终查询条件加入列表
|
||||||
|
qs.append(tms)
|
||||||
|
|
||||||
qs.append(tms) # 添加到最终查询列表
|
# 处理所有查询条件
|
||||||
|
|
||||||
# 4. 生成最终查询表达式
|
|
||||||
if qs:
|
if qs:
|
||||||
|
# 组合所有查询条件为OR关系
|
||||||
query = " OR ".join([f"({t})" for t in qs if t])
|
query = " OR ".join([f"({t})" for t in qs if t])
|
||||||
return MatchTextExpr(
|
# 如果查询条件为空,使用原始文本
|
||||||
self.query_fields, query, 100, {"minimum_should_match": min_match}
|
if not query:
|
||||||
), keywords
|
query = otxt
|
||||||
|
# 返回匹配文本表达式和关键词
|
||||||
|
return MatchTextExpr(self.query_fields, query, 100, {"minimum_should_match": min_match}), keywords
|
||||||
|
# 如果没有生成查询条件,只返回关键词
|
||||||
return None, keywords
|
return None, keywords
|
||||||
|
|
||||||
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
|
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
|
||||||
|
@ -282,7 +328,7 @@ class FulltextQueryer:
|
||||||
tk_syns = self.syn.lookup(tk)
|
tk_syns = self.syn.lookup(tk)
|
||||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
tk_syns = [f'"{s}"' if s.find(" ") > 0 else s for s in tk_syns]
|
||||||
tk = FulltextQueryer.subSpecialChar(tk)
|
tk = FulltextQueryer.subSpecialChar(tk)
|
||||||
if tk.find(" ") > 0:
|
if tk.find(" ") > 0:
|
||||||
tk = '"%s"' % tk
|
tk = '"%s"' % tk
|
||||||
|
@ -291,5 +337,4 @@ class FulltextQueryer:
|
||||||
if tk:
|
if tk:
|
||||||
keywords.append(f"{tk}^{w}")
|
keywords.append(f"{tk}^{w}")
|
||||||
|
|
||||||
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
|
return MatchTextExpr(self.query_fields, " ".join(keywords), 100, {"minimum_should_match": min(3, len(keywords) / 10)})
|
||||||
{"minimum_should_match": min(3, len(keywords) / 10)})
|
|
||||||
|
|
|
@ -22,9 +22,12 @@ import os
|
||||||
import re
|
import re
|
||||||
import string
|
import string
|
||||||
import sys
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
from hanziconv import HanziConv
|
from hanziconv import HanziConv
|
||||||
from nltk import word_tokenize
|
from nltk import word_tokenize
|
||||||
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +41,7 @@ class RagTokenizer:
|
||||||
def loadDict_(self, fnm):
|
def loadDict_(self, fnm):
|
||||||
logging.info(f"[HUQIE]:Build trie from {fnm}")
|
logging.info(f"[HUQIE]:Build trie from {fnm}")
|
||||||
try:
|
try:
|
||||||
of = open(fnm, "r", encoding='utf-8')
|
of = open(fnm, "r", encoding="utf-8")
|
||||||
while True:
|
while True:
|
||||||
line = of.readline()
|
line = of.readline()
|
||||||
if not line:
|
if not line:
|
||||||
|
@ -46,7 +49,7 @@ class RagTokenizer:
|
||||||
line = re.sub(r"[\r\n]+", "", line)
|
line = re.sub(r"[\r\n]+", "", line)
|
||||||
line = re.split(r"[ \t]", line)
|
line = re.split(r"[ \t]", line)
|
||||||
k = self.key_(line[0])
|
k = self.key_(line[0])
|
||||||
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
|
F = int(math.log(float(line[1]) / self.DENOMINATOR) + 0.5)
|
||||||
if k not in self.trie_ or self.trie_[k][0] < F:
|
if k not in self.trie_ or self.trie_[k][0] < F:
|
||||||
self.trie_[self.key_(line[0])] = (F, line[2])
|
self.trie_[self.key_(line[0])] = (F, line[2])
|
||||||
self.trie_[self.rkey_(line[0])] = 1
|
self.trie_[self.rkey_(line[0])] = 1
|
||||||
|
@ -106,8 +109,8 @@ class RagTokenizer:
|
||||||
if inside_code == 0x3000:
|
if inside_code == 0x3000:
|
||||||
inside_code = 0x0020
|
inside_code = 0x0020
|
||||||
else:
|
else:
|
||||||
inside_code -= 0xfee0
|
inside_code -= 0xFEE0
|
||||||
if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
|
if inside_code < 0x0020 or inside_code > 0x7E: # After the conversion, if it's not a half-width character, return the original character.
|
||||||
rstring += uchar
|
rstring += uchar
|
||||||
else:
|
else:
|
||||||
rstring += chr(inside_code)
|
rstring += chr(inside_code)
|
||||||
|
@ -127,11 +130,9 @@ class RagTokenizer:
|
||||||
S = s + 1
|
S = s + 1
|
||||||
if s + 2 <= len(chars):
|
if s + 2 <= len(chars):
|
||||||
t1, t2 = "".join(chars[s : s + 1]), "".join(chars[s : s + 2])
|
t1, t2 = "".join(chars[s : s + 1]), "".join(chars[s : s + 2])
|
||||||
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
|
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
|
||||||
self.key_(t2)):
|
|
||||||
S = s + 2
|
S = s + 2
|
||||||
if len(preTks) > 2 and len(
|
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
||||||
preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
|
||||||
t1 = preTks[-1][0] + "".join(chars[s : s + 1])
|
t1 = preTks[-1][0] + "".join(chars[s : s + 1])
|
||||||
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
||||||
S = s + 2
|
S = s + 2
|
||||||
|
@ -149,7 +150,7 @@ class RagTokenizer:
|
||||||
if k in self.trie_:
|
if k in self.trie_:
|
||||||
pretks.append((t, self.trie_[k]))
|
pretks.append((t, self.trie_[k]))
|
||||||
else:
|
else:
|
||||||
pretks.append((t, (-12, '')))
|
pretks.append((t, (-12, "")))
|
||||||
res = max(res, self.dfs_(chars, e, pretks, tkslist))
|
res = max(res, self.dfs_(chars, e, pretks, tkslist))
|
||||||
|
|
||||||
if res > s:
|
if res > s:
|
||||||
|
@ -160,7 +161,7 @@ class RagTokenizer:
|
||||||
if k in self.trie_:
|
if k in self.trie_:
|
||||||
preTks.append((t, self.trie_[k]))
|
preTks.append((t, self.trie_[k]))
|
||||||
else:
|
else:
|
||||||
preTks.append((t, (-12, '')))
|
preTks.append((t, (-12, "")))
|
||||||
|
|
||||||
return self.dfs_(chars, s + 1, preTks, tkslist)
|
return self.dfs_(chars, s + 1, preTks, tkslist)
|
||||||
|
|
||||||
|
@ -219,8 +220,7 @@ class RagTokenizer:
|
||||||
while s < len(line):
|
while s < len(line):
|
||||||
e = s + 1
|
e = s + 1
|
||||||
t = line[s:e]
|
t = line[s:e]
|
||||||
while e < len(line) and self.trie_.has_keys_with_prefix(
|
while e < len(line) and self.trie_.has_keys_with_prefix(self.key_(t)):
|
||||||
self.key_(t)):
|
|
||||||
e += 1
|
e += 1
|
||||||
t = line[s:e]
|
t = line[s:e]
|
||||||
|
|
||||||
|
@ -231,7 +231,7 @@ class RagTokenizer:
|
||||||
if self.key_(t) in self.trie_:
|
if self.key_(t) in self.trie_:
|
||||||
res.append((t, self.trie_[self.key_(t)]))
|
res.append((t, self.trie_[self.key_(t)]))
|
||||||
else:
|
else:
|
||||||
res.append((t, (0, '')))
|
res.append((t, (0, "")))
|
||||||
|
|
||||||
s = e
|
s = e
|
||||||
|
|
||||||
|
@ -254,7 +254,7 @@ class RagTokenizer:
|
||||||
if self.key_(t) in self.trie_:
|
if self.key_(t) in self.trie_:
|
||||||
res.append((t, self.trie_[self.key_(t)]))
|
res.append((t, self.trie_[self.key_(t)]))
|
||||||
else:
|
else:
|
||||||
res.append((t, (0, '')))
|
res.append((t, (0, "")))
|
||||||
|
|
||||||
s -= 1
|
s -= 1
|
||||||
|
|
||||||
|
@ -297,8 +297,7 @@ class RagTokenizer:
|
||||||
if not lang:
|
if not lang:
|
||||||
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
|
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
|
||||||
continue
|
continue
|
||||||
if len(L) < 2 or re.match(
|
if len(L) < 2 or re.match(r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
|
||||||
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
|
|
||||||
res.append(L)
|
res.append(L)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -359,31 +358,64 @@ class RagTokenizer:
|
||||||
return self.merge_(res)
|
return self.merge_(res)
|
||||||
|
|
||||||
def fine_grained_tokenize(self, tks):
|
def fine_grained_tokenize(self, tks):
|
||||||
|
"""
|
||||||
|
细粒度分词方法,根据文本特征(中英文比例、数字符号等)动态选择分词策略
|
||||||
|
|
||||||
|
参数:
|
||||||
|
tks (str): 待分词的文本字符串
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 分词后的结果(用空格连接的词序列)
|
||||||
|
|
||||||
|
处理逻辑:
|
||||||
|
1. 先按空格初步切分文本
|
||||||
|
2. 根据中文占比决定是否启用细粒度分词
|
||||||
|
3. 对特殊格式(短词、纯数字等)直接保留原样
|
||||||
|
4. 对长词或复杂词使用DFS回溯算法寻找最优切分
|
||||||
|
5. 对英文词进行额外校验和规范化处理
|
||||||
|
"""
|
||||||
|
# 初始切分:按空格分割输入文本
|
||||||
tks = tks.split()
|
tks = tks.split()
|
||||||
|
# 计算中文词占比(判断是否主要包含中文内容)
|
||||||
zh_num = len([1 for c in tks if c and is_chinese(c[0])])
|
zh_num = len([1 for c in tks if c and is_chinese(c[0])])
|
||||||
|
# 如果中文占比低于20%,则按简单规则处理(主要处理英文混合文本)
|
||||||
if zh_num < len(tks) * 0.2:
|
if zh_num < len(tks) * 0.2:
|
||||||
res = []
|
res = []
|
||||||
for tk in tks:
|
for tk in tks:
|
||||||
res.extend(tk.split("/"))
|
res.extend(tk.split("/"))
|
||||||
return " ".join(res)
|
return " ".join(res)
|
||||||
|
|
||||||
|
# 中文或复杂文本处理流程
|
||||||
res = []
|
res = []
|
||||||
for tk in tks:
|
for tk in tks:
|
||||||
|
# 规则1:跳过短词(长度<3)或纯数字/符号组合(如"3.14")
|
||||||
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
|
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
|
||||||
res.append(tk)
|
res.append(tk)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# 初始化候选分词列表
|
||||||
tkslist = []
|
tkslist = []
|
||||||
|
|
||||||
|
# 规则2:超长词(长度>10)直接保留不切分
|
||||||
if len(tk) > 10:
|
if len(tk) > 10:
|
||||||
tkslist.append(tk)
|
tkslist.append(tk)
|
||||||
else:
|
else:
|
||||||
|
# 使用DFS回溯算法寻找所有可能的分词组合
|
||||||
self.dfs_(tk, 0, [], tkslist)
|
self.dfs_(tk, 0, [], tkslist)
|
||||||
|
|
||||||
|
# 规则3:若无有效切分方案则保留原词
|
||||||
if len(tkslist) < 2:
|
if len(tkslist) < 2:
|
||||||
res.append(tk)
|
res.append(tk)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# 从候选方案中选择最优切分(通过sortTks_排序)
|
||||||
stk = self.sortTks_(tkslist)[1][0]
|
stk = self.sortTks_(tkslist)[1][0]
|
||||||
|
|
||||||
|
# 规则4:若切分结果与原词长度相同则视为无效切分
|
||||||
if len(stk) == len(tk):
|
if len(stk) == len(tk):
|
||||||
stk = tk
|
stk = tk
|
||||||
else:
|
else:
|
||||||
|
# 英文特殊处理:检查子词长度是否合法
|
||||||
if re.match(r"[a-z\.-]+$", tk):
|
if re.match(r"[a-z\.-]+$", tk):
|
||||||
for t in stk:
|
for t in stk:
|
||||||
if len(t) < 3:
|
if len(t) < 3:
|
||||||
|
@ -393,29 +425,28 @@ class RagTokenizer:
|
||||||
stk = " ".join(stk)
|
stk = " ".join(stk)
|
||||||
else:
|
else:
|
||||||
stk = " ".join(stk)
|
stk = " ".join(stk)
|
||||||
|
# 中文词直接拼接结果
|
||||||
res.append(stk)
|
res.append(stk)
|
||||||
|
|
||||||
return " ".join(self.english_normalize_(res))
|
return " ".join(self.english_normalize_(res))
|
||||||
|
|
||||||
|
|
||||||
def is_chinese(s):
|
def is_chinese(s):
|
||||||
if s >= u'\u4e00' and s <= u'\u9fa5':
|
if s >= "\u4e00" and s <= "\u9fa5":
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_number(s):
|
def is_number(s):
|
||||||
if s >= u'\u0030' and s <= u'\u0039':
|
if s >= "\u0030" and s <= "\u0039":
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_alphabet(s):
|
def is_alphabet(s):
|
||||||
if (s >= u'\u0041' and s <= u'\u005a') or (
|
if (s >= "\u0041" and s <= "\u005a") or (s >= "\u0061" and s <= "\u007a"):
|
||||||
s >= u'\u0061' and s <= u'\u007a'):
|
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
@ -424,8 +455,7 @@ def is_alphabet(s):
|
||||||
def naiveQie(txt):
|
def naiveQie(txt):
|
||||||
tks = []
|
tks = []
|
||||||
for t in txt.split():
|
for t in txt.split():
|
||||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
|
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and re.match(r".*[a-zA-Z]$", t):
|
||||||
) and re.match(r".*[a-zA-Z]$", t):
|
|
||||||
tks.append(" ")
|
tks.append(" ")
|
||||||
tks.append(t)
|
tks.append(t)
|
||||||
return tks
|
return tks
|
||||||
|
@ -441,43 +471,41 @@ addUserDict = tokenizer.addUserDict
|
||||||
tradi2simp = tokenizer._tradi2simp
|
tradi2simp = tokenizer._tradi2simp
|
||||||
strQ2B = tokenizer._strQ2B
|
strQ2B = tokenizer._strQ2B
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
tknzr = RagTokenizer(debug=True)
|
tknzr = RagTokenizer(debug=True)
|
||||||
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
|
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
|
||||||
|
tks = tknzr.tokenize("哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
|
||||||
|
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||||
|
print(tks)
|
||||||
tks = tknzr.tokenize(
|
tks = tknzr.tokenize(
|
||||||
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
|
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。"
|
||||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
)
|
||||||
tks = tknzr.tokenize(
|
print(tks)
|
||||||
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
|
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
# tks = tknzr.tokenize("多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
|
||||||
tks = tknzr.tokenize(
|
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||||
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
|
# tks = tknzr.tokenize("实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
|
||||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||||
tks = tknzr.tokenize(
|
# tks = tknzr.tokenize("虽然我不怎么玩")
|
||||||
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
|
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
# tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
|
||||||
tks = tknzr.tokenize("虽然我不怎么玩")
|
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
# tks = tknzr.tokenize("涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
|
||||||
tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
|
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
# tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
|
||||||
tks = tknzr.tokenize(
|
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||||
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
|
# tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
|
||||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||||
tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
|
# tks = tknzr.tokenize("数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
|
||||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||||
tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
|
# if len(sys.argv) < 2:
|
||||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
# sys.exit()
|
||||||
tks = tknzr.tokenize(
|
# tknzr.DEBUG = False
|
||||||
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
|
# tknzr.loadUserDict(sys.argv[1])
|
||||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
# of = open(sys.argv[2], "r")
|
||||||
if len(sys.argv) < 2:
|
# while True:
|
||||||
sys.exit()
|
# line = of.readline()
|
||||||
tknzr.DEBUG = False
|
# if not line:
|
||||||
tknzr.loadUserDict(sys.argv[1])
|
# break
|
||||||
of = open(sys.argv[2], "r")
|
# logging.info(tknzr.tokenize(line))
|
||||||
while True:
|
# of.close()
|
||||||
line = of.readline()
|
|
||||||
if not line:
|
|
||||||
break
|
|
||||||
logging.info(tknzr.tokenize(line))
|
|
||||||
of.close()
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||||
|
@ -24,7 +25,8 @@ import numpy as np
|
||||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
|
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
|
||||||
|
|
||||||
|
|
||||||
def index_name(uid): return f"ragflow_{uid}"
|
def index_name(uid):
|
||||||
|
return f"ragflow_{uid}"
|
||||||
|
|
||||||
|
|
||||||
class Dealer:
|
class Dealer:
|
||||||
|
@ -47,11 +49,10 @@ class Dealer:
|
||||||
qv, _ = emb_mdl.encode_queries(txt)
|
qv, _ = emb_mdl.encode_queries(txt)
|
||||||
shape = np.array(qv).shape
|
shape = np.array(qv).shape
|
||||||
if len(shape) > 1:
|
if len(shape) > 1:
|
||||||
raise Exception(
|
raise Exception(f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
|
||||||
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
|
|
||||||
embedding_data = [float(v) for v in qv]
|
embedding_data = [float(v) for v in qv]
|
||||||
vector_column_name = f"q_{len(embedding_data)}_vec"
|
vector_column_name = f"q_{len(embedding_data)}_vec"
|
||||||
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
|
return MatchDenseExpr(vector_column_name, embedding_data, "float", "cosine", topk, {"similarity": similarity})
|
||||||
|
|
||||||
def get_filters(self, req):
|
def get_filters(self, req):
|
||||||
condition = dict()
|
condition = dict()
|
||||||
|
@ -64,12 +65,7 @@ class Dealer:
|
||||||
condition[key] = req[key]
|
condition[key] = req[key]
|
||||||
return condition
|
return condition
|
||||||
|
|
||||||
def search(self, req, idx_names: str | list[str],
|
def search(self, req, idx_names: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False, rank_feature: dict | None = None):
|
||||||
kb_ids: list[str],
|
|
||||||
emb_mdl=None,
|
|
||||||
highlight=False,
|
|
||||||
rank_feature: dict | None = None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
执行混合检索(全文检索+向量检索)
|
执行混合检索(全文检索+向量检索)
|
||||||
|
|
||||||
|
@ -108,18 +104,37 @@ class Dealer:
|
||||||
offset, limit = pg * ps, ps
|
offset, limit = pg * ps, ps
|
||||||
|
|
||||||
# 3. 设置返回字段(默认包含文档名、内容等核心字段)
|
# 3. 设置返回字段(默认包含文档名、内容等核心字段)
|
||||||
src = req.get("fields",
|
src = req.get(
|
||||||
["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int",
|
"fields",
|
||||||
"doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd",
|
[
|
||||||
"question_kwd", "question_tks",
|
"docnm_kwd",
|
||||||
"available_int", "content_with_weight", PAGERANK_FLD, TAG_FLD])
|
"content_ltks",
|
||||||
kwds = set([])
|
"kb_id",
|
||||||
|
"img_id",
|
||||||
|
"title_tks",
|
||||||
|
"important_kwd",
|
||||||
|
"position_int",
|
||||||
|
"doc_id",
|
||||||
|
"page_num_int",
|
||||||
|
"top_int",
|
||||||
|
"create_timestamp_flt",
|
||||||
|
"knowledge_graph_kwd",
|
||||||
|
"question_kwd",
|
||||||
|
"question_tks",
|
||||||
|
"available_int",
|
||||||
|
"content_with_weight",
|
||||||
|
PAGERANK_FLD,
|
||||||
|
TAG_FLD,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
kwds = set([]) # 初始化关键词集合
|
||||||
|
|
||||||
# 4. 处理查询问题
|
# 4. 处理查询问题
|
||||||
qst = req.get("question", "")
|
qst = req.get("question", "") # 获取查询问题文本
|
||||||
q_vec = []
|
print(f"收到前端问题:{qst}")
|
||||||
|
q_vec = [] # 初始化查询向量(如需向量检索)
|
||||||
if not qst:
|
if not qst:
|
||||||
# 4.1 无查询文本时的处理(按文档排序)
|
# 4.1 若查询文本为空,执行默认排序检索(通常用于无搜索条件浏览)(注:前端测试检索时会禁止空文本的提交)
|
||||||
if req.get("sort"):
|
if req.get("sort"):
|
||||||
orderBy.asc("page_num_int")
|
orderBy.asc("page_num_int")
|
||||||
orderBy.asc("top_int")
|
orderBy.asc("top_int")
|
||||||
|
@ -128,16 +143,16 @@ class Dealer:
|
||||||
total = self.dataStore.getTotal(res)
|
total = self.dataStore.getTotal(res)
|
||||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||||
else:
|
else:
|
||||||
# 4.2 有查询文本时的处理
|
# 4.2 若存在查询文本,进入全文/混合检索流程
|
||||||
highlightFields = ["content_ltks", "title_tks"] if highlight else []
|
highlightFields = ["content_ltks", "title_tks"] if highlight else [] # highlight当前会一直为False,不起作用
|
||||||
|
|
||||||
# 4.2.1 生成全文检索表达式和关键词
|
# 4.2.1 生成全文检索表达式和关键词
|
||||||
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
||||||
|
print(f"matchText.matching_text: {matchText.matching_text}")
|
||||||
|
print(f"keywords: {keywords}\n")
|
||||||
if emb_mdl is None:
|
if emb_mdl is None:
|
||||||
# 4.2.2 纯全文检索模式
|
# 4.2.2 纯全文检索模式 (未提供向量模型,正常情况不会进入)
|
||||||
matchExprs = [matchText]
|
matchExprs = [matchText]
|
||||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
||||||
idx_names, kb_ids, rank_feature=rank_feature)
|
|
||||||
total = self.dataStore.getTotal(res)
|
total = self.dataStore.getTotal(res)
|
||||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||||
else:
|
else:
|
||||||
|
@ -145,27 +160,41 @@ class Dealer:
|
||||||
# 生成查询向量
|
# 生成查询向量
|
||||||
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
||||||
q_vec = matchDense.embedding_data
|
q_vec = matchDense.embedding_data
|
||||||
|
# 在返回字段中加入查询向量字段
|
||||||
src.append(f"q_{len(q_vec)}_vec")
|
src.append(f"q_{len(q_vec)}_vec")
|
||||||
# 设置混合检索权重(全文5% + 向量95%)
|
# 创建融合表达式:设置向量匹配为95%,全文为5%(可以调整权重)
|
||||||
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
|
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
|
||||||
|
# 构建混合查询表达式
|
||||||
matchExprs = [matchText, matchDense, fusionExpr]
|
matchExprs = [matchText, matchDense, fusionExpr]
|
||||||
|
|
||||||
# 执行混合检索
|
# 执行混合检索
|
||||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
||||||
idx_names, kb_ids, rank_feature=rank_feature)
|
|
||||||
total = self.dataStore.getTotal(res)
|
total = self.dataStore.getTotal(res)
|
||||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||||
|
|
||||||
# If result is empty, try again with lower min_match
|
print(f"共查询到: {total} 条信息")
|
||||||
|
# print(f"查询信息结果: {res}\n")
|
||||||
|
|
||||||
|
# 若未找到结果,则尝试降低匹配门槛后重试
|
||||||
if total == 0:
|
if total == 0:
|
||||||
|
if filters.get("doc_id"):
|
||||||
|
# 有特定文档ID时执行无条件查询
|
||||||
|
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||||
|
total = self.dataStore.getTotal(res)
|
||||||
|
print(f"针对选中文档,共查询到: {total} 条信息")
|
||||||
|
# print(f"查询信息结果: {res}\n")
|
||||||
|
else:
|
||||||
|
# 否则调整全文和向量匹配参数再次搜索
|
||||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||||
filters.pop("doc_ids", None)
|
filters.pop("doc_id", None)
|
||||||
matchDense.extra_options["similarity"] = 0.17
|
matchDense.extra_options["similarity"] = 0.17
|
||||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
||||||
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
|
||||||
total = self.dataStore.getTotal(res)
|
total = self.dataStore.getTotal(res)
|
||||||
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
||||||
|
print(f"再次查询,共查询到: {total} 条信息")
|
||||||
|
# print(f"查询信息结果: {res}\n")
|
||||||
|
|
||||||
|
# 4.3 处理关键词(对关键词进行更细粒度的切词)
|
||||||
for k in keywords:
|
for k in keywords:
|
||||||
kwds.add(k)
|
kwds.add(k)
|
||||||
for kk in rag_tokenizer.fine_grained_tokenize(k).split():
|
for kk in rag_tokenizer.fine_grained_tokenize(k).split():
|
||||||
|
@ -175,27 +204,23 @@ class Dealer:
|
||||||
continue
|
continue
|
||||||
kwds.add(kk)
|
kwds.add(kk)
|
||||||
|
|
||||||
|
# 5. 提取检索结果中的ID、字段、聚合和高亮信息
|
||||||
logging.debug(f"TOTAL: {total}")
|
logging.debug(f"TOTAL: {total}")
|
||||||
ids = self.dataStore.getChunkIds(res)
|
ids = self.dataStore.getChunkIds(res) # 提取匹配chunk的ID
|
||||||
keywords = list(kwds)
|
keywords = list(kwds) # 转为列表格式返回
|
||||||
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
|
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") # 获取高亮内容
|
||||||
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
|
aggs = self.dataStore.getAggregation(res, "docnm_kwd") # 执行基于文档名的聚合分析
|
||||||
return self.SearchResult(
|
print(f"ids:{ids}")
|
||||||
total=total,
|
print(f"keywords:{keywords}")
|
||||||
ids=ids,
|
print(f"highlight:{highlight}")
|
||||||
query_vector=q_vec,
|
print(f"aggs:{aggs}")
|
||||||
aggregation=aggs,
|
return self.SearchResult(total=total, ids=ids, query_vector=q_vec, aggregation=aggs, highlight=highlight, field=self.dataStore.getFields(res, src), keywords=keywords)
|
||||||
highlight=highlight,
|
|
||||||
field=self.dataStore.getFields(res, src),
|
|
||||||
keywords=keywords
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trans2floats(txt):
|
def trans2floats(txt):
|
||||||
return [float(t) for t in txt.split("\t")]
|
return [float(t) for t in txt.split("\t")]
|
||||||
|
|
||||||
def insert_citations(self, answer, chunks, chunk_v,
|
def insert_citations(self, answer, chunks, chunk_v, embd_mdl, tkweight=0.1, vtweight=0.9):
|
||||||
embd_mdl, tkweight=0.1, vtweight=0.9):
|
|
||||||
assert len(chunks) == len(chunk_v)
|
assert len(chunks) == len(chunk_v)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return answer, set([])
|
return answer, set([])
|
||||||
|
@ -213,10 +238,7 @@ class Dealer:
|
||||||
i += 1
|
i += 1
|
||||||
pieces_.append("".join(pieces[st:i]) + "\n")
|
pieces_.append("".join(pieces[st:i]) + "\n")
|
||||||
else:
|
else:
|
||||||
pieces_.extend(
|
pieces_.extend(re.split(r"([^\|][;。?!!\n]|[a-z][.?;!][ \n])", pieces[i]))
|
||||||
re.split(
|
|
||||||
r"([^\|][;。?!!\n]|[a-z][.?;!][ \n])",
|
|
||||||
pieces[i]))
|
|
||||||
i += 1
|
i += 1
|
||||||
pieces = pieces_
|
pieces = pieces_
|
||||||
else:
|
else:
|
||||||
|
@ -242,27 +264,19 @@ class Dealer:
|
||||||
chunk_v[i] = [0.0] * len(ans_v[0])
|
chunk_v[i] = [0.0] * len(ans_v[0])
|
||||||
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
|
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
|
||||||
|
|
||||||
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[0]))
|
||||||
len(ans_v[0]), len(chunk_v[0]))
|
|
||||||
|
|
||||||
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split()
|
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split() for ck in chunks]
|
||||||
for ck in chunks]
|
|
||||||
cites = {}
|
cites = {}
|
||||||
thr = 0.63
|
thr = 0.63
|
||||||
while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
|
while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
|
||||||
for i, a in enumerate(pieces_):
|
for i, a in enumerate(pieces_):
|
||||||
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
|
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], chunk_v, rag_tokenizer.tokenize(self.qryr.rmWWW(pieces_[i])).split(), chunks_tks, tkweight, vtweight)
|
||||||
chunk_v,
|
|
||||||
rag_tokenizer.tokenize(
|
|
||||||
self.qryr.rmWWW(pieces_[i])).split(),
|
|
||||||
chunks_tks,
|
|
||||||
tkweight, vtweight)
|
|
||||||
mx = np.max(sim) * 0.99
|
mx = np.max(sim) * 0.99
|
||||||
logging.debug("{} SIM: {}".format(pieces_[i], mx))
|
logging.debug("{} SIM: {}".format(pieces_[i], mx))
|
||||||
if mx < thr:
|
if mx < thr:
|
||||||
continue
|
continue
|
||||||
cites[idx[i]] = list(
|
cites[idx[i]] = list(set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
|
||||||
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
|
|
||||||
thr *= 0.8
|
thr *= 0.8
|
||||||
|
|
||||||
res = ""
|
res = ""
|
||||||
|
@ -305,12 +319,9 @@ class Dealer:
|
||||||
rank_fea.append(0)
|
rank_fea.append(0)
|
||||||
else:
|
else:
|
||||||
rank_fea.append(nor / np.sqrt(denor) / q_denor)
|
rank_fea.append(nor / np.sqrt(denor) / q_denor)
|
||||||
return np.array(rank_fea)*10. + pageranks
|
return np.array(rank_fea) * 10.0 + pageranks
|
||||||
|
|
||||||
def rerank(self, sres, query, tkweight=0.3,
|
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks", rank_feature: dict | None = None):
|
||||||
vtweight=0.7, cfield="content_ltks",
|
|
||||||
rank_feature: dict | None = None
|
|
||||||
):
|
|
||||||
_, keywords = self.qryr.question(query)
|
_, keywords = self.qryr.question(query)
|
||||||
vector_size = len(sres.query_vector)
|
vector_size = len(sres.query_vector)
|
||||||
vector_column = f"q_{vector_size}_vec"
|
vector_column = f"q_{vector_size}_vec"
|
||||||
|
@ -339,16 +350,11 @@ class Dealer:
|
||||||
## For rank feature(tag_fea) scores.
|
## For rank feature(tag_fea) scores.
|
||||||
rank_fea = self._rank_feature_scores(rank_feature, sres)
|
rank_fea = self._rank_feature_scores(rank_feature, sres)
|
||||||
|
|
||||||
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
|
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, ins_embd, keywords, ins_tw, tkweight, vtweight)
|
||||||
ins_embd,
|
|
||||||
keywords,
|
|
||||||
ins_tw, tkweight, vtweight)
|
|
||||||
|
|
||||||
return sim + rank_fea, tksim, vtsim
|
return sim + rank_fea, tksim, vtsim
|
||||||
|
|
||||||
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
|
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks", rank_feature: dict | None = None):
|
||||||
vtweight=0.7, cfield="content_ltks",
|
|
||||||
rank_feature: dict | None = None):
|
|
||||||
_, keywords = self.qryr.question(query)
|
_, keywords = self.qryr.question(query)
|
||||||
|
|
||||||
for i in sres.ids:
|
for i in sres.ids:
|
||||||
|
@ -370,15 +376,25 @@ class Dealer:
|
||||||
return tkweight * (np.array(tksim) + rank_fea) + vtweight * vtsim, tksim, vtsim
|
return tkweight * (np.array(tksim) + rank_fea) + vtweight * vtsim, tksim, vtsim
|
||||||
|
|
||||||
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
||||||
return self.qryr.hybrid_similarity(ans_embd,
|
return self.qryr.hybrid_similarity(ans_embd, ins_embd, rag_tokenizer.tokenize(ans).split(), rag_tokenizer.tokenize(inst).split())
|
||||||
ins_embd,
|
|
||||||
rag_tokenizer.tokenize(ans).split(),
|
|
||||||
rag_tokenizer.tokenize(inst).split())
|
|
||||||
|
|
||||||
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
|
def retrieval(
|
||||||
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True,
|
self,
|
||||||
rerank_mdl=None, highlight=False,
|
question,
|
||||||
rank_feature: dict | None = {PAGERANK_FLD: 10}):
|
embd_mdl,
|
||||||
|
tenant_ids,
|
||||||
|
kb_ids,
|
||||||
|
page,
|
||||||
|
page_size,
|
||||||
|
similarity_threshold=0.2,
|
||||||
|
vector_similarity_weight=0.3,
|
||||||
|
top=1024,
|
||||||
|
doc_ids=None,
|
||||||
|
aggs=True,
|
||||||
|
rerank_mdl=None,
|
||||||
|
highlight=False,
|
||||||
|
rank_feature: dict | None = {PAGERANK_FLD: 10},
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
执行检索操作,根据问题查询相关文档片段
|
执行检索操作,根据问题查询相关文档片段
|
||||||
|
|
||||||
|
@ -406,59 +422,49 @@ class Dealer:
|
||||||
if not question:
|
if not question:
|
||||||
return ranks
|
return ranks
|
||||||
# 设置重排序页面限制
|
# 设置重排序页面限制
|
||||||
RERANK_PAGE_LIMIT = 3
|
RERANK_LIMIT = 64
|
||||||
|
RERANK_LIMIT = int(RERANK_LIMIT // page_size + ((RERANK_LIMIT % page_size) / (page_size * 1.0) + 0.5)) * page_size if page_size > 1 else 1
|
||||||
|
if RERANK_LIMIT < 1:
|
||||||
|
RERANK_LIMIT = 1
|
||||||
# 构建检索请求参数
|
# 构建检索请求参数
|
||||||
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size * RERANK_PAGE_LIMIT, 128),
|
req = {
|
||||||
"question": question, "vector": True, "topk": top,
|
"kb_ids": kb_ids,
|
||||||
|
"doc_ids": doc_ids,
|
||||||
|
"page": math.ceil(page_size * page / RERANK_LIMIT),
|
||||||
|
"size": RERANK_LIMIT,
|
||||||
|
"question": question,
|
||||||
|
"vector": True,
|
||||||
|
"topk": top,
|
||||||
"similarity": similarity_threshold,
|
"similarity": similarity_threshold,
|
||||||
"available_int": 1}
|
"available_int": 1,
|
||||||
|
}
|
||||||
# 如果页码超过重排序限制,直接请求指定页的数据
|
|
||||||
if page > RERANK_PAGE_LIMIT:
|
|
||||||
req["page"] = page
|
|
||||||
req["size"] = page_size
|
|
||||||
|
|
||||||
# 处理租户ID格式
|
# 处理租户ID格式
|
||||||
if isinstance(tenant_ids, str):
|
if isinstance(tenant_ids, str):
|
||||||
tenant_ids = tenant_ids.split(",")
|
tenant_ids = tenant_ids.split(",")
|
||||||
|
|
||||||
# 执行搜索操作
|
# 执行搜索操作
|
||||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids],
|
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
|
||||||
kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
|
|
||||||
ranks["total"] = sres.total
|
|
||||||
|
|
||||||
# 根据页码决定是否需要重排序
|
|
||||||
if page <= RERANK_PAGE_LIMIT:
|
|
||||||
# 前几页需要重排序以提高结果质量
|
|
||||||
if rerank_mdl and sres.total > 0:
|
if rerank_mdl and sres.total > 0:
|
||||||
# 使用重排序模型进行重排序
|
sim, tsim, vsim = self.rerank_by_model(rerank_mdl, sres, question, 1 - vector_similarity_weight, vector_similarity_weight, rank_feature=rank_feature)
|
||||||
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
|
|
||||||
sres, question, 1 - vector_similarity_weight,
|
|
||||||
vector_similarity_weight,
|
|
||||||
rank_feature=rank_feature)
|
|
||||||
else:
|
else:
|
||||||
# 使用默认方法进行重排序
|
sim, tsim, vsim = self.rerank(sres, question, 1 - vector_similarity_weight, vector_similarity_weight, rank_feature=rank_feature)
|
||||||
sim, tsim, vsim = self.rerank(
|
# Already paginated in search function
|
||||||
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
|
|
||||||
rank_feature=rank_feature)
|
|
||||||
# 根据相似度降序排序,并选择当前页的结果
|
|
||||||
idx = np.argsort(sim * -1)[(page - 1) * page_size : page * page_size]
|
idx = np.argsort(sim * -1)[(page - 1) * page_size : page * page_size]
|
||||||
else:
|
|
||||||
# 后续页面不需要重排序,直接使用搜索结果
|
|
||||||
sim = tsim = vsim = [1] * len(sres.ids)
|
|
||||||
idx = list(range(len(sres.ids)))
|
|
||||||
|
|
||||||
# 获取向量维度和列名
|
|
||||||
dim = len(sres.query_vector)
|
dim = len(sres.query_vector)
|
||||||
vector_column = f"q_{dim}_vec"
|
vector_column = f"q_{dim}_vec"
|
||||||
zero_vector = [0.0] * dim
|
zero_vector = [0.0] * dim
|
||||||
|
if doc_ids:
|
||||||
# 处理每个检索结果
|
similarity_threshold = 0
|
||||||
|
page_size = 30
|
||||||
|
sim_np = np.array(sim)
|
||||||
|
filtered_count = (sim_np >= similarity_threshold).sum()
|
||||||
|
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
|
||||||
for i in idx:
|
for i in idx:
|
||||||
# 过滤低于阈值的结果
|
|
||||||
if sim[i] < similarity_threshold:
|
if sim[i] < similarity_threshold:
|
||||||
break
|
break
|
||||||
# 控制返回结果数量
|
|
||||||
if len(ranks["chunks"]) >= page_size:
|
if len(ranks["chunks"]) >= page_size:
|
||||||
if aggs:
|
if aggs:
|
||||||
continue
|
continue
|
||||||
|
@ -468,7 +474,6 @@ class Dealer:
|
||||||
dnm = chunk.get("docnm_kwd", "")
|
dnm = chunk.get("docnm_kwd", "")
|
||||||
did = chunk.get("doc_id", "")
|
did = chunk.get("doc_id", "")
|
||||||
position_int = chunk.get("position_int", [])
|
position_int = chunk.get("position_int", [])
|
||||||
# 构建结果字典
|
|
||||||
d = {
|
d = {
|
||||||
"chunk_id": id,
|
"chunk_id": id,
|
||||||
"content_ltks": chunk["content_ltks"],
|
"content_ltks": chunk["content_ltks"],
|
||||||
|
@ -483,9 +488,8 @@ class Dealer:
|
||||||
"term_similarity": tsim[i],
|
"term_similarity": tsim[i],
|
||||||
"vector": chunk.get(vector_column, zero_vector),
|
"vector": chunk.get(vector_column, zero_vector),
|
||||||
"positions": position_int,
|
"positions": position_int,
|
||||||
|
"doc_type_kwd": chunk.get("doc_type_kwd", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 处理高亮内容
|
|
||||||
if highlight and sres.highlight:
|
if highlight and sres.highlight:
|
||||||
if id in sres.highlight:
|
if id in sres.highlight:
|
||||||
d["highlight"] = rmSpace(sres.highlight[id])
|
d["highlight"] = rmSpace(sres.highlight[id])
|
||||||
|
@ -495,12 +499,7 @@ class Dealer:
|
||||||
if dnm not in ranks["doc_aggs"]:
|
if dnm not in ranks["doc_aggs"]:
|
||||||
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
||||||
ranks["doc_aggs"][dnm]["count"] += 1
|
ranks["doc_aggs"][dnm]["count"] += 1
|
||||||
# 将文档聚合信息转换为列表格式,并按计数降序排序
|
ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k, v in sorted(ranks["doc_aggs"].items(), key=lambda x: x[1]["count"] * -1)]
|
||||||
ranks["doc_aggs"] = [{"doc_name": k,
|
|
||||||
"doc_id": v["doc_id"],
|
|
||||||
"count": v["count"]} for k,
|
|
||||||
v in sorted(ranks["doc_aggs"].items(),
|
|
||||||
key=lambda x: x[1]["count"] * -1)]
|
|
||||||
ranks["chunks"] = ranks["chunks"][:page_size]
|
ranks["chunks"] = ranks["chunks"][:page_size]
|
||||||
|
|
||||||
return ranks
|
return ranks
|
||||||
|
@ -509,16 +508,12 @@ class Dealer:
|
||||||
tbl = self.dataStore.sql(sql, fetch_size, format)
|
tbl = self.dataStore.sql(sql, fetch_size, format)
|
||||||
return tbl
|
return tbl
|
||||||
|
|
||||||
def chunk_list(self, doc_id: str, tenant_id: str,
|
def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, offset=0, fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
||||||
kb_ids: list[str], max_count=1024,
|
|
||||||
offset=0,
|
|
||||||
fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
|
||||||
condition = {"doc_id": doc_id}
|
condition = {"doc_id": doc_id}
|
||||||
res = []
|
res = []
|
||||||
bs = 128
|
bs = 128
|
||||||
for p in range(offset, max_count, bs):
|
for p in range(offset, max_count, bs):
|
||||||
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
|
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), kb_ids)
|
||||||
kb_ids)
|
|
||||||
dict_chunks = self.dataStore.getFields(es_res, fields)
|
dict_chunks = self.dataStore.getFields(es_res, fields)
|
||||||
for id, doc in dict_chunks.items():
|
for id, doc in dict_chunks.items():
|
||||||
doc["id"] = id
|
doc["id"] = id
|
||||||
|
@ -548,8 +543,7 @@ class Dealer:
|
||||||
if not aggs:
|
if not aggs:
|
||||||
return False
|
return False
|
||||||
cnt = np.sum([c for _, c in aggs])
|
cnt = np.sum([c for _, c in aggs])
|
||||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], key=lambda x: x[1] * -1)[:topn_tags]
|
||||||
key=lambda x: x[1] * -1)[:topn_tags]
|
|
||||||
doc[TAG_FLD] = {a: c for a, c in tag_fea if c > 0}
|
doc[TAG_FLD] = {a: c for a, c in tag_fea if c > 0}
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -564,6 +558,5 @@ class Dealer:
|
||||||
if not aggs:
|
if not aggs:
|
||||||
return {}
|
return {}
|
||||||
cnt = np.sum([c for _, c in aggs])
|
cnt = np.sum([c for _, c in aggs])
|
||||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], key=lambda x: x[1] * -1)[:topn_tags]
|
||||||
key=lambda x: x[1] * -1)[:topn_tags]
|
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
|
||||||
return {a: max(1, c) for a, c in tag_fea}
|
|
||||||
|
|
|
@ -25,20 +25,18 @@ from api.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
class Dealer:
|
class Dealer:
|
||||||
def __init__(self, redis=None):
|
def __init__(self, redis=None):
|
||||||
|
|
||||||
self.lookup_num = 100000000
|
self.lookup_num = 100000000
|
||||||
self.load_tm = time.time() - 1000000
|
self.load_tm = time.time() - 1000000
|
||||||
self.dictionary = None
|
self.dictionary = None
|
||||||
path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json")
|
path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json")
|
||||||
try:
|
try:
|
||||||
self.dictionary = json.load(open(path, 'r'))
|
self.dictionary = json.load(open(path, "r"))
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warning("Missing synonym.json")
|
logging.warning("Missing synonym.json")
|
||||||
self.dictionary = {}
|
self.dictionary = {}
|
||||||
|
|
||||||
if not redis:
|
if not redis:
|
||||||
logging.warning(
|
logging.warning("Realtime synonym is disabled, since no redis connection.")
|
||||||
"Realtime synonym is disabled, since no redis connection.")
|
|
||||||
if not len(self.dictionary.keys()):
|
if not len(self.dictionary.keys()):
|
||||||
logging.warning("Fail to load synonym")
|
logging.warning("Fail to load synonym")
|
||||||
|
|
||||||
|
@ -67,18 +65,36 @@ class Dealer:
|
||||||
logging.error("Fail to load synonym!" + str(e))
|
logging.error("Fail to load synonym!" + str(e))
|
||||||
|
|
||||||
def lookup(self, tk, topn=8):
|
def lookup(self, tk, topn=8):
|
||||||
|
"""
|
||||||
|
查找输入词条(tk)的同义词,支持英文和中文混合处理
|
||||||
|
|
||||||
|
参数:
|
||||||
|
tk (str): 待查询的词条(如"happy"或"苹果")
|
||||||
|
topn (int): 最多返回的同义词数量,默认为8
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list: 同义词列表,可能为空(无同义词时)
|
||||||
|
|
||||||
|
处理逻辑:
|
||||||
|
1. 英文单词:使用WordNet语义网络查询
|
||||||
|
2. 中文/其他:从预加载的自定义词典查询
|
||||||
|
"""
|
||||||
|
# 英文单词处理分支
|
||||||
if re.match(r"[a-z]+$", tk):
|
if re.match(r"[a-z]+$", tk):
|
||||||
res = list(set([re.sub("_", " ", syn.name().split(".")[0]) for syn in wordnet.synsets(tk)]) - set([tk]))
|
res = list(set([re.sub("_", " ", syn.name().split(".")[0]) for syn in wordnet.synsets(tk)]) - set([tk]))
|
||||||
return [t for t in res if t]
|
return [t for t in res if t]
|
||||||
|
|
||||||
|
# 中文/其他词条处理
|
||||||
self.lookup_num += 1
|
self.lookup_num += 1
|
||||||
self.load()
|
self.load() # 自定义词典
|
||||||
|
# 从字典获取同义词,默认返回空列表
|
||||||
res = self.dictionary.get(re.sub(r"[ \t]+", " ", tk.lower()), [])
|
res = self.dictionary.get(re.sub(r"[ \t]+", " ", tk.lower()), [])
|
||||||
|
# 兼容处理:如果字典值是字符串,转为单元素列表
|
||||||
if isinstance(res, str):
|
if isinstance(res, str):
|
||||||
res = [res]
|
res = [res]
|
||||||
return res[:topn]
|
return res[:topn]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
dl = Dealer()
|
dl = Dealer()
|
||||||
print(dl.dictionary)
|
print(dl.dictionary)
|
||||||
|
|
|
@ -26,7 +26,9 @@ from api.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
class Dealer:
|
class Dealer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.stop_words = set(["请问",
|
self.stop_words = set(
|
||||||
|
[
|
||||||
|
"请问",
|
||||||
"您",
|
"您",
|
||||||
"你",
|
"你",
|
||||||
"我",
|
"我",
|
||||||
|
@ -56,7 +58,9 @@ class Dealer:
|
||||||
"哪个",
|
"哪个",
|
||||||
"哪些",
|
"哪些",
|
||||||
"啥",
|
"啥",
|
||||||
"相关"])
|
"相关",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def load_dict(fnm):
|
def load_dict(fnm):
|
||||||
res = {}
|
res = {}
|
||||||
|
@ -90,19 +94,15 @@ class Dealer:
|
||||||
logging.warning("Load term.freq FAIL!")
|
logging.warning("Load term.freq FAIL!")
|
||||||
|
|
||||||
def pretoken(self, txt, num=False, stpwd=True):
|
def pretoken(self, txt, num=False, stpwd=True):
|
||||||
patt = [
|
patt = [r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"]
|
||||||
r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
|
rewt = []
|
||||||
]
|
|
||||||
rewt = [
|
|
||||||
]
|
|
||||||
for p, r in rewt:
|
for p, r in rewt:
|
||||||
txt = re.sub(p, r, txt)
|
txt = re.sub(p, r, txt)
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
for t in rag_tokenizer.tokenize(txt).split():
|
for t in rag_tokenizer.tokenize(txt).split():
|
||||||
tk = t
|
tk = t
|
||||||
if (stpwd and tk in self.stop_words) or (
|
if (stpwd and tk in self.stop_words) or (re.match(r"[0-9]$", tk) and not num):
|
||||||
re.match(r"[0-9]$", tk) and not num):
|
|
||||||
continue
|
continue
|
||||||
for p in patt:
|
for p in patt:
|
||||||
if re.match(p, t):
|
if re.match(p, t):
|
||||||
|
@ -114,19 +114,18 @@ class Dealer:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def tokenMerge(self, tks):
|
def tokenMerge(self, tks):
|
||||||
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
def oneTerm(t):
|
||||||
|
return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||||
|
|
||||||
res, i = [], 0
|
res, i = [], 0
|
||||||
while i < len(tks):
|
while i < len(tks):
|
||||||
j = i
|
j = i
|
||||||
if i == 0 and oneTerm(tks[i]) and len(
|
if i == 0 and oneTerm(tks[i]) and len(tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
|
||||||
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
|
|
||||||
res.append(" ".join(tks[0:2]))
|
res.append(" ".join(tks[0:2]))
|
||||||
i = 2
|
i = 2
|
||||||
continue
|
continue
|
||||||
|
|
||||||
while j < len(
|
while j < len(tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
|
||||||
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
|
|
||||||
j += 1
|
j += 1
|
||||||
if j - i > 1:
|
if j - i > 1:
|
||||||
if j - i < 5:
|
if j - i < 5:
|
||||||
|
@ -159,9 +158,7 @@ class Dealer:
|
||||||
"""
|
"""
|
||||||
tks = []
|
tks = []
|
||||||
for t in re.sub(r"[ \t]+", " ", txt).split():
|
for t in re.sub(r"[ \t]+", " ", txt).split():
|
||||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
|
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and re.match(r".*[a-zA-Z]$", t) and tks and self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
|
||||||
re.match(r".*[a-zA-Z]$", t) and tks and \
|
|
||||||
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
|
|
||||||
tks[-1] = tks[-1] + " " + t
|
tks[-1] = tks[-1] + " " + t
|
||||||
else:
|
else:
|
||||||
tks.append(t)
|
tks.append(t)
|
||||||
|
@ -180,8 +177,7 @@ class Dealer:
|
||||||
return 0.01
|
return 0.01
|
||||||
if not self.ne or t not in self.ne:
|
if not self.ne or t not in self.ne:
|
||||||
return 1
|
return 1
|
||||||
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
|
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, "firstnm": 1}
|
||||||
"firstnm": 1}
|
|
||||||
return m[self.ne[t]]
|
return m[self.ne[t]]
|
||||||
|
|
||||||
def postag(t):
|
def postag(t):
|
||||||
|
@ -208,7 +204,7 @@ class Dealer:
|
||||||
if not s and len(t) >= 4:
|
if not s and len(t) >= 4:
|
||||||
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
|
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
|
||||||
if len(s) > 1:
|
if len(s) > 1:
|
||||||
s = np.min([freq(tt) for tt in s]) / 6.
|
s = np.min([freq(tt) for tt in s]) / 6.0
|
||||||
else:
|
else:
|
||||||
s = 0
|
s = 0
|
||||||
|
|
||||||
|
@ -224,18 +220,18 @@ class Dealer:
|
||||||
elif len(t) >= 4:
|
elif len(t) >= 4:
|
||||||
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
|
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
|
||||||
if len(s) > 1:
|
if len(s) > 1:
|
||||||
return max(3, np.min([df(tt) for tt in s]) / 6.)
|
return max(3, np.min([df(tt) for tt in s]) / 6.0)
|
||||||
|
|
||||||
return 3
|
return 3
|
||||||
|
|
||||||
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
|
def idf(s, N):
|
||||||
|
return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
|
||||||
|
|
||||||
tw = []
|
tw = []
|
||||||
if not preprocess:
|
if not preprocess:
|
||||||
idf1 = np.array([idf(freq(t), 10000000) for t in tks])
|
idf1 = np.array([idf(freq(t), 10000000) for t in tks])
|
||||||
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
|
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
|
||||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
wts = (0.3 * idf1 + 0.7 * idf2) * np.array([ner(t) * postag(t) for t in tks])
|
||||||
np.array([ner(t) * postag(t) for t in tks])
|
|
||||||
wts = [s for s in wts]
|
wts = [s for s in wts]
|
||||||
tw = list(zip(tks, wts))
|
tw = list(zip(tks, wts))
|
||||||
else:
|
else:
|
||||||
|
@ -243,8 +239,7 @@ class Dealer:
|
||||||
tt = self.tokenMerge(self.pretoken(tk, True))
|
tt = self.tokenMerge(self.pretoken(tk, True))
|
||||||
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
|
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
|
||||||
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
|
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
|
||||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
wts = (0.3 * idf1 + 0.7 * idf2) * np.array([ner(t) * postag(t) for t in tt])
|
||||||
np.array([ner(t) * postag(t) for t in tt])
|
|
||||||
wts = [s for s in wts]
|
wts = [s for s in wts]
|
||||||
tw.extend(zip(tt, wts))
|
tw.extend(zip(tt, wts))
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,10 @@ def chunks_format(reference):
|
||||||
"image_id": get_value(chunk, "image_id", "img_id"),
|
"image_id": get_value(chunk, "image_id", "img_id"),
|
||||||
"positions": get_value(chunk, "positions", "position_int"),
|
"positions": get_value(chunk, "positions", "position_int"),
|
||||||
"url": chunk.get("url"),
|
"url": chunk.get("url"),
|
||||||
|
"similarity": chunk.get("similarity"),
|
||||||
|
"vector_similarity": chunk.get("vector_similarity"),
|
||||||
|
"term_similarity": chunk.get("term_similarity"),
|
||||||
|
"doc_type": chunk.get("doc_type_kwd"),
|
||||||
}
|
}
|
||||||
for chunk in reference.get("chunks", [])
|
for chunk in reference.get("chunks", [])
|
||||||
]
|
]
|
||||||
|
@ -145,15 +149,17 @@ def kb_prompt(kbinfos, max_tokens):
|
||||||
|
|
||||||
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
|
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
|
||||||
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
||||||
doc2chunks[ck["docnm_kwd"]]["chunks"].append((f"URL: {ck['url']}\n" if "url" in ck else "") + f"ID: {i}\n" + ck["content_with_weight"])
|
cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
|
||||||
|
cnt += ck["content_with_weight"]
|
||||||
|
doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
|
||||||
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
|
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
|
||||||
|
|
||||||
knowledges = []
|
knowledges = []
|
||||||
for nm, cks_meta in doc2chunks.items():
|
for nm, cks_meta in doc2chunks.items():
|
||||||
txt = f"\nDocument: {nm} \n"
|
txt = f"\n文档: {nm} \n"
|
||||||
for k, v in cks_meta["meta"].items():
|
for k, v in cks_meta["meta"].items():
|
||||||
txt += f"{k}: {v}\n"
|
txt += f"{k}: {v}\n"
|
||||||
txt += "Relevant fragments as following:\n"
|
txt += "相关片段如下:\n"
|
||||||
for i, chunk in enumerate(cks_meta["chunks"], 1):
|
for i, chunk in enumerate(cks_meta["chunks"], 1):
|
||||||
txt += f"{chunk}\n"
|
txt += f"{chunk}\n"
|
||||||
knowledges.append(txt)
|
knowledges.append(txt)
|
||||||
|
@ -388,3 +394,57 @@ Output:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"JSON parsing error: {result} -> {e}")
|
logging.exception(f"JSON parsing error: {result} -> {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def vision_llm_describe_prompt(page=None) -> str:
|
||||||
|
prompt_en = """
|
||||||
|
INSTRUCTION:
|
||||||
|
Transcribe the content from the provided PDF page image into clean Markdown format.
|
||||||
|
- Only output the content transcribed from the image.
|
||||||
|
- Do NOT output this instruction or any other explanation.
|
||||||
|
- If the content is missing or you do not understand the input, return an empty string.
|
||||||
|
|
||||||
|
RULES:
|
||||||
|
1. Do NOT generate examples, demonstrations, or templates.
|
||||||
|
2. Do NOT output any extra text such as 'Example', 'Example Output', or similar.
|
||||||
|
3. Do NOT generate any tables, headings, or content that is not explicitly present in the image.
|
||||||
|
4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content.
|
||||||
|
5. Do NOT explain Markdown or mention that you are using Markdown.
|
||||||
|
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
||||||
|
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
||||||
|
8. Preserve the original language, information, and order exactly as shown in the image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if page is not None:
|
||||||
|
prompt_en += f"\nAt the end of the transcription, add the page divider: `--- Page {page} ---`."
|
||||||
|
|
||||||
|
prompt_en += """
|
||||||
|
FAILURE HANDLING:
|
||||||
|
- If you do not detect valid content in the image, return an empty string.
|
||||||
|
"""
|
||||||
|
return prompt_en
|
||||||
|
|
||||||
|
|
||||||
|
def vision_llm_figure_describe_prompt() -> str:
|
||||||
|
prompt = """
|
||||||
|
You are an expert visual data analyst. Analyze the image and provide a comprehensive description of its content. Focus on identifying the type of visual data representation (e.g., bar chart, pie chart, line graph, table, flowchart), its structure, and any text captions or labels included in the image.
|
||||||
|
|
||||||
|
Tasks:
|
||||||
|
1. Describe the overall structure of the visual representation. Specify if it is a chart, graph, table, or diagram.
|
||||||
|
2. Identify and extract any axes, legends, titles, or labels present in the image. Provide the exact text where available.
|
||||||
|
3. Extract the data points from the visual elements (e.g., bar heights, line graph coordinates, pie chart segments, table rows and columns).
|
||||||
|
4. Analyze and explain any trends, comparisons, or patterns shown in the data.
|
||||||
|
5. Capture any annotations, captions, or footnotes, and explain their relevance to the image.
|
||||||
|
6. Only include details that are explicitly present in the image. If an element (e.g., axis, legend, or caption) does not exist or is not visible, do not mention it.
|
||||||
|
|
||||||
|
Output format (include only sections relevant to the image content):
|
||||||
|
- Visual Type: [Type]
|
||||||
|
- Title: [Title text, if available]
|
||||||
|
- Axes / Legends / Labels: [Details, if available]
|
||||||
|
- Data Points: [Extracted data]
|
||||||
|
- Trends / Insights: [Analysis and interpretation]
|
||||||
|
- Captions / Annotations: [Text and relevance, if available]
|
||||||
|
|
||||||
|
Ensure high accuracy, clarity, and completeness in your analysis, and includes only the information present in the image. Avoid unnecessary statements about missing elements.
|
||||||
|
"""
|
||||||
|
return prompt
|
||||||
|
|
10540
rag/res/synonym.json
10540
rag/res/synonym.json
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue