Fix: 增加解析分词,修复召回时关键词判断失效问题 issue#133
This commit is contained in:
commit
401e3d81c4
|
@ -54,3 +54,4 @@ docker/models
|
|||
management/web/types/auto
|
||||
web/node_modules/.cache/logger/umi.log
|
||||
management/models--slanet_plus
|
||||
node_modules/.cache/logger/umi.log
|
||||
|
|
53
README.md
53
README.md
|
@ -76,23 +76,32 @@ ollama pull bge-m3:latest
|
|||
|
||||
#### 1. 使用Docker Compose运行
|
||||
|
||||
在项目根目录下执行
|
||||
- 使用GPU运行(需保证首张显卡有6GB以上剩余显存):
|
||||
|
||||
使用GPU运行:
|
||||
```bash
|
||||
docker compose -f docker/docker-compose_gpu.yml up -d
|
||||
```
|
||||
1. 在宿主机安装nvidia-container-runtime,让 Docker 自动挂载 GPU 设备和驱动:
|
||||
|
||||
使用CPU运行:
|
||||
```bash
|
||||
docker compose -f docker/docker-compose.yml up -d
|
||||
```
|
||||
```bash
|
||||
sudo apt install -y nvidia-container-runtime
|
||||
```
|
||||
|
||||
2. 在项目根目录下执行
|
||||
|
||||
```bash
|
||||
docker compose -f docker/docker-compose_gpu.yml up -d
|
||||
```
|
||||
|
||||
- 使用CPU运行:
|
||||
|
||||
在项目根目录下执行
|
||||
|
||||
```bash
|
||||
docker compose -f docker/docker-compose.yml up -d
|
||||
```
|
||||
|
||||
访问地址:`服务器ip:80`,进入到前台界面
|
||||
|
||||
访问地址:`服务器ip:8888`,进入到后台管理界面
|
||||
|
||||
图文教程:[https://blog.csdn.net/qq1198768105/article/details/147475488](https://blog.csdn.net/qq1198768105/article/details/147475488)
|
||||
|
||||
#### 2. 源码运行(mysql、minio、es等组件仍需docker启动)
|
||||
|
||||
|
@ -100,29 +109,29 @@ docker compose -f docker/docker-compose.yml up -d
|
|||
|
||||
- 启动后端:进入到`management/server`,执行:
|
||||
|
||||
```bash
|
||||
python app.py
|
||||
```
|
||||
```bash
|
||||
python app.py
|
||||
```
|
||||
|
||||
- 启动前端:进入到`management\web`,执行:
|
||||
|
||||
```bash
|
||||
pnpm dev
|
||||
```
|
||||
```bash
|
||||
pnpm dev
|
||||
```
|
||||
|
||||
2. 启动前台交互系统:
|
||||
|
||||
- 启动后端:项目根目录下执行:
|
||||
|
||||
```bash
|
||||
python -m api.ragflow_server
|
||||
```
|
||||
```bash
|
||||
python -m api.ragflow_server
|
||||
```
|
||||
|
||||
- 启动前端:进入到`web`,执行:
|
||||
|
||||
```bash
|
||||
pnpm dev
|
||||
```
|
||||
```bash
|
||||
pnpm dev
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> 源码部署需要注意:如果用到MinerU后台解析,需要参考MinerU的文档下载模型文件,并安装LibreOffice,配置环境变量,以适配支持除pdf之外的类型文件。
|
||||
|
|
|
@ -22,7 +22,8 @@ from flask_login import login_required, current_user
|
|||
from rag.app.qa import rmPrefix, beAdoc
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.prompts import keyword_extraction
|
||||
|
||||
# from rag.prompts import keyword_extraction, cross_languages
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils import rmSpace
|
||||
from api.db import LLMType, ParserType
|
||||
|
@ -37,9 +38,9 @@ import xxhash
|
|||
import re
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
@validate_request("doc_id") # 验证请求中必须包含 doc_id 参数
|
||||
def list_chunk():
|
||||
req = request.json
|
||||
doc_id = req["doc_id"]
|
||||
|
@ -54,9 +55,7 @@ def list_chunk():
|
|||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
query = {
|
||||
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
|
||||
}
|
||||
query = {"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True}
|
||||
if "available_int" in req:
|
||||
query["available_int"] = int(req["available_int"])
|
||||
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
||||
|
@ -64,9 +63,7 @@ def list_chunk():
|
|||
for id in sres.ids:
|
||||
d = {
|
||||
"chunk_id": id,
|
||||
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[
|
||||
id].get(
|
||||
"content_with_weight", ""),
|
||||
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get("content_with_weight", ""),
|
||||
"doc_id": sres.field[id]["doc_id"],
|
||||
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
||||
"important_kwd": sres.field[id].get("important_kwd", []),
|
||||
|
@ -81,12 +78,11 @@ def list_chunk():
|
|||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found!',
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
return get_json_result(data=False, message="No chunk found!", code=settings.RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/get', methods=['GET']) # noqa: F821
|
||||
@manager.route("/get", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def get():
|
||||
chunk_id = request.args["chunk_id"]
|
||||
|
@ -112,19 +108,16 @@ def get():
|
|||
return get_json_result(data=chunk)
|
||||
except Exception as e:
|
||||
if str(e).find("NotFoundError") >= 0:
|
||||
return get_json_result(data=False, message='Chunk not found!',
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
return get_json_result(data=False, message="Chunk not found!", code=settings.RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
||||
def set():
|
||||
req = request.json
|
||||
d = {
|
||||
"id": req["chunk_id"],
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]}
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
if "important_kwd" in req:
|
||||
|
@ -153,13 +146,9 @@ def set():
|
|||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
if doc.parser_id == ParserType.QA:
|
||||
arr = [
|
||||
t for t in re.split(
|
||||
r"[\n\t]",
|
||||
req["content_with_weight"]) if len(t) > 1]
|
||||
arr = [t for t in re.split(r"[\n\t]", req["content_with_weight"]) if len(t) > 1]
|
||||
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
||||
d = beAdoc(d, q, a, not any(
|
||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||
d = beAdoc(d, q, a, not any([rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
|
@ -170,7 +159,7 @@ def set():
|
|||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/switch', methods=['POST']) # noqa: F821
|
||||
@manager.route("/switch", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "available_int", "doc_id")
|
||||
def switch():
|
||||
|
@ -180,20 +169,19 @@ def switch():
|
|||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
for cid in req["chunk_ids"]:
|
||||
if not settings.docStoreConn.update({"id": cid},
|
||||
{"available_int": int(req["available_int"])},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
if not settings.docStoreConn.update({"id": cid}, {"available_int": int(req["available_int"])}, search.index_name(DocumentService.get_tenant_id(req["doc_id"])), doc.kb_id):
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "doc_id")
|
||||
def rm():
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
req = request.json
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
|
@ -204,19 +192,21 @@ def rm():
|
|||
deleted_chunk_ids = req["chunk_ids"]
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
for cid in deleted_chunk_ids:
|
||||
if STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
||||
@manager.route("/create", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "content_with_weight")
|
||||
def create():
|
||||
req = request.json
|
||||
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), "content_with_weight": req["content_with_weight"]}
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
d["important_kwd"] = req.get("important_kwd", [])
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_kwd", [])))
|
||||
|
@ -252,14 +242,35 @@ def create():
|
|||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
||||
return get_json_result(data={"chunk_id": chunck_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821
|
||||
"""
|
||||
{
|
||||
"similarity_threshold": 0.2,
|
||||
"vector_similarity_weight": 0.30000000000000004,
|
||||
"question": "香港",
|
||||
"doc_ids": [],
|
||||
"kb_id": "4b071030bc8e43f1bfb8b7831f320d2f",
|
||||
"page": 1,
|
||||
"size": 10
|
||||
},
|
||||
{
|
||||
"similarity_threshold": 0.2,
|
||||
"vector_similarity_weight": 0.30000000000000004,
|
||||
"question": "显著优势",
|
||||
"doc_ids": [],
|
||||
"kb_id": "1848bc54384611f0b33e4e66786d0323",
|
||||
"page": 1,
|
||||
"size": 10
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
@manager.route("/retrieval_test", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval_test():
|
||||
|
@ -268,57 +279,53 @@ def retrieval_test():
|
|||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
kb_ids = req["kb_id"]
|
||||
# 如果kb_ids是字符串,将其转换为列表
|
||||
if isinstance(kb_ids, str):
|
||||
kb_ids = [kb_ids]
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
use_kg = req.get("use_kg", False)
|
||||
top = int(req.get("top_k", 1024))
|
||||
# langs = req.get("cross_languages", []) # 获取跨语言设定
|
||||
tenant_ids = []
|
||||
|
||||
try:
|
||||
# 查询当前用户所属的租户
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
# 验证知识库权限
|
||||
for kb_id in kb_ids:
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
tenant_id=tenant.tenant_id, id=kb_id):
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
tenant_ids.append(tenant.tenant_id)
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
# 获取知识库信息
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
|
||||
# if langs:
|
||||
# question = cross_languages(kb.tenant_id, None, question, langs) # 跨语言处理
|
||||
|
||||
# 加载嵌入模型
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
# 加载重排序模型(如果指定)
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
# 对问题进行标签化
|
||||
# labels = label_question(question, [kb])
|
||||
labels = None
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
|
||||
rank_feature=labels
|
||||
)
|
||||
if use_kg:
|
||||
ck = settings.kg_retrievaler.retrieval(question,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
# 执行检索操作
|
||||
ranks = settings.retrievaler.retrieval(
|
||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||
)
|
||||
|
||||
# 移除不必要的向量信息
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
ranks["labels"] = labels
|
||||
|
@ -326,47 +333,5 @@ def retrieval_test():
|
|||
return get_json_result(data=ranks)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
return get_json_result(data=False, message="No chunk found! Check the chunk status please!", code=settings.RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def knowledge_graph():
|
||||
doc_id = request.args["doc_id"]
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
req = {
|
||||
"doc_ids": [doc_id],
|
||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||
}
|
||||
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
for id in sres.ids[:2]:
|
||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||
try:
|
||||
content_json = json.loads(sres.field[id]["content_with_weight"])
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if ty == 'mind_map':
|
||||
node_dict = {}
|
||||
|
||||
def repeat_deal(content_json, node_dict):
|
||||
if 'id' in content_json:
|
||||
if content_json['id'] in node_dict:
|
||||
node_name = content_json['id']
|
||||
content_json['id'] += f"({node_dict[content_json['id']]})"
|
||||
node_dict[node_name] += 1
|
||||
else:
|
||||
node_dict[content_json['id']] = 1
|
||||
if 'children' in content_json and content_json['children']:
|
||||
for item in content_json['children']:
|
||||
repeat_deal(item, node_dict)
|
||||
|
||||
repeat_deal(content_json, node_dict)
|
||||
|
||||
obj[ty] = content_json
|
||||
|
||||
return get_json_result(data=obj)
|
||||
|
|
|
@ -4,9 +4,11 @@ import redis
|
|||
from minio import Minio
|
||||
from dotenv import load_dotenv
|
||||
from elasticsearch import Elasticsearch
|
||||
from pathlib import Path
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv("../../docker/.env")
|
||||
env_path = Path(__file__).parent.parent.parent / "docker" / ".env"
|
||||
load_dotenv(env_path)
|
||||
|
||||
|
||||
# 检测是否在Docker容器中运行
|
||||
|
|
|
@ -24,4 +24,5 @@ pyclipper==1.3.0.post6
|
|||
omegaconf==2.3.0
|
||||
rapid-table==1.0.3
|
||||
openai==1.70.0
|
||||
redis==6.2.0
|
||||
redis==6.2.0
|
||||
tokenizer==3.4.5
|
|
@ -2,12 +2,12 @@ import os
|
|||
import tempfile
|
||||
import shutil
|
||||
import json
|
||||
from bs4 import BeautifulSoup
|
||||
import mysql.connector
|
||||
import time
|
||||
import traceback
|
||||
import re
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from io import BytesIO
|
||||
from datetime import datetime
|
||||
from database import MINIO_CONFIG, DB_CONFIG, get_minio_client, get_es_client
|
||||
|
@ -17,17 +17,18 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
|||
from magic_pdf.config.enums import SupportedPdfParseMethod
|
||||
from magic_pdf.data.read_api import read_local_office, read_local_images
|
||||
from utils import generate_uuid
|
||||
from .rag_tokenizer import RagTokenizer
|
||||
|
||||
|
||||
tknzr = RagTokenizer()
|
||||
|
||||
|
||||
# 自定义tokenizer和文本处理函数,替代rag.nlp中的功能
|
||||
def tokenize_text(text):
|
||||
"""将文本分词,替代rag_tokenizer功能"""
|
||||
# 简单实现,未来可能需要改成更复杂的分词逻辑
|
||||
return text.split()
|
||||
return tknzr.tokenize(text)
|
||||
|
||||
|
||||
def merge_chunks(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
"""合并文本块,替代naive_merge功能"""
|
||||
"""合并文本块,替代naive_merge功能(预留函数)"""
|
||||
if not sections:
|
||||
return []
|
||||
|
||||
|
@ -149,25 +150,7 @@ def _create_task_record(doc_id, chunk_ids_list):
|
|||
progress, progress_msg, retry_count, digest, chunk_ids, task_type, priority
|
||||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
task_params = [
|
||||
task_id,
|
||||
current_timestamp,
|
||||
current_date_only,
|
||||
current_timestamp,
|
||||
current_date_only,
|
||||
doc_id,
|
||||
0,
|
||||
1,
|
||||
None,
|
||||
0.0,
|
||||
1.0,
|
||||
"MinerU解析完成",
|
||||
1,
|
||||
digest,
|
||||
chunk_ids_str,
|
||||
"",
|
||||
0
|
||||
]
|
||||
task_params = [task_id, current_timestamp, current_date_only, current_timestamp, current_date_only, doc_id, 0, 1, None, 0.0, 1.0, "MinerU解析完成", 1, digest, chunk_ids_str, "", 0]
|
||||
cursor.execute(task_insert, task_params)
|
||||
conn.commit()
|
||||
print(f"[Parser-INFO] Task记录创建成功,Task ID: {task_id}")
|
||||
|
@ -204,54 +187,55 @@ def get_bbox_from_block(block):
|
|||
def process_table_content(content_list):
|
||||
"""
|
||||
处理表格内容,将每一行分开存储
|
||||
|
||||
|
||||
Args:
|
||||
content_list: 原始内容列表
|
||||
|
||||
|
||||
Returns:
|
||||
处理后的内容列表
|
||||
"""
|
||||
new_content_list = []
|
||||
|
||||
|
||||
for item in content_list:
|
||||
if 'table_body' in item and item['table_body']:
|
||||
if "table_body" in item and item["table_body"]:
|
||||
# 使用BeautifulSoup解析HTML表格
|
||||
soup = BeautifulSoup(item['table_body'], 'html.parser')
|
||||
table = soup.find('table')
|
||||
|
||||
soup = BeautifulSoup(item["table_body"], "html.parser")
|
||||
table = soup.find("table")
|
||||
|
||||
if table:
|
||||
rows = table.find_all('tr')
|
||||
rows = table.find_all("tr")
|
||||
# 获取表头(第一行)
|
||||
header_row = rows[0] if rows else None
|
||||
|
||||
|
||||
# 处理每一行,从第二行开始(跳过表头)
|
||||
for i, row in enumerate(rows):
|
||||
# 创建新的内容项
|
||||
new_item = item.copy()
|
||||
|
||||
|
||||
# 创建只包含当前行的表格
|
||||
new_table = soup.new_tag('table')
|
||||
|
||||
new_table = soup.new_tag("table")
|
||||
|
||||
# 如果有表头,添加表头
|
||||
if header_row and i > 0:
|
||||
new_table.append(header_row)
|
||||
|
||||
|
||||
# 添加当前行
|
||||
new_table.append(row)
|
||||
|
||||
|
||||
# 创建新的HTML结构
|
||||
new_html = f"<html><body>{str(new_table)}</body></html>"
|
||||
new_item['table_body'] = f"\n\n{new_html}\n\n"
|
||||
|
||||
new_item["table_body"] = f"\n\n{new_html}\n\n"
|
||||
|
||||
# 添加到新的内容列表
|
||||
new_content_list.append(new_item)
|
||||
else:
|
||||
new_content_list.append(item)
|
||||
else:
|
||||
new_content_list.append(item)
|
||||
|
||||
|
||||
return new_content_list
|
||||
|
||||
|
||||
def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
||||
"""
|
||||
执行文档解析的核心逻辑
|
||||
|
@ -305,7 +289,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
if normalized_base_url.endswith("/v1"):
|
||||
# 如果 base_url 已经是 http://host/v1 形式
|
||||
embedding_url = normalized_base_url + "/" + endpoint_segment
|
||||
elif normalized_base_url.endswith('/embeddings'):
|
||||
elif normalized_base_url.endswith("/embeddings"):
|
||||
# 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API,无需再进行处理)
|
||||
embedding_url = normalized_base_url
|
||||
else:
|
||||
|
@ -403,7 +387,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
middle_content = pipe_result.get_middle_json()
|
||||
middle_json_content = json.loads(middle_content)
|
||||
# 对excel文件单独进行处理
|
||||
elif file_type.endswith("excel") :
|
||||
elif file_type.endswith("excel"):
|
||||
update_progress(0.3, "使用MinerU解析器")
|
||||
# 创建临时文件保存文件内容
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
@ -441,7 +425,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
# 使用MinerU处理
|
||||
ds = read_local_images(temp_file_path)[0]
|
||||
infer_result = ds.apply(doc_analyze, ocr=True)
|
||||
|
||||
|
||||
update_progress(0.3, "分析PDF类型")
|
||||
is_ocr = ds.classify() == SupportedPdfParseMethod.OCR
|
||||
mode_msg = "OCR模式" if is_ocr else "文本模式"
|
||||
|
@ -613,7 +597,6 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
)
|
||||
|
||||
# 准备ES文档
|
||||
content_tokens = tokenize_text(content) # 分词
|
||||
current_time_es = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
current_timestamp_es = datetime.now().timestamp()
|
||||
|
||||
|
@ -625,11 +608,11 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
"doc_id": doc_id,
|
||||
"kb_id": kb_id,
|
||||
"docnm_kwd": doc_info["name"],
|
||||
"title_tks": doc_info["name"],
|
||||
"title_sm_tks": doc_info["name"],
|
||||
"title_tks": tokenize_text(doc_info["name"]),
|
||||
"title_sm_tks": tokenize_text(doc_info["name"]),
|
||||
"content_with_weight": content,
|
||||
"content_ltks": " ".join(content_tokens), # 字符串类型
|
||||
"content_sm_ltks": " ".join(content_tokens), # 字符串类型
|
||||
"content_ltks": tokenize_text(content),
|
||||
"content_sm_ltks": tokenize_text(content),
|
||||
"page_num_int": [page_idx + 1],
|
||||
"position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]]
|
||||
"top_int": [1],
|
||||
|
@ -755,7 +738,6 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
|
|||
traceback.print_exc() # 打印详细错误堆栈
|
||||
# 更新文档状态为失败
|
||||
_update_document_progress(doc_id, status="1", run="0", message=error_message, process_duration=process_duration) # status=1表示完成,run=0表示失败
|
||||
# 不抛出异常,让调用者知道任务已结束(但失败)
|
||||
return {"success": False, "error": error_message}
|
||||
|
||||
finally:
|
||||
|
|
|
@ -0,0 +1,412 @@
|
|||
import logging
|
||||
import copy
|
||||
import datrie
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
from hanziconv import HanziConv
|
||||
from nltk import word_tokenize
|
||||
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
||||
|
||||
|
||||
# def get_project_base_directory():
|
||||
# return os.path.abspath(
|
||||
# os.path.join(
|
||||
# os.path.dirname(os.path.realpath(__file__)),
|
||||
# os.pardir,
|
||||
# os.pardir,
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
class RagTokenizer:
|
||||
def key_(self, line):
|
||||
return str(line.lower().encode("utf-8"))[2:-1]
|
||||
|
||||
def rkey_(self, line):
|
||||
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
||||
|
||||
def loadDict_(self, fnm):
|
||||
print(f"[HUQIE]:Build trie from {fnm}")
|
||||
try:
|
||||
of = open(fnm, "r", encoding="utf-8")
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
line = re.sub(r"[\r\n]+", "", line)
|
||||
line = re.split(r"[ \t]", line)
|
||||
k = self.key_(line[0])
|
||||
F = int(math.log(float(line[1]) / self.DENOMINATOR) + 0.5)
|
||||
if k not in self.trie_ or self.trie_[k][0] < F:
|
||||
self.trie_[self.key_(line[0])] = (F, line[2])
|
||||
self.trie_[self.rkey_(line[0])] = 1
|
||||
|
||||
dict_file_cache = fnm + ".trie"
|
||||
print(f"[HUQIE]:Build trie cache to {dict_file_cache}")
|
||||
self.trie_.save(dict_file_cache)
|
||||
of.close()
|
||||
except Exception:
|
||||
logging.exception(f"[HUQIE]:Build trie {fnm} failed")
|
||||
|
||||
def __init__(self):
|
||||
self.DENOMINATOR = 1000000
|
||||
self.DIR_ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "res", "huqie")
|
||||
|
||||
self.stemmer = PorterStemmer()
|
||||
self.lemmatizer = WordNetLemmatizer()
|
||||
|
||||
self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)"
|
||||
|
||||
trie_file_name = self.DIR_ + ".txt.trie"
|
||||
# check if trie file existence
|
||||
if os.path.exists(trie_file_name):
|
||||
try:
|
||||
# load trie from file
|
||||
self.trie_ = datrie.Trie.load(trie_file_name)
|
||||
return
|
||||
except Exception:
|
||||
# fail to load trie from file, build default trie
|
||||
logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
else:
|
||||
# file not exist, build default trie
|
||||
print(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
|
||||
# load data from dict file and save to trie file
|
||||
self.loadDict_(self.DIR_ + ".txt")
|
||||
|
||||
def _strQ2B(self, ustring):
|
||||
"""全角转半角,转小写"""
|
||||
rstring = ""
|
||||
for uchar in ustring:
|
||||
inside_code = ord(uchar)
|
||||
if inside_code == 0x3000:
|
||||
inside_code = 0x0020
|
||||
else:
|
||||
inside_code -= 0xFEE0
|
||||
if inside_code < 0x0020 or inside_code > 0x7E: # After the conversion, if it's not a half-width character, return the original character.
|
||||
rstring += uchar
|
||||
else:
|
||||
rstring += chr(inside_code)
|
||||
return rstring
|
||||
|
||||
def _tradi2simp(self, line):
|
||||
"""繁体转简体"""
|
||||
return HanziConv.toSimplified(line)
|
||||
|
||||
def dfs_(self, chars, s, preTks, tkslist):
|
||||
res = s
|
||||
# if s > MAX_L or s>= len(chars):
|
||||
if s >= len(chars):
|
||||
tkslist.append(preTks)
|
||||
return res
|
||||
|
||||
# pruning
|
||||
S = s + 1
|
||||
if s + 2 <= len(chars):
|
||||
t1, t2 = "".join(chars[s : s + 1]), "".join(chars[s : s + 2])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
|
||||
S = s + 2
|
||||
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
||||
t1 = preTks[-1][0] + "".join(chars[s : s + 1])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
||||
S = s + 2
|
||||
|
||||
for e in range(S, len(chars) + 1):
|
||||
t = "".join(chars[s:e])
|
||||
k = self.key_(t)
|
||||
|
||||
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
|
||||
break
|
||||
|
||||
if k in self.trie_:
|
||||
pretks = copy.deepcopy(preTks)
|
||||
if k in self.trie_:
|
||||
pretks.append((t, self.trie_[k]))
|
||||
else:
|
||||
pretks.append((t, (-12, "")))
|
||||
res = max(res, self.dfs_(chars, e, pretks, tkslist))
|
||||
|
||||
if res > s:
|
||||
return res
|
||||
|
||||
t = "".join(chars[s : s + 1])
|
||||
k = self.key_(t)
|
||||
if k in self.trie_:
|
||||
preTks.append((t, self.trie_[k]))
|
||||
else:
|
||||
preTks.append((t, (-12, "")))
|
||||
|
||||
return self.dfs_(chars, s + 1, preTks, tkslist)
|
||||
|
||||
def freq(self, tk):
|
||||
k = self.key_(tk)
|
||||
if k not in self.trie_:
|
||||
return 0
|
||||
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
|
||||
|
||||
def score_(self, tfts):
|
||||
B = 30
|
||||
F, L, tks = 0, 0, []
|
||||
for tk, (freq, tag) in tfts:
|
||||
F += freq
|
||||
L += 0 if len(tk) < 2 else 1
|
||||
tks.append(tk)
|
||||
# F /= len(tks)
|
||||
L /= len(tks)
|
||||
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
|
||||
return tks, B / len(tks) + L + F
|
||||
|
||||
def sortTks_(self, tkslist):
|
||||
res = []
|
||||
for tfts in tkslist:
|
||||
tks, s = self.score_(tfts)
|
||||
res.append((tks, s))
|
||||
return sorted(res, key=lambda x: x[1], reverse=True)
|
||||
|
||||
def merge_(self, tks):
|
||||
# if split chars is part of token
|
||||
res = []
|
||||
tks = re.sub(r"[ ]+", " ", tks).split()
|
||||
s = 0
|
||||
while True:
|
||||
if s >= len(tks):
|
||||
break
|
||||
E = s + 1
|
||||
for e in range(s + 2, min(len(tks) + 2, s + 6)):
|
||||
tk = "".join(tks[s:e])
|
||||
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
|
||||
E = e
|
||||
res.append("".join(tks[s:E]))
|
||||
s = E
|
||||
|
||||
return " ".join(res)
|
||||
|
||||
def maxForward_(self, line):
|
||||
res = []
|
||||
s = 0
|
||||
while s < len(line):
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while e < len(line) and self.trie_.has_keys_with_prefix(self.key_(t)):
|
||||
e += 1
|
||||
t = line[s:e]
|
||||
|
||||
while e - 1 > s and self.key_(t) not in self.trie_:
|
||||
e -= 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, "")))
|
||||
|
||||
s = e
|
||||
|
||||
return self.score_(res)
|
||||
|
||||
def maxBackward_(self, line):
|
||||
res = []
|
||||
s = len(line) - 1
|
||||
while s >= 0:
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
|
||||
s -= 1
|
||||
t = line[s:e]
|
||||
|
||||
while s + 1 < e and self.key_(t) not in self.trie_:
|
||||
s += 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, "")))
|
||||
|
||||
s -= 1
|
||||
|
||||
return self.score_(res[::-1])
|
||||
|
||||
def _split_by_lang(self, line):
|
||||
"""根据语言进行切分"""
|
||||
txt_lang_pairs = []
|
||||
arr = re.split(self.SPLIT_CHAR, line)
|
||||
for a in arr:
|
||||
if not a:
|
||||
continue
|
||||
s = 0
|
||||
e = s + 1
|
||||
zh = is_chinese(a[s])
|
||||
while e < len(a):
|
||||
_zh = is_chinese(a[e])
|
||||
if _zh == zh:
|
||||
e += 1
|
||||
continue
|
||||
txt_lang_pairs.append((a[s:e], zh))
|
||||
s = e
|
||||
e = s + 1
|
||||
zh = _zh
|
||||
if s >= len(a):
|
||||
continue
|
||||
txt_lang_pairs.append((a[s:e], zh))
|
||||
return txt_lang_pairs
|
||||
|
||||
def tokenize(self, line: str) -> str:
|
||||
"""
|
||||
对输入文本进行分词,支持中英文混合处理。
|
||||
|
||||
分词流程:
|
||||
1. 预处理:
|
||||
- 将所有非单词字符(字母、数字、下划线以外的)替换为空格。
|
||||
- 全角字符转半角。
|
||||
- 转换为小写。
|
||||
- 繁体中文转简体中文。
|
||||
2. 按语言切分:
|
||||
- 将预处理后的文本按语言(中文/非中文)分割成多个片段。
|
||||
3. 分段处理:
|
||||
- 对于非中文(通常是英文)片段:
|
||||
- 使用 NLTK 的 `word_tokenize` 进行分词。
|
||||
- 对分词结果进行词干提取 (PorterStemmer) 和词形还原 (WordNetLemmatizer)。
|
||||
- 对于中文片段:
|
||||
- 如果片段过短(长度<2)或为纯粹的英文/数字模式(如 "abc-def", "123.45"),则直接保留该片段。
|
||||
- 否则,采用基于词典的混合分词策略:
|
||||
a. 执行正向最大匹配 (FMM) 和逆向最大匹配 (BMM) 得到两组分词结果 (`tks` 和 `tks1`)。
|
||||
b. 比较 FMM 和 BMM 的结果:
|
||||
i. 找到两者从开头开始最长的相同分词序列,这部分通常是无歧义的,直接加入结果。
|
||||
ii. 对于 FMM 和 BMM 结果不一致的歧义部分(即从第一个不同点开始的子串):
|
||||
- 提取出这段有歧义的原始文本。
|
||||
- 调用 `self.dfs_` (深度优先搜索) 在这段文本上探索所有可能的分词组合。
|
||||
- `self.dfs_` 会利用Trie词典,并由 `self.sortTks_` 对所有组合进行评分和排序。
|
||||
- 选择得分最高的分词方案作为该歧义段落的结果。
|
||||
iii.继续处理 FMM 和 BMM 结果中歧义段落之后的部分,重复步骤 i 和 ii,直到两个序列都处理完毕。
|
||||
c. 如果在比较完所有对应部分后,FMM 或 BMM 仍有剩余(理论上如果实现正确且输入相同,剩余部分也应相同),
|
||||
则对这部分剩余的原始文本同样使用 `self.dfs_` 进行最优分词。
|
||||
4. 后处理:
|
||||
- 将所有处理过的片段(英文词元、中文词元)用空格连接起来。
|
||||
- 调用 `self.merge_` 对连接后的结果进行进一步的合并操作,
|
||||
尝试合并一些可能被错误分割但实际是一个完整词的片段(基于词典检查)。
|
||||
5. 返回最终分词结果字符串(词元间用空格分隔)。
|
||||
|
||||
Args:
|
||||
line (str): 待分词的原始输入字符串。
|
||||
|
||||
Returns:
|
||||
str: 分词后的字符串,词元之间用空格分隔。
|
||||
"""
|
||||
# 1. 预处理
|
||||
line = re.sub(r"\W+", " ", line) # 将非字母数字下划线替换为空格
|
||||
line = self._strQ2B(line).lower() # 全角转半角,转小写
|
||||
line = self._tradi2simp(line) # 繁体转简体
|
||||
|
||||
# 2. 按语言切分
|
||||
arr = self._split_by_lang(line) # 将文本分割成 (文本片段, 是否为中文) 的列表
|
||||
res = [] # 存储最终分词结果的列表
|
||||
|
||||
# 3. 分段处理
|
||||
for L, lang in arr: # L 是文本片段,lang 是布尔值表示是否为中文
|
||||
if not lang: # 如果不是中文
|
||||
# 使用NLTK进行分词、词干提取和词形还原
|
||||
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
|
||||
continue # 处理下一个片段
|
||||
|
||||
# 如果是中文,但长度小于2或匹配纯英文/数字模式,则直接添加,不进一步切分
|
||||
if len(L) < 2 or re.match(r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
|
||||
res.append(L)
|
||||
continue # 处理下一个片段
|
||||
|
||||
# 对较长的中文片段执行FMM和BMM
|
||||
tks, s = self.maxForward_(L) # tks: FMM结果列表, s: FMM评分 FMM (Forward Maximum Matching - 正向最大匹配)
|
||||
tks1, s1 = self.maxBackward_(L) # tks1: BMM结果列表, s1: BMM评分 BMM (Backward Maximum Matching - 逆向最大匹配)
|
||||
|
||||
# 初始化用于比较FMM和BMM结果的指针
|
||||
i, j = 0, 0 # i 指向 tks1 (BMM), j 指向 tks (FMM)
|
||||
_i, _j = 0, 0 # _i, _j 记录上一段歧义处理的结束位置
|
||||
|
||||
# 3.b.i. 查找 FMM 和 BMM 从头开始的最长相同前缀
|
||||
same = 0 # 相同词元的数量
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
same += 1
|
||||
if same > 0: # 如果存在相同前缀
|
||||
res.append(" ".join(tks[j : j + same])) # 将FMM中的相同部分加入结果
|
||||
|
||||
# 更新指针到不同部分的开始
|
||||
_i = i + same
|
||||
_j = j + same
|
||||
# 准备开始处理可能存在的歧义部分
|
||||
j = _j + 1 # FMM指针向后移动一位(或多位,取决于下面tk的构造)
|
||||
i = _i + 1 # BMM指针向后移动一位(或多位)
|
||||
|
||||
# 3.b.ii. 迭代处理 FMM 和 BMM 结果中的歧义部分
|
||||
while i < len(tks1) and j < len(tks):
|
||||
# tk1 是 BMM 从上一个同步点 _i 到当前指针 i 形成的字符串
|
||||
# tk 是 FMM 从上一个同步点 _j 到当前指针 j 形成的字符串
|
||||
tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
|
||||
|
||||
if tk1 != tk: # 如果这两个子串不相同,说明FMM和BMM的切分路径出现分叉
|
||||
# 尝试通过移动较短子串的指针来寻找下一个可能的同步点
|
||||
if len(tk1) > len(tk):
|
||||
j += 1
|
||||
else:
|
||||
i += 1
|
||||
continue # 继续外层while循环
|
||||
|
||||
# 如果子串相同,但当前位置的单个词元不同,则这也是一个需要DFS解决的歧义点
|
||||
if tks1[i] != tks[j]: # 注意:这里比较的是tks1[i]和tks[j],而不是tk1和tk的最后一个词
|
||||
i += 1
|
||||
j += 1
|
||||
continue
|
||||
|
||||
# 从_j到j (不包括j处的词) 这段 FMM 产生的文本是歧义的,需要用DFS解决。
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:j]), 0, [], tkslist) # 对这段FMM子串进行DFS
|
||||
if tkslist: # 确保DFS有结果
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0])) # 取最优DFS结果
|
||||
|
||||
# 处理当前这个相同的词元 (tks[j] 或 tks1[i]) 以及之后连续相同的词元
|
||||
same = 1
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
same += 1
|
||||
res.append(" ".join(tks[j : j + same])) # 将FMM中从j开始的连续相同部分加入结果
|
||||
|
||||
# 更新指针到下一个不同部分的开始
|
||||
_i = i + same
|
||||
_j = j + same
|
||||
j = _j + 1
|
||||
i = _i + 1
|
||||
|
||||
# 3.c. 处理 FMM 或 BMM 可能的尾部剩余部分
|
||||
# 如果 _i (BMM的已处理指针) 还没有到达 tks1 的末尾
|
||||
# (并且假设 _j (FMM的已处理指针) 也未到 tks 的末尾,且剩余部分代表相同的原始文本)
|
||||
if _i < len(tks1):
|
||||
# 断言确保FMM的已处理指针也未到末尾
|
||||
assert _j < len(tks)
|
||||
# 断言FMM和BMM的剩余部分代表相同的原始字符串
|
||||
assert "".join(tks1[_i:]) == "".join(tks[_j:])
|
||||
# 对FMM的剩余部分(代表了原始文本的尾部)进行DFS分词
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
|
||||
if tkslist: # 确保DFS有结果
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
|
||||
# 4. 后处理
|
||||
res_str = " ".join(res) # 将所有分词结果用空格连接
|
||||
return self.merge_(res_str) # 返回经过合并处理的最终分词结果
|
||||
|
||||
|
||||
def is_chinese(s):
|
||||
if s >= "\u4e00" and s <= "\u9fa5":
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tknzr = RagTokenizer()
|
||||
tks = tknzr.tokenize("基于动态视觉相机的光流估计研究_孙文义.pdf")
|
||||
print(tks)
|
||||
tks = tknzr.tokenize("图3-1 事件流输入表征。(a)事件帧;(b)时间面;(c)体素网格\n(a)\n(b)\n(c)")
|
||||
print(tks)
|
File diff suppressed because it is too large
Load Diff
|
@ -35,22 +35,19 @@ def beAdoc(d, q, a, eng, row_num=-1):
|
|||
|
||||
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Excel and csv(txt) format files are supported.
|
||||
If the file is in excel format, there should be 2 column content and tags without header.
|
||||
And content column is ahead of tags column.
|
||||
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
||||
Excel and csv(txt) format files are supported.
|
||||
If the file is in excel format, there should be 2 column content and tags without header.
|
||||
And content column is ahead of tags column.
|
||||
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
||||
|
||||
If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate content and tags.
|
||||
If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate content and tags.
|
||||
|
||||
All the deformed lines will be ignored.
|
||||
Every pair will be treated as a chunk.
|
||||
All the deformed lines will be ignored.
|
||||
Every pair will be treated as a chunk.
|
||||
"""
|
||||
eng = lang.lower() == "english"
|
||||
res = []
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
|
||||
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
excel_parser = Excel()
|
||||
|
@ -83,11 +80,9 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|||
content = ""
|
||||
i += 1
|
||||
if len(res) % 999 == 0:
|
||||
callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
callback(0.6, ("Extract TAG: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
callback(0.6, ("Extract TAG: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
return res
|
||||
|
||||
|
@ -110,40 +105,61 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|||
res.append(beAdoc(deepcopy(doc), content, row[1], eng, i))
|
||||
content = ""
|
||||
if len(res) % 999 == 0:
|
||||
callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
callback(0.6, ("Extract TAG : {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
callback(0.6, ("Extract TAG : {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
return res
|
||||
|
||||
raise NotImplementedError(
|
||||
"Excel, csv(txt) format files are supported.")
|
||||
raise NotImplementedError("Excel, csv(txt) format files are supported.")
|
||||
|
||||
|
||||
def label_question(question, kbs):
|
||||
"""
|
||||
标记问题的标签。
|
||||
|
||||
该函数通过给定的问题和知识库列表,对问题进行标签标记。它首先确定哪些知识库配置了标签,
|
||||
然后从缓存中获取这些标签,必要时从设置中检索标签。最后,使用这些标签对问题进行标记。
|
||||
|
||||
参数:
|
||||
question (str): 需要标记的问题。
|
||||
kbs (list): 知识库对象列表,用于标签标记。
|
||||
|
||||
返回:
|
||||
list: 与问题相关的标签列表。
|
||||
"""
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
|
||||
from api import settings
|
||||
|
||||
# 初始化标签和标签知识库ID列表
|
||||
tags = None
|
||||
tag_kb_ids = []
|
||||
|
||||
# 遍历知识库,收集所有标签知识库ID
|
||||
for kb in kbs:
|
||||
if kb.parser_config.get("tag_kb_ids"):
|
||||
tag_kb_ids.extend(kb.parser_config["tag_kb_ids"])
|
||||
|
||||
# 如果存在标签知识库ID,则进一步处理
|
||||
if tag_kb_ids:
|
||||
# 尝试从缓存中获取所有标签
|
||||
all_tags = get_tags_from_cache(tag_kb_ids)
|
||||
|
||||
# 如果缓存中没有标签,从设置中检索标签,并设置缓存
|
||||
if not all_tags:
|
||||
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
||||
set_tags_to_cache(all_tags, tag_kb_ids)
|
||||
else:
|
||||
# 如果缓存中获取到标签,将其解析为JSON格式
|
||||
all_tags = json.loads(all_tags)
|
||||
|
||||
# 根据标签知识库ID获取对应的标签知识库
|
||||
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
|
||||
tags = settings.retrievaler.tag_query(question,
|
||||
list(set([kb.tenant_id for kb in tag_kbs])),
|
||||
tag_kb_ids,
|
||||
all_tags,
|
||||
kb.parser_config.get("topn_tags", 3)
|
||||
)
|
||||
|
||||
# 使用设置中的检索器对问题进行标签标记
|
||||
tags = settings.retrievaler.tag_query(question, list(set([kb.tenant_id for kb in tag_kbs])), tag_kb_ids, all_tags, kb.parser_config.get("topn_tags", 3))
|
||||
|
||||
# 返回标记的标签
|
||||
return tags
|
||||
|
||||
|
||||
|
@ -152,4 +168,5 @@ if __name__ == "__main__":
|
|||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
|
|
187
rag/nlp/query.py
187
rag/nlp/query.py
|
@ -53,6 +53,16 @@ class FulltextQueryer:
|
|||
|
||||
@staticmethod
|
||||
def rmWWW(txt):
|
||||
"""
|
||||
移除文本中的WWW(WHAT、WHO、WHERE等疑问词)。
|
||||
|
||||
本函数通过一系列正则表达式模式来识别并替换文本中的疑问词,以简化文本或为后续处理做准备。
|
||||
参数:
|
||||
- txt: 待处理的文本字符串。
|
||||
|
||||
返回:
|
||||
- 处理后的文本字符串,如果所有疑问词都被移除且文本为空,则返回原始文本。
|
||||
"""
|
||||
patts = [
|
||||
(
|
||||
r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
||||
|
@ -61,7 +71,8 @@ class FulltextQueryer:
|
|||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||
(
|
||||
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
|
||||
" ")
|
||||
" ",
|
||||
),
|
||||
]
|
||||
otxt = txt
|
||||
for r, p in patts:
|
||||
|
@ -70,28 +81,53 @@ class FulltextQueryer:
|
|||
txt = otxt
|
||||
return txt
|
||||
|
||||
@staticmethod
|
||||
def add_space_between_eng_zh(txt):
|
||||
"""
|
||||
在英文和中文之间添加空格。
|
||||
|
||||
该函数通过正则表达式匹配文本中英文和中文相邻的情况,并在它们之间插入空格。
|
||||
这样做可以改善文本的可读性,特别是在混合使用英文和中文时。
|
||||
|
||||
参数:
|
||||
txt (str): 需要处理的文本字符串。
|
||||
|
||||
返回:
|
||||
str: 处理后的文本字符串,其中英文和中文之间添加了空格。
|
||||
"""
|
||||
# (ENG/ENG+NUM) + ZH
|
||||
txt = re.sub(r"([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)", r"\1 \2", txt)
|
||||
# ENG + ZH
|
||||
txt = re.sub(r"([A-Za-z])([\u4e00-\u9fa5]+)", r"\1 \2", txt)
|
||||
# ZH + (ENG/ENG+NUM)
|
||||
txt = re.sub(r"([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)", r"\1 \2", txt)
|
||||
txt = re.sub(r"([\u4e00-\u9fa5]+)([A-Za-z])", r"\1 \2", txt)
|
||||
return txt
|
||||
|
||||
def question(self, txt, tbl="qa", min_match: float = 0.6):
|
||||
"""
|
||||
处理用户问题并生成全文检索表达式
|
||||
|
||||
根据输入的文本生成查询表达式,用于在数据库中匹配相关问题。
|
||||
|
||||
参数:
|
||||
txt: 原始问题文本
|
||||
tbl: 查询表名(默认"qa")
|
||||
min_match: 最小匹配阈值(默认0.6)
|
||||
|
||||
- txt (str): 输入的文本。
|
||||
- tbl (str): 数据表名,默认为"qa"。
|
||||
- min_match (float): 最小匹配度,默认为0.6。
|
||||
|
||||
返回:
|
||||
MatchTextExpr: 全文检索表达式对象
|
||||
list: 提取的关键词列表
|
||||
- MatchTextExpr: 生成的查询表达式对象。
|
||||
- keywords (list): 提取的关键词列表。
|
||||
"""
|
||||
# 1. 文本预处理:去除特殊字符、繁体转简体、全角转半角、转小写
|
||||
txt = FulltextQueryer.add_space_between_eng_zh(txt) # 在英文和中文之间添加空格
|
||||
# 使用正则表达式替换特殊字符为单个空格,并将文本转换为简体中文和小写
|
||||
txt = re.sub(
|
||||
r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
|
||||
" ",
|
||||
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
|
||||
).strip()
|
||||
txt = FulltextQueryer.rmWWW(txt) # 去除停用词
|
||||
otxt = txt
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
|
||||
# 2. 非中文文本处理
|
||||
# 如果文本不是中文,则进行英文处理
|
||||
if not self.isChinese(txt):
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
tks = rag_tokenizer.tokenize(txt).split()
|
||||
|
@ -106,11 +142,10 @@ class FulltextQueryer:
|
|||
syn = self.syn.lookup(tk)
|
||||
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
|
||||
keywords.extend(syn)
|
||||
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
|
||||
syn = ['"{}"^{:.4f}'.format(s, w / 4.0) for s in syn if s.strip()]
|
||||
syns.append(" ".join(syn))
|
||||
|
||||
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
|
||||
tk and not re.match(r"[.^+\(\)-]", tk)]
|
||||
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if tk and not re.match(r"[.^+\(\)-]", tk)]
|
||||
for i in range(1, len(tks_w)):
|
||||
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
|
||||
if not left or not right:
|
||||
|
@ -126,48 +161,53 @@ class FulltextQueryer:
|
|||
if not q:
|
||||
q.append(txt)
|
||||
query = " ".join(q)
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100
|
||||
), keywords
|
||||
return MatchTextExpr(self.query_fields, query, 100), keywords
|
||||
|
||||
def need_fine_grained_tokenize(tk):
|
||||
"""
|
||||
判断是否需要细粒度分词
|
||||
判断是否需要对词进行细粒度分词。
|
||||
|
||||
参数:
|
||||
tk: 待判断的词条
|
||||
- tk (str): 待判断的词。
|
||||
|
||||
返回:
|
||||
bool: True表示需要细粒度分词
|
||||
- bool: 是否需要进行细粒度分词。
|
||||
"""
|
||||
# 长度小于3的词不处理
|
||||
if len(tk) < 3:
|
||||
return False
|
||||
# 匹配特定模式的词不处理(如数字、字母、符号组合)
|
||||
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
|
||||
return False
|
||||
return True
|
||||
|
||||
txt = FulltextQueryer.rmWWW(txt) # 二次去除停用词
|
||||
qs, keywords = [], [] # 初始化查询表达式和关键词列表
|
||||
# 3. 中文文本处理(最多处理256个词)
|
||||
for tt in self.tw.split(txt)[:256]: # .split():
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
qs, keywords = [], []
|
||||
# 遍历文本分割后的前256个片段(防止处理过长文本)
|
||||
for tt in self.tw.split(txt)[:256]: # 注:这个split似乎是对英文设计,中文不起作用
|
||||
if not tt:
|
||||
continue
|
||||
# 3.1 基础关键词收集
|
||||
# 将当前片段加入关键词列表
|
||||
keywords.append(tt)
|
||||
twts = self.tw.weights([tt]) # 获取词权重
|
||||
syns = self.syn.lookup(tt) # 查询同义词
|
||||
# 3.2 同义词扩展(最多扩展到32个关键词)
|
||||
# 获取当前片段的权重
|
||||
twts = self.tw.weights([tt])
|
||||
# 查找同义词
|
||||
syns = self.syn.lookup(tt)
|
||||
# 如果有同义词且关键词数量未超过32,将同义词加入关键词列表
|
||||
if syns and len(keywords) < 32:
|
||||
keywords.extend(syns)
|
||||
# 调试日志:输出权重信息
|
||||
logging.debug(json.dumps(twts, ensure_ascii=False))
|
||||
# 初始化查询条件列表
|
||||
tms = []
|
||||
# 3.3 处理每个词及其权重
|
||||
# 按权重降序排序处理每个token
|
||||
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
||||
# 3.3.1 细粒度分词处理
|
||||
sm = (
|
||||
rag_tokenizer.fine_grained_tokenize(tk).split()
|
||||
if need_fine_grained_tokenize(tk)
|
||||
else []
|
||||
)
|
||||
# 3.3.2 清洗分词结果
|
||||
# 如果需要细粒度分词,则进行分词处理
|
||||
sm = rag_tokenizer.fine_grained_tokenize(tk).split() if need_fine_grained_tokenize(tk) else []
|
||||
# 对每个分词结果进行清洗:
|
||||
# 1. 去除标点符号和特殊字符
|
||||
# 2. 使用subSpecialChar进一步处理
|
||||
# 3. 过滤掉长度<=1的词
|
||||
sm = [
|
||||
re.sub(
|
||||
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
||||
|
@ -178,59 +218,65 @@ class FulltextQueryer:
|
|||
]
|
||||
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
# 3.3.3 收集关键词(不超过32个)
|
||||
|
||||
# 如果关键词数量未达上限,添加处理后的token和分词结果
|
||||
if len(keywords) < 32:
|
||||
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
||||
keywords.extend(sm)
|
||||
|
||||
# 3.3.4 同义词处理
|
||||
keywords.append(re.sub(r"[ \\\"']+", "", tk)) # 去除转义字符
|
||||
keywords.extend(sm) # 添加分词结果
|
||||
# 获取当前token的同义词并进行处理
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
# 添加有效同义词到关键词列表
|
||||
if len(keywords) < 32:
|
||||
keywords.extend([s for s in tk_syns if s])
|
||||
# 对同义词进行分词处理,并为包含空格的同义词添加引号
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
# 关键词数量限制
|
||||
tk_syns = [f'"{s}"' if s.find(" ") > 0 else s for s in tk_syns]
|
||||
|
||||
# 关键词数量达到上限则停止处理
|
||||
if len(keywords) >= 32:
|
||||
break
|
||||
|
||||
# 3.3.5 构建查询表达式
|
||||
|
||||
# 处理当前token用于构建查询条件:
|
||||
# 1. 特殊字符处理
|
||||
# 2. 为包含空格的token添加引号
|
||||
# 3. 如果有同义词,构建OR条件并降低权重
|
||||
# 4. 如果有分词结果,添加OR条件
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk # 处理短语查询
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns) # 添加同义词查询
|
||||
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
|
||||
if sm:
|
||||
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm)) # 添加细粒度分词查询
|
||||
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
|
||||
if tk.strip():
|
||||
tms.append((tk, w)) # 保存带权重的查询表达式
|
||||
|
||||
# 3.4 合并当前词的查询表达式
|
||||
tms.append((tk, w))
|
||||
|
||||
# 将处理后的查询条件按权重组合成字符串
|
||||
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
||||
|
||||
# 3.5 添加相邻词组合查询(提升短语匹配权重)
|
||||
# 如果有多个权重项,添加短语搜索条件(提高相邻词匹配的权重)
|
||||
if len(twts) > 1:
|
||||
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
|
||||
|
||||
# 3.6 处理同义词查询表达式
|
||||
syns = " OR ".join(
|
||||
[
|
||||
'"%s"'
|
||||
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
|
||||
for s in syns
|
||||
]
|
||||
)
|
||||
# 处理同义词的查询条件
|
||||
syns = " OR ".join(['"%s"' % rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s)) for s in syns])
|
||||
# 组合主查询条件和同义词条件
|
||||
if syns and tms:
|
||||
tms = f"({tms})^5 OR ({syns})^0.7"
|
||||
# 将最终查询条件加入列表
|
||||
qs.append(tms)
|
||||
|
||||
qs.append(tms) # 添加到最终查询列表
|
||||
|
||||
# 4. 生成最终查询表达式
|
||||
if qs:
|
||||
# 处理所有查询条件
|
||||
if qs:
|
||||
# 组合所有查询条件为OR关系
|
||||
query = " OR ".join([f"({t})" for t in qs if t])
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100, {"minimum_should_match": min_match}
|
||||
), keywords
|
||||
# 如果查询条件为空,使用原始文本
|
||||
if not query:
|
||||
query = otxt
|
||||
# 返回匹配文本表达式和关键词
|
||||
return MatchTextExpr(self.query_fields, query, 100, {"minimum_should_match": min_match}), keywords
|
||||
# 如果没有生成查询条件,只返回关键词
|
||||
return None, keywords
|
||||
|
||||
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
|
||||
|
@ -282,7 +328,7 @@ class FulltextQueryer:
|
|||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
tk_syns = [f'"{s}"' if s.find(" ") > 0 else s for s in tk_syns]
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
|
@ -291,5 +337,4 @@ class FulltextQueryer:
|
|||
if tk:
|
||||
keywords.append(f"{tk}^{w}")
|
||||
|
||||
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
|
||||
{"minimum_should_match": min(3, len(keywords) / 10)})
|
||||
return MatchTextExpr(self.query_fields, " ".join(keywords), 100, {"minimum_should_match": min(3, len(keywords) / 10)})
|
||||
|
|
|
@ -22,9 +22,12 @@ import os
|
|||
import re
|
||||
import string
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from hanziconv import HanziConv
|
||||
from nltk import word_tokenize
|
||||
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
|
@ -38,7 +41,7 @@ class RagTokenizer:
|
|||
def loadDict_(self, fnm):
|
||||
logging.info(f"[HUQIE]:Build trie from {fnm}")
|
||||
try:
|
||||
of = open(fnm, "r", encoding='utf-8')
|
||||
of = open(fnm, "r", encoding="utf-8")
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
|
@ -46,7 +49,7 @@ class RagTokenizer:
|
|||
line = re.sub(r"[\r\n]+", "", line)
|
||||
line = re.split(r"[ \t]", line)
|
||||
k = self.key_(line[0])
|
||||
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
|
||||
F = int(math.log(float(line[1]) / self.DENOMINATOR) + 0.5)
|
||||
if k not in self.trie_ or self.trie_[k][0] < F:
|
||||
self.trie_[self.key_(line[0])] = (F, line[2])
|
||||
self.trie_[self.rkey_(line[0])] = 1
|
||||
|
@ -106,8 +109,8 @@ class RagTokenizer:
|
|||
if inside_code == 0x3000:
|
||||
inside_code = 0x0020
|
||||
else:
|
||||
inside_code -= 0xfee0
|
||||
if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
|
||||
inside_code -= 0xFEE0
|
||||
if inside_code < 0x0020 or inside_code > 0x7E: # After the conversion, if it's not a half-width character, return the original character.
|
||||
rstring += uchar
|
||||
else:
|
||||
rstring += chr(inside_code)
|
||||
|
@ -126,13 +129,11 @@ class RagTokenizer:
|
|||
# pruning
|
||||
S = s + 1
|
||||
if s + 2 <= len(chars):
|
||||
t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
|
||||
self.key_(t2)):
|
||||
t1, t2 = "".join(chars[s : s + 1]), "".join(chars[s : s + 2])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
|
||||
S = s + 2
|
||||
if len(preTks) > 2 and len(
|
||||
preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
||||
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
|
||||
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
||||
t1 = preTks[-1][0] + "".join(chars[s : s + 1])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
||||
S = s + 2
|
||||
|
||||
|
@ -149,18 +150,18 @@ class RagTokenizer:
|
|||
if k in self.trie_:
|
||||
pretks.append((t, self.trie_[k]))
|
||||
else:
|
||||
pretks.append((t, (-12, '')))
|
||||
pretks.append((t, (-12, "")))
|
||||
res = max(res, self.dfs_(chars, e, pretks, tkslist))
|
||||
|
||||
if res > s:
|
||||
return res
|
||||
|
||||
t = "".join(chars[s:s + 1])
|
||||
t = "".join(chars[s : s + 1])
|
||||
k = self.key_(t)
|
||||
if k in self.trie_:
|
||||
preTks.append((t, self.trie_[k]))
|
||||
else:
|
||||
preTks.append((t, (-12, '')))
|
||||
preTks.append((t, (-12, "")))
|
||||
|
||||
return self.dfs_(chars, s + 1, preTks, tkslist)
|
||||
|
||||
|
@ -183,7 +184,7 @@ class RagTokenizer:
|
|||
F += freq
|
||||
L += 0 if len(tk) < 2 else 1
|
||||
tks.append(tk)
|
||||
#F /= len(tks)
|
||||
# F /= len(tks)
|
||||
L /= len(tks)
|
||||
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
|
||||
return tks, B / len(tks) + L + F
|
||||
|
@ -219,8 +220,7 @@ class RagTokenizer:
|
|||
while s < len(line):
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while e < len(line) and self.trie_.has_keys_with_prefix(
|
||||
self.key_(t)):
|
||||
while e < len(line) and self.trie_.has_keys_with_prefix(self.key_(t)):
|
||||
e += 1
|
||||
t = line[s:e]
|
||||
|
||||
|
@ -231,7 +231,7 @@ class RagTokenizer:
|
|||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, '')))
|
||||
res.append((t, (0, "")))
|
||||
|
||||
s = e
|
||||
|
||||
|
@ -254,7 +254,7 @@ class RagTokenizer:
|
|||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, '')))
|
||||
res.append((t, (0, "")))
|
||||
|
||||
s -= 1
|
||||
|
||||
|
@ -277,13 +277,13 @@ class RagTokenizer:
|
|||
if _zh == zh:
|
||||
e += 1
|
||||
continue
|
||||
txt_lang_pairs.append((a[s: e], zh))
|
||||
txt_lang_pairs.append((a[s:e], zh))
|
||||
s = e
|
||||
e = s + 1
|
||||
zh = _zh
|
||||
if s >= len(a):
|
||||
continue
|
||||
txt_lang_pairs.append((a[s: e], zh))
|
||||
txt_lang_pairs.append((a[s:e], zh))
|
||||
return txt_lang_pairs
|
||||
|
||||
def tokenize(self, line):
|
||||
|
@ -293,12 +293,11 @@ class RagTokenizer:
|
|||
|
||||
arr = self._split_by_lang(line)
|
||||
res = []
|
||||
for L,lang in arr:
|
||||
for L, lang in arr:
|
||||
if not lang:
|
||||
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
|
||||
continue
|
||||
if len(L) < 2 or re.match(
|
||||
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
|
||||
if len(L) < 2 or re.match(r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
|
||||
res.append(L)
|
||||
continue
|
||||
|
||||
|
@ -314,7 +313,7 @@ class RagTokenizer:
|
|||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
same += 1
|
||||
if same > 0:
|
||||
res.append(" ".join(tks[j: j + same]))
|
||||
res.append(" ".join(tks[j : j + same]))
|
||||
_i = i + same
|
||||
_j = j + same
|
||||
j = _j + 1
|
||||
|
@ -341,7 +340,7 @@ class RagTokenizer:
|
|||
same = 1
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
same += 1
|
||||
res.append(" ".join(tks[j: j + same]))
|
||||
res.append(" ".join(tks[j : j + same]))
|
||||
_i = i + same
|
||||
_j = j + same
|
||||
j = _j + 1
|
||||
|
@ -359,31 +358,64 @@ class RagTokenizer:
|
|||
return self.merge_(res)
|
||||
|
||||
def fine_grained_tokenize(self, tks):
|
||||
"""
|
||||
细粒度分词方法,根据文本特征(中英文比例、数字符号等)动态选择分词策略
|
||||
|
||||
参数:
|
||||
tks (str): 待分词的文本字符串
|
||||
|
||||
返回:
|
||||
str: 分词后的结果(用空格连接的词序列)
|
||||
|
||||
处理逻辑:
|
||||
1. 先按空格初步切分文本
|
||||
2. 根据中文占比决定是否启用细粒度分词
|
||||
3. 对特殊格式(短词、纯数字等)直接保留原样
|
||||
4. 对长词或复杂词使用DFS回溯算法寻找最优切分
|
||||
5. 对英文词进行额外校验和规范化处理
|
||||
"""
|
||||
# 初始切分:按空格分割输入文本
|
||||
tks = tks.split()
|
||||
# 计算中文词占比(判断是否主要包含中文内容)
|
||||
zh_num = len([1 for c in tks if c and is_chinese(c[0])])
|
||||
# 如果中文占比低于20%,则按简单规则处理(主要处理英文混合文本)
|
||||
if zh_num < len(tks) * 0.2:
|
||||
res = []
|
||||
for tk in tks:
|
||||
res.extend(tk.split("/"))
|
||||
return " ".join(res)
|
||||
|
||||
# 中文或复杂文本处理流程
|
||||
res = []
|
||||
for tk in tks:
|
||||
# 规则1:跳过短词(长度<3)或纯数字/符号组合(如"3.14")
|
||||
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
|
||||
res.append(tk)
|
||||
continue
|
||||
|
||||
# 初始化候选分词列表
|
||||
tkslist = []
|
||||
|
||||
# 规则2:超长词(长度>10)直接保留不切分
|
||||
if len(tk) > 10:
|
||||
tkslist.append(tk)
|
||||
else:
|
||||
# 使用DFS回溯算法寻找所有可能的分词组合
|
||||
self.dfs_(tk, 0, [], tkslist)
|
||||
|
||||
# 规则3:若无有效切分方案则保留原词
|
||||
if len(tkslist) < 2:
|
||||
res.append(tk)
|
||||
continue
|
||||
|
||||
# 从候选方案中选择最优切分(通过sortTks_排序)
|
||||
stk = self.sortTks_(tkslist)[1][0]
|
||||
|
||||
# 规则4:若切分结果与原词长度相同则视为无效切分
|
||||
if len(stk) == len(tk):
|
||||
stk = tk
|
||||
else:
|
||||
# 英文特殊处理:检查子词长度是否合法
|
||||
if re.match(r"[a-z\.-]+$", tk):
|
||||
for t in stk:
|
||||
if len(t) < 3:
|
||||
|
@ -393,29 +425,28 @@ class RagTokenizer:
|
|||
stk = " ".join(stk)
|
||||
else:
|
||||
stk = " ".join(stk)
|
||||
|
||||
# 中文词直接拼接结果
|
||||
res.append(stk)
|
||||
|
||||
return " ".join(self.english_normalize_(res))
|
||||
|
||||
|
||||
def is_chinese(s):
|
||||
if s >= u'\u4e00' and s <= u'\u9fa5':
|
||||
if s >= "\u4e00" and s <= "\u9fa5":
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_number(s):
|
||||
if s >= u'\u0030' and s <= u'\u0039':
|
||||
if s >= "\u0030" and s <= "\u0039":
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_alphabet(s):
|
||||
if (s >= u'\u0041' and s <= u'\u005a') or (
|
||||
s >= u'\u0061' and s <= u'\u007a'):
|
||||
if (s >= "\u0041" and s <= "\u005a") or (s >= "\u0061" and s <= "\u007a"):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
@ -424,8 +455,7 @@ def is_alphabet(s):
|
|||
def naiveQie(txt):
|
||||
tks = []
|
||||
for t in txt.split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
|
||||
) and re.match(r".*[a-zA-Z]$", t):
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and re.match(r".*[a-zA-Z]$", t):
|
||||
tks.append(" ")
|
||||
tks.append(t)
|
||||
return tks
|
||||
|
@ -441,43 +471,41 @@ addUserDict = tokenizer.addUserDict
|
|||
tradi2simp = tokenizer._tradi2simp
|
||||
strQ2B = tokenizer._strQ2B
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
tknzr = RagTokenizer(debug=True)
|
||||
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
|
||||
tks = tknzr.tokenize("哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
|
||||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
print(tks)
|
||||
tks = tknzr.tokenize(
|
||||
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("虽然我不怎么玩")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
if len(sys.argv) < 2:
|
||||
sys.exit()
|
||||
tknzr.DEBUG = False
|
||||
tknzr.loadUserDict(sys.argv[1])
|
||||
of = open(sys.argv[2], "r")
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
logging.info(tknzr.tokenize(line))
|
||||
of.close()
|
||||
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。"
|
||||
)
|
||||
print(tks)
|
||||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
# tks = tknzr.tokenize("多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
|
||||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
# tks = tknzr.tokenize("实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
|
||||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
# tks = tknzr.tokenize("虽然我不怎么玩")
|
||||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
# tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
|
||||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
# tks = tknzr.tokenize("涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
|
||||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
# tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
|
||||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
# tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
|
||||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
# tks = tknzr.tokenize("数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
|
||||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
# if len(sys.argv) < 2:
|
||||
# sys.exit()
|
||||
# tknzr.DEBUG = False
|
||||
# tknzr.loadUserDict(sys.argv[1])
|
||||
# of = open(sys.argv[2], "r")
|
||||
# while True:
|
||||
# line = of.readline()
|
||||
# if not line:
|
||||
# break
|
||||
# logging.info(tknzr.tokenize(line))
|
||||
# of.close()
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#
|
||||
import logging
|
||||
import re
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||
|
@ -24,7 +25,8 @@ import numpy as np
|
|||
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
|
||||
|
||||
|
||||
def index_name(uid): return f"ragflow_{uid}"
|
||||
def index_name(uid):
|
||||
return f"ragflow_{uid}"
|
||||
|
||||
|
||||
class Dealer:
|
||||
|
@ -47,11 +49,10 @@ class Dealer:
|
|||
qv, _ = emb_mdl.encode_queries(txt)
|
||||
shape = np.array(qv).shape
|
||||
if len(shape) > 1:
|
||||
raise Exception(
|
||||
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
|
||||
raise Exception(f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
|
||||
embedding_data = [float(v) for v in qv]
|
||||
vector_column_name = f"q_{len(embedding_data)}_vec"
|
||||
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
|
||||
return MatchDenseExpr(vector_column_name, embedding_data, "float", "cosine", topk, {"similarity": similarity})
|
||||
|
||||
def get_filters(self, req):
|
||||
condition = dict()
|
||||
|
@ -64,15 +65,10 @@ class Dealer:
|
|||
condition[key] = req[key]
|
||||
return condition
|
||||
|
||||
def search(self, req, idx_names: str | list[str],
|
||||
kb_ids: list[str],
|
||||
emb_mdl=None,
|
||||
highlight=False,
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
def search(self, req, idx_names: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False, rank_feature: dict | None = None):
|
||||
"""
|
||||
执行混合检索(全文检索+向量检索)
|
||||
|
||||
|
||||
参数:
|
||||
req: 请求参数字典,包含:
|
||||
- page: 页码
|
||||
|
@ -86,7 +82,7 @@ class Dealer:
|
|||
emb_mdl: 嵌入模型,用于向量检索
|
||||
highlight: 是否返回高亮内容
|
||||
rank_feature: 排序特征配置
|
||||
|
||||
|
||||
返回:
|
||||
SearchResult对象,包含:
|
||||
- total: 匹配总数
|
||||
|
@ -106,20 +102,39 @@ class Dealer:
|
|||
topk = int(req.get("topk", 1024))
|
||||
ps = int(req.get("size", topk))
|
||||
offset, limit = pg * ps, ps
|
||||
|
||||
|
||||
# 3. 设置返回字段(默认包含文档名、内容等核心字段)
|
||||
src = req.get("fields",
|
||||
["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int",
|
||||
"doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd",
|
||||
"question_kwd", "question_tks",
|
||||
"available_int", "content_with_weight", PAGERANK_FLD, TAG_FLD])
|
||||
kwds = set([])
|
||||
src = req.get(
|
||||
"fields",
|
||||
[
|
||||
"docnm_kwd",
|
||||
"content_ltks",
|
||||
"kb_id",
|
||||
"img_id",
|
||||
"title_tks",
|
||||
"important_kwd",
|
||||
"position_int",
|
||||
"doc_id",
|
||||
"page_num_int",
|
||||
"top_int",
|
||||
"create_timestamp_flt",
|
||||
"knowledge_graph_kwd",
|
||||
"question_kwd",
|
||||
"question_tks",
|
||||
"available_int",
|
||||
"content_with_weight",
|
||||
PAGERANK_FLD,
|
||||
TAG_FLD,
|
||||
],
|
||||
)
|
||||
kwds = set([]) # 初始化关键词集合
|
||||
|
||||
# 4. 处理查询问题
|
||||
qst = req.get("question", "")
|
||||
q_vec = []
|
||||
qst = req.get("question", "") # 获取查询问题文本
|
||||
print(f"收到前端问题:{qst}")
|
||||
q_vec = [] # 初始化查询向量(如需向量检索)
|
||||
if not qst:
|
||||
# 4.1 无查询文本时的处理(按文档排序)
|
||||
# 4.1 若查询文本为空,执行默认排序检索(通常用于无搜索条件浏览)(注:前端测试检索时会禁止空文本的提交)
|
||||
if req.get("sort"):
|
||||
orderBy.asc("page_num_int")
|
||||
orderBy.asc("top_int")
|
||||
|
@ -128,44 +143,58 @@ class Dealer:
|
|||
total = self.dataStore.getTotal(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
# 4.2 有查询文本时的处理
|
||||
highlightFields = ["content_ltks", "title_tks"] if highlight else []
|
||||
|
||||
# 4.2 若存在查询文本,进入全文/混合检索流程
|
||||
highlightFields = ["content_ltks", "title_tks"] if highlight else [] # highlight当前会一直为False,不起作用
|
||||
# 4.2.1 生成全文检索表达式和关键词
|
||||
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
||||
print(f"matchText.matching_text: {matchText.matching_text}")
|
||||
print(f"keywords: {keywords}\n")
|
||||
if emb_mdl is None:
|
||||
# 4.2.2 纯全文检索模式
|
||||
# 4.2.2 纯全文检索模式 (未提供向量模型,正常情况不会进入)
|
||||
matchExprs = [matchText]
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
# 4.2.3 混合检索模式(全文+向量)
|
||||
# 生成查询向量
|
||||
# 4.2.3 混合检索模式(全文+向量)
|
||||
# 生成查询向量
|
||||
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
||||
q_vec = matchDense.embedding_data
|
||||
# 在返回字段中加入查询向量字段
|
||||
src.append(f"q_{len(q_vec)}_vec")
|
||||
# 设置混合检索权重(全文5% + 向量95%)
|
||||
# 创建融合表达式:设置向量匹配为95%,全文为5%(可以调整权重)
|
||||
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
|
||||
# 构建混合查询表达式
|
||||
matchExprs = [matchText, matchDense, fusionExpr]
|
||||
|
||||
# 执行混合检索
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
|
||||
# If result is empty, try again with lower min_match
|
||||
if total == 0:
|
||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||
filters.pop("doc_ids", None)
|
||||
matchDense.extra_options["similarity"] = 0.17
|
||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
||||
print(f"共查询到: {total} 条信息")
|
||||
# print(f"查询信息结果: {res}\n")
|
||||
|
||||
# 若未找到结果,则尝试降低匹配门槛后重试
|
||||
if total == 0:
|
||||
if filters.get("doc_id"):
|
||||
# 有特定文档ID时执行无条件查询
|
||||
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total = self.dataStore.getTotal(res)
|
||||
print(f"针对选中文档,共查询到: {total} 条信息")
|
||||
# print(f"查询信息结果: {res}\n")
|
||||
else:
|
||||
# 否则调整全文和向量匹配参数再次搜索
|
||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||
filters.pop("doc_id", None)
|
||||
matchDense.extra_options["similarity"] = 0.17
|
||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
||||
print(f"再次查询,共查询到: {total} 条信息")
|
||||
# print(f"查询信息结果: {res}\n")
|
||||
|
||||
# 4.3 处理关键词(对关键词进行更细粒度的切词)
|
||||
for k in keywords:
|
||||
kwds.add(k)
|
||||
for kk in rag_tokenizer.fine_grained_tokenize(k).split():
|
||||
|
@ -175,27 +204,23 @@ class Dealer:
|
|||
continue
|
||||
kwds.add(kk)
|
||||
|
||||
# 5. 提取检索结果中的ID、字段、聚合和高亮信息
|
||||
logging.debug(f"TOTAL: {total}")
|
||||
ids = self.dataStore.getChunkIds(res)
|
||||
keywords = list(kwds)
|
||||
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
|
||||
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
|
||||
return self.SearchResult(
|
||||
total=total,
|
||||
ids=ids,
|
||||
query_vector=q_vec,
|
||||
aggregation=aggs,
|
||||
highlight=highlight,
|
||||
field=self.dataStore.getFields(res, src),
|
||||
keywords=keywords
|
||||
)
|
||||
ids = self.dataStore.getChunkIds(res) # 提取匹配chunk的ID
|
||||
keywords = list(kwds) # 转为列表格式返回
|
||||
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") # 获取高亮内容
|
||||
aggs = self.dataStore.getAggregation(res, "docnm_kwd") # 执行基于文档名的聚合分析
|
||||
print(f"ids:{ids}")
|
||||
print(f"keywords:{keywords}")
|
||||
print(f"highlight:{highlight}")
|
||||
print(f"aggs:{aggs}")
|
||||
return self.SearchResult(total=total, ids=ids, query_vector=q_vec, aggregation=aggs, highlight=highlight, field=self.dataStore.getFields(res, src), keywords=keywords)
|
||||
|
||||
@staticmethod
|
||||
def trans2floats(txt):
|
||||
return [float(t) for t in txt.split("\t")]
|
||||
|
||||
def insert_citations(self, answer, chunks, chunk_v,
|
||||
embd_mdl, tkweight=0.1, vtweight=0.9):
|
||||
def insert_citations(self, answer, chunks, chunk_v, embd_mdl, tkweight=0.1, vtweight=0.9):
|
||||
assert len(chunks) == len(chunk_v)
|
||||
if not chunks:
|
||||
return answer, set([])
|
||||
|
@ -211,12 +236,9 @@ class Dealer:
|
|||
i += 1
|
||||
if i < len(pieces):
|
||||
i += 1
|
||||
pieces_.append("".join(pieces[st: i]) + "\n")
|
||||
pieces_.append("".join(pieces[st:i]) + "\n")
|
||||
else:
|
||||
pieces_.extend(
|
||||
re.split(
|
||||
r"([^\|][;。?!!\n]|[a-z][.?;!][ \n])",
|
||||
pieces[i]))
|
||||
pieces_.extend(re.split(r"([^\|][;。?!!\n]|[a-z][.?;!][ \n])", pieces[i]))
|
||||
i += 1
|
||||
pieces = pieces_
|
||||
else:
|
||||
|
@ -239,30 +261,22 @@ class Dealer:
|
|||
ans_v, _ = embd_mdl.encode(pieces_)
|
||||
for i in range(len(chunk_v)):
|
||||
if len(ans_v[0]) != len(chunk_v[i]):
|
||||
chunk_v[i] = [0.0]*len(ans_v[0])
|
||||
chunk_v[i] = [0.0] * len(ans_v[0])
|
||||
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
|
||||
|
||||
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
||||
len(ans_v[0]), len(chunk_v[0]))
|
||||
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[0]))
|
||||
|
||||
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split()
|
||||
for ck in chunks]
|
||||
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split() for ck in chunks]
|
||||
cites = {}
|
||||
thr = 0.63
|
||||
while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
|
||||
for i, a in enumerate(pieces_):
|
||||
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
|
||||
chunk_v,
|
||||
rag_tokenizer.tokenize(
|
||||
self.qryr.rmWWW(pieces_[i])).split(),
|
||||
chunks_tks,
|
||||
tkweight, vtweight)
|
||||
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], chunk_v, rag_tokenizer.tokenize(self.qryr.rmWWW(pieces_[i])).split(), chunks_tks, tkweight, vtweight)
|
||||
mx = np.max(sim) * 0.99
|
||||
logging.debug("{} SIM: {}".format(pieces_[i], mx))
|
||||
if mx < thr:
|
||||
continue
|
||||
cites[idx[i]] = list(
|
||||
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
|
||||
cites[idx[i]] = list(set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
|
||||
thr *= 0.8
|
||||
|
||||
res = ""
|
||||
|
@ -294,7 +308,7 @@ class Dealer:
|
|||
if not query_rfea:
|
||||
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
|
||||
|
||||
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
|
||||
q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
|
||||
for i in search_res.ids:
|
||||
nor, denor = 0, 0
|
||||
for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items():
|
||||
|
@ -304,13 +318,10 @@ class Dealer:
|
|||
if denor == 0:
|
||||
rank_fea.append(0)
|
||||
else:
|
||||
rank_fea.append(nor/np.sqrt(denor)/q_denor)
|
||||
return np.array(rank_fea)*10. + pageranks
|
||||
rank_fea.append(nor / np.sqrt(denor) / q_denor)
|
||||
return np.array(rank_fea) * 10.0 + pageranks
|
||||
|
||||
def rerank(self, sres, query, tkweight=0.3,
|
||||
vtweight=0.7, cfield="content_ltks",
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks", rank_feature: dict | None = None):
|
||||
_, keywords = self.qryr.question(query)
|
||||
vector_size = len(sres.query_vector)
|
||||
vector_column = f"q_{vector_size}_vec"
|
||||
|
@ -339,16 +350,11 @@ class Dealer:
|
|||
## For rank feature(tag_fea) scores.
|
||||
rank_fea = self._rank_feature_scores(rank_feature, sres)
|
||||
|
||||
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
|
||||
ins_embd,
|
||||
keywords,
|
||||
ins_tw, tkweight, vtweight)
|
||||
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, ins_embd, keywords, ins_tw, tkweight, vtweight)
|
||||
|
||||
return sim + rank_fea, tksim, vtsim
|
||||
|
||||
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
|
||||
vtweight=0.7, cfield="content_ltks",
|
||||
rank_feature: dict | None = None):
|
||||
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks", rank_feature: dict | None = None):
|
||||
_, keywords = self.qryr.question(query)
|
||||
|
||||
for i in sres.ids:
|
||||
|
@ -367,21 +373,31 @@ class Dealer:
|
|||
## For rank feature(tag_fea) scores.
|
||||
rank_fea = self._rank_feature_scores(rank_feature, sres)
|
||||
|
||||
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
|
||||
return tkweight * (np.array(tksim) + rank_fea) + vtweight * vtsim, tksim, vtsim
|
||||
|
||||
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
||||
return self.qryr.hybrid_similarity(ans_embd,
|
||||
ins_embd,
|
||||
rag_tokenizer.tokenize(ans).split(),
|
||||
rag_tokenizer.tokenize(inst).split())
|
||||
return self.qryr.hybrid_similarity(ans_embd, ins_embd, rag_tokenizer.tokenize(ans).split(), rag_tokenizer.tokenize(inst).split())
|
||||
|
||||
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
|
||||
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True,
|
||||
rerank_mdl=None, highlight=False,
|
||||
rank_feature: dict | None = {PAGERANK_FLD: 10}):
|
||||
def retrieval(
|
||||
self,
|
||||
question,
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
page,
|
||||
page_size,
|
||||
similarity_threshold=0.2,
|
||||
vector_similarity_weight=0.3,
|
||||
top=1024,
|
||||
doc_ids=None,
|
||||
aggs=True,
|
||||
rerank_mdl=None,
|
||||
highlight=False,
|
||||
rank_feature: dict | None = {PAGERANK_FLD: 10},
|
||||
):
|
||||
"""
|
||||
执行检索操作,根据问题查询相关文档片段
|
||||
|
||||
|
||||
参数说明:
|
||||
- question: 用户输入的查询问题
|
||||
- embd_mdl: 嵌入模型,用于将文本转换为向量
|
||||
|
@ -397,68 +413,58 @@ class Dealer:
|
|||
- rerank_mdl: 重排序模型
|
||||
- highlight: 是否高亮匹配内容
|
||||
- rank_feature: 排序特征,如PageRank值
|
||||
|
||||
|
||||
返回:
|
||||
包含检索结果的字典,包括总数、文档片段和文档聚合信息
|
||||
"""
|
||||
# 初始化结果字典
|
||||
# 初始化结果字典
|
||||
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
||||
if not question:
|
||||
return ranks
|
||||
# 设置重排序页面限制
|
||||
RERANK_PAGE_LIMIT = 3
|
||||
RERANK_LIMIT = 64
|
||||
RERANK_LIMIT = int(RERANK_LIMIT // page_size + ((RERANK_LIMIT % page_size) / (page_size * 1.0) + 0.5)) * page_size if page_size > 1 else 1
|
||||
if RERANK_LIMIT < 1:
|
||||
RERANK_LIMIT = 1
|
||||
# 构建检索请求参数
|
||||
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size * RERANK_PAGE_LIMIT, 128),
|
||||
"question": question, "vector": True, "topk": top,
|
||||
"similarity": similarity_threshold,
|
||||
"available_int": 1}
|
||||
|
||||
# 如果页码超过重排序限制,直接请求指定页的数据
|
||||
if page > RERANK_PAGE_LIMIT:
|
||||
req["page"] = page
|
||||
req["size"] = page_size
|
||||
req = {
|
||||
"kb_ids": kb_ids,
|
||||
"doc_ids": doc_ids,
|
||||
"page": math.ceil(page_size * page / RERANK_LIMIT),
|
||||
"size": RERANK_LIMIT,
|
||||
"question": question,
|
||||
"vector": True,
|
||||
"topk": top,
|
||||
"similarity": similarity_threshold,
|
||||
"available_int": 1,
|
||||
}
|
||||
|
||||
# 处理租户ID格式
|
||||
if isinstance(tenant_ids, str):
|
||||
tenant_ids = tenant_ids.split(",")
|
||||
|
||||
|
||||
# 执行搜索操作
|
||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids],
|
||||
kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
|
||||
ranks["total"] = sres.total
|
||||
|
||||
# 根据页码决定是否需要重排序
|
||||
if page <= RERANK_PAGE_LIMIT:
|
||||
# 前几页需要重排序以提高结果质量
|
||||
if rerank_mdl and sres.total > 0:
|
||||
# 使用重排序模型进行重排序
|
||||
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
|
||||
sres, question, 1 - vector_similarity_weight,
|
||||
vector_similarity_weight,
|
||||
rank_feature=rank_feature)
|
||||
else:
|
||||
# 使用默认方法进行重排序
|
||||
sim, tsim, vsim = self.rerank(
|
||||
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
|
||||
rank_feature=rank_feature)
|
||||
# 根据相似度降序排序,并选择当前页的结果
|
||||
idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size]
|
||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
|
||||
|
||||
if rerank_mdl and sres.total > 0:
|
||||
sim, tsim, vsim = self.rerank_by_model(rerank_mdl, sres, question, 1 - vector_similarity_weight, vector_similarity_weight, rank_feature=rank_feature)
|
||||
else:
|
||||
# 后续页面不需要重排序,直接使用搜索结果
|
||||
sim = tsim = vsim = [1] * len(sres.ids)
|
||||
idx = list(range(len(sres.ids)))
|
||||
|
||||
# 获取向量维度和列名
|
||||
sim, tsim, vsim = self.rerank(sres, question, 1 - vector_similarity_weight, vector_similarity_weight, rank_feature=rank_feature)
|
||||
# Already paginated in search function
|
||||
idx = np.argsort(sim * -1)[(page - 1) * page_size : page * page_size]
|
||||
|
||||
dim = len(sres.query_vector)
|
||||
vector_column = f"q_{dim}_vec"
|
||||
zero_vector = [0.0] * dim
|
||||
|
||||
# 处理每个检索结果
|
||||
if doc_ids:
|
||||
similarity_threshold = 0
|
||||
page_size = 30
|
||||
sim_np = np.array(sim)
|
||||
filtered_count = (sim_np >= similarity_threshold).sum()
|
||||
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
|
||||
for i in idx:
|
||||
# 过滤低于阈值的结果
|
||||
if sim[i] < similarity_threshold:
|
||||
break
|
||||
# 控制返回结果数量
|
||||
if len(ranks["chunks"]) >= page_size:
|
||||
if aggs:
|
||||
continue
|
||||
|
@ -468,7 +474,6 @@ class Dealer:
|
|||
dnm = chunk.get("docnm_kwd", "")
|
||||
did = chunk.get("doc_id", "")
|
||||
position_int = chunk.get("position_int", [])
|
||||
# 构建结果字典
|
||||
d = {
|
||||
"chunk_id": id,
|
||||
"content_ltks": chunk["content_ltks"],
|
||||
|
@ -483,9 +488,8 @@ class Dealer:
|
|||
"term_similarity": tsim[i],
|
||||
"vector": chunk.get(vector_column, zero_vector),
|
||||
"positions": position_int,
|
||||
"doc_type_kwd": chunk.get("doc_type_kwd", ""),
|
||||
}
|
||||
|
||||
# 处理高亮内容
|
||||
if highlight and sres.highlight:
|
||||
if id in sres.highlight:
|
||||
d["highlight"] = rmSpace(sres.highlight[id])
|
||||
|
@ -495,12 +499,7 @@ class Dealer:
|
|||
if dnm not in ranks["doc_aggs"]:
|
||||
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
||||
ranks["doc_aggs"][dnm]["count"] += 1
|
||||
# 将文档聚合信息转换为列表格式,并按计数降序排序
|
||||
ranks["doc_aggs"] = [{"doc_name": k,
|
||||
"doc_id": v["doc_id"],
|
||||
"count": v["count"]} for k,
|
||||
v in sorted(ranks["doc_aggs"].items(),
|
||||
key=lambda x: x[1]["count"] * -1)]
|
||||
ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k, v in sorted(ranks["doc_aggs"].items(), key=lambda x: x[1]["count"] * -1)]
|
||||
ranks["chunks"] = ranks["chunks"][:page_size]
|
||||
|
||||
return ranks
|
||||
|
@ -509,16 +508,12 @@ class Dealer:
|
|||
tbl = self.dataStore.sql(sql, fetch_size, format)
|
||||
return tbl
|
||||
|
||||
def chunk_list(self, doc_id: str, tenant_id: str,
|
||||
kb_ids: list[str], max_count=1024,
|
||||
offset=0,
|
||||
fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
||||
def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, offset=0, fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
||||
condition = {"doc_id": doc_id}
|
||||
res = []
|
||||
bs = 128
|
||||
for p in range(offset, max_count, bs):
|
||||
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
|
||||
kb_ids)
|
||||
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), kb_ids)
|
||||
dict_chunks = self.dataStore.getFields(es_res, fields)
|
||||
for id, doc in dict_chunks.items():
|
||||
doc["id"] = id
|
||||
|
@ -548,8 +543,7 @@ class Dealer:
|
|||
if not aggs:
|
||||
return False
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
key=lambda x: x[1] * -1)[:topn_tags]
|
||||
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], key=lambda x: x[1] * -1)[:topn_tags]
|
||||
doc[TAG_FLD] = {a: c for a, c in tag_fea if c > 0}
|
||||
return True
|
||||
|
||||
|
@ -564,6 +558,5 @@ class Dealer:
|
|||
if not aggs:
|
||||
return {}
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
key=lambda x: x[1] * -1)[:topn_tags]
|
||||
return {a: max(1, c) for a, c in tag_fea}
|
||||
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], key=lambda x: x[1] * -1)[:topn_tags]
|
||||
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
|
||||
|
|
|
@ -25,20 +25,18 @@ from api.utils.file_utils import get_project_base_directory
|
|||
|
||||
class Dealer:
|
||||
def __init__(self, redis=None):
|
||||
|
||||
self.lookup_num = 100000000
|
||||
self.load_tm = time.time() - 1000000
|
||||
self.dictionary = None
|
||||
path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json")
|
||||
try:
|
||||
self.dictionary = json.load(open(path, 'r'))
|
||||
self.dictionary = json.load(open(path, "r"))
|
||||
except Exception:
|
||||
logging.warning("Missing synonym.json")
|
||||
self.dictionary = {}
|
||||
|
||||
if not redis:
|
||||
logging.warning(
|
||||
"Realtime synonym is disabled, since no redis connection.")
|
||||
logging.warning("Realtime synonym is disabled, since no redis connection.")
|
||||
if not len(self.dictionary.keys()):
|
||||
logging.warning("Fail to load synonym")
|
||||
|
||||
|
@ -67,18 +65,36 @@ class Dealer:
|
|||
logging.error("Fail to load synonym!" + str(e))
|
||||
|
||||
def lookup(self, tk, topn=8):
|
||||
"""
|
||||
查找输入词条(tk)的同义词,支持英文和中文混合处理
|
||||
|
||||
参数:
|
||||
tk (str): 待查询的词条(如"happy"或"苹果")
|
||||
topn (int): 最多返回的同义词数量,默认为8
|
||||
|
||||
返回:
|
||||
list: 同义词列表,可能为空(无同义词时)
|
||||
|
||||
处理逻辑:
|
||||
1. 英文单词:使用WordNet语义网络查询
|
||||
2. 中文/其他:从预加载的自定义词典查询
|
||||
"""
|
||||
# 英文单词处理分支
|
||||
if re.match(r"[a-z]+$", tk):
|
||||
res = list(set([re.sub("_", " ", syn.name().split(".")[0]) for syn in wordnet.synsets(tk)]) - set([tk]))
|
||||
return [t for t in res if t]
|
||||
|
||||
# 中文/其他词条处理
|
||||
self.lookup_num += 1
|
||||
self.load()
|
||||
self.load() # 自定义词典
|
||||
# 从字典获取同义词,默认返回空列表
|
||||
res = self.dictionary.get(re.sub(r"[ \t]+", " ", tk.lower()), [])
|
||||
# 兼容处理:如果字典值是字符串,转为单元素列表
|
||||
if isinstance(res, str):
|
||||
res = [res]
|
||||
return res[:topn]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
dl = Dealer()
|
||||
print(dl.dictionary)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -26,37 +26,41 @@ from api.utils.file_utils import get_project_base_directory
|
|||
|
||||
class Dealer:
|
||||
def __init__(self):
|
||||
self.stop_words = set(["请问",
|
||||
"您",
|
||||
"你",
|
||||
"我",
|
||||
"他",
|
||||
"是",
|
||||
"的",
|
||||
"就",
|
||||
"有",
|
||||
"于",
|
||||
"及",
|
||||
"即",
|
||||
"在",
|
||||
"为",
|
||||
"最",
|
||||
"有",
|
||||
"从",
|
||||
"以",
|
||||
"了",
|
||||
"将",
|
||||
"与",
|
||||
"吗",
|
||||
"吧",
|
||||
"中",
|
||||
"#",
|
||||
"什么",
|
||||
"怎么",
|
||||
"哪个",
|
||||
"哪些",
|
||||
"啥",
|
||||
"相关"])
|
||||
self.stop_words = set(
|
||||
[
|
||||
"请问",
|
||||
"您",
|
||||
"你",
|
||||
"我",
|
||||
"他",
|
||||
"是",
|
||||
"的",
|
||||
"就",
|
||||
"有",
|
||||
"于",
|
||||
"及",
|
||||
"即",
|
||||
"在",
|
||||
"为",
|
||||
"最",
|
||||
"有",
|
||||
"从",
|
||||
"以",
|
||||
"了",
|
||||
"将",
|
||||
"与",
|
||||
"吗",
|
||||
"吧",
|
||||
"中",
|
||||
"#",
|
||||
"什么",
|
||||
"怎么",
|
||||
"哪个",
|
||||
"哪些",
|
||||
"啥",
|
||||
"相关",
|
||||
]
|
||||
)
|
||||
|
||||
def load_dict(fnm):
|
||||
res = {}
|
||||
|
@ -90,50 +94,45 @@ class Dealer:
|
|||
logging.warning("Load term.freq FAIL!")
|
||||
|
||||
def pretoken(self, txt, num=False, stpwd=True):
|
||||
patt = [
|
||||
r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
|
||||
]
|
||||
rewt = [
|
||||
]
|
||||
patt = [r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"]
|
||||
rewt = []
|
||||
for p, r in rewt:
|
||||
txt = re.sub(p, r, txt)
|
||||
|
||||
res = []
|
||||
for t in rag_tokenizer.tokenize(txt).split():
|
||||
tk = t
|
||||
if (stpwd and tk in self.stop_words) or (
|
||||
re.match(r"[0-9]$", tk) and not num):
|
||||
if (stpwd and tk in self.stop_words) or (re.match(r"[0-9]$", tk) and not num):
|
||||
continue
|
||||
for p in patt:
|
||||
if re.match(p, t):
|
||||
tk = "#"
|
||||
break
|
||||
#tk = re.sub(r"([\+\\-])", r"\\\1", tk)
|
||||
# tk = re.sub(r"([\+\\-])", r"\\\1", tk)
|
||||
if tk != "#" and tk:
|
||||
res.append(tk)
|
||||
return res
|
||||
|
||||
def tokenMerge(self, tks):
|
||||
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
def oneTerm(t):
|
||||
return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
|
||||
res, i = [], 0
|
||||
while i < len(tks):
|
||||
j = i
|
||||
if i == 0 and oneTerm(tks[i]) and len(
|
||||
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
|
||||
if i == 0 and oneTerm(tks[i]) and len(tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
|
||||
res.append(" ".join(tks[0:2]))
|
||||
i = 2
|
||||
continue
|
||||
|
||||
while j < len(
|
||||
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
|
||||
while j < len(tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
|
||||
j += 1
|
||||
if j - i > 1:
|
||||
if j - i < 5:
|
||||
res.append(" ".join(tks[i:j]))
|
||||
i = j
|
||||
else:
|
||||
res.append(" ".join(tks[i:i + 2]))
|
||||
res.append(" ".join(tks[i : i + 2]))
|
||||
i = i + 2
|
||||
else:
|
||||
if len(tks[i]) > 0:
|
||||
|
@ -153,15 +152,13 @@ class Dealer:
|
|||
特殊分词方法,主要处理连续英文单词的合并
|
||||
参数:
|
||||
txt: 待分词的文本字符串
|
||||
|
||||
|
||||
返回:
|
||||
处理后的词条列表
|
||||
"""
|
||||
tks = []
|
||||
for t in re.sub(r"[ \t]+", " ", txt).split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
|
||||
re.match(r".*[a-zA-Z]$", t) and tks and \
|
||||
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and re.match(r".*[a-zA-Z]$", t) and tks and self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
|
||||
tks[-1] = tks[-1] + " " + t
|
||||
else:
|
||||
tks.append(t)
|
||||
|
@ -180,8 +177,7 @@ class Dealer:
|
|||
return 0.01
|
||||
if not self.ne or t not in self.ne:
|
||||
return 1
|
||||
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
|
||||
"firstnm": 1}
|
||||
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, "firstnm": 1}
|
||||
return m[self.ne[t]]
|
||||
|
||||
def postag(t):
|
||||
|
@ -208,7 +204,7 @@ class Dealer:
|
|||
if not s and len(t) >= 4:
|
||||
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
|
||||
if len(s) > 1:
|
||||
s = np.min([freq(tt) for tt in s]) / 6.
|
||||
s = np.min([freq(tt) for tt in s]) / 6.0
|
||||
else:
|
||||
s = 0
|
||||
|
||||
|
@ -224,18 +220,18 @@ class Dealer:
|
|||
elif len(t) >= 4:
|
||||
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
|
||||
if len(s) > 1:
|
||||
return max(3, np.min([df(tt) for tt in s]) / 6.)
|
||||
return max(3, np.min([df(tt) for tt in s]) / 6.0)
|
||||
|
||||
return 3
|
||||
|
||||
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
|
||||
def idf(s, N):
|
||||
return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
|
||||
|
||||
tw = []
|
||||
if not preprocess:
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tks])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
np.array([ner(t) * postag(t) for t in tks])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * np.array([ner(t) * postag(t) for t in tks])
|
||||
wts = [s for s in wts]
|
||||
tw = list(zip(tks, wts))
|
||||
else:
|
||||
|
@ -243,8 +239,7 @@ class Dealer:
|
|||
tt = self.tokenMerge(self.pretoken(tk, True))
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
np.array([ner(t) * postag(t) for t in tt])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * np.array([ner(t) * postag(t) for t in tt])
|
||||
wts = [s for s in wts]
|
||||
tw.extend(zip(tt, wts))
|
||||
|
||||
|
|
|
@ -42,6 +42,10 @@ def chunks_format(reference):
|
|||
"image_id": get_value(chunk, "image_id", "img_id"),
|
||||
"positions": get_value(chunk, "positions", "position_int"),
|
||||
"url": chunk.get("url"),
|
||||
"similarity": chunk.get("similarity"),
|
||||
"vector_similarity": chunk.get("vector_similarity"),
|
||||
"term_similarity": chunk.get("term_similarity"),
|
||||
"doc_type": chunk.get("doc_type_kwd"),
|
||||
}
|
||||
for chunk in reference.get("chunks", [])
|
||||
]
|
||||
|
@ -145,15 +149,17 @@ def kb_prompt(kbinfos, max_tokens):
|
|||
|
||||
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
|
||||
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
||||
doc2chunks[ck["docnm_kwd"]]["chunks"].append((f"URL: {ck['url']}\n" if "url" in ck else "") + f"ID: {i}\n" + ck["content_with_weight"])
|
||||
cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
|
||||
cnt += ck["content_with_weight"]
|
||||
doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
|
||||
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
|
||||
|
||||
knowledges = []
|
||||
for nm, cks_meta in doc2chunks.items():
|
||||
txt = f"\nDocument: {nm} \n"
|
||||
txt = f"\n文档: {nm} \n"
|
||||
for k, v in cks_meta["meta"].items():
|
||||
txt += f"{k}: {v}\n"
|
||||
txt += "Relevant fragments as following:\n"
|
||||
txt += "相关片段如下:\n"
|
||||
for i, chunk in enumerate(cks_meta["chunks"], 1):
|
||||
txt += f"{chunk}\n"
|
||||
knowledges.append(txt)
|
||||
|
@ -388,3 +394,57 @@ Output:
|
|||
except Exception as e:
|
||||
logging.exception(f"JSON parsing error: {result} -> {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def vision_llm_describe_prompt(page=None) -> str:
|
||||
prompt_en = """
|
||||
INSTRUCTION:
|
||||
Transcribe the content from the provided PDF page image into clean Markdown format.
|
||||
- Only output the content transcribed from the image.
|
||||
- Do NOT output this instruction or any other explanation.
|
||||
- If the content is missing or you do not understand the input, return an empty string.
|
||||
|
||||
RULES:
|
||||
1. Do NOT generate examples, demonstrations, or templates.
|
||||
2. Do NOT output any extra text such as 'Example', 'Example Output', or similar.
|
||||
3. Do NOT generate any tables, headings, or content that is not explicitly present in the image.
|
||||
4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content.
|
||||
5. Do NOT explain Markdown or mention that you are using Markdown.
|
||||
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
||||
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
||||
8. Preserve the original language, information, and order exactly as shown in the image.
|
||||
"""
|
||||
|
||||
if page is not None:
|
||||
prompt_en += f"\nAt the end of the transcription, add the page divider: `--- Page {page} ---`."
|
||||
|
||||
prompt_en += """
|
||||
FAILURE HANDLING:
|
||||
- If you do not detect valid content in the image, return an empty string.
|
||||
"""
|
||||
return prompt_en
|
||||
|
||||
|
||||
def vision_llm_figure_describe_prompt() -> str:
|
||||
prompt = """
|
||||
You are an expert visual data analyst. Analyze the image and provide a comprehensive description of its content. Focus on identifying the type of visual data representation (e.g., bar chart, pie chart, line graph, table, flowchart), its structure, and any text captions or labels included in the image.
|
||||
|
||||
Tasks:
|
||||
1. Describe the overall structure of the visual representation. Specify if it is a chart, graph, table, or diagram.
|
||||
2. Identify and extract any axes, legends, titles, or labels present in the image. Provide the exact text where available.
|
||||
3. Extract the data points from the visual elements (e.g., bar heights, line graph coordinates, pie chart segments, table rows and columns).
|
||||
4. Analyze and explain any trends, comparisons, or patterns shown in the data.
|
||||
5. Capture any annotations, captions, or footnotes, and explain their relevance to the image.
|
||||
6. Only include details that are explicitly present in the image. If an element (e.g., axis, legend, or caption) does not exist or is not visible, do not mention it.
|
||||
|
||||
Output format (include only sections relevant to the image content):
|
||||
- Visual Type: [Type]
|
||||
- Title: [Title text, if available]
|
||||
- Axes / Legends / Labels: [Details, if available]
|
||||
- Data Points: [Extracted data]
|
||||
- Trends / Insights: [Analysis and interpretation]
|
||||
- Captions / Annotations: [Text and relevance, if available]
|
||||
|
||||
Ensure high accuracy, clarity, and completeness in your analysis, and includes only the information present in the image. Avoid unnecessary statements about missing elements.
|
||||
"""
|
||||
return prompt
|
||||
|
|
10542
rag/res/synonym.json
10542
rag/res/synonym.json
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue