Fix: 增加解析分词,修复召回时关键词判断失效问题 issue#133

This commit is contained in:
zstar 2025-06-01 21:00:23 +08:00 committed by GitHub
commit 401e3d81c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 1387 additions and 11417 deletions

1
.gitignore vendored
View File

@ -54,3 +54,4 @@ docker/models
management/web/types/auto management/web/types/auto
web/node_modules/.cache/logger/umi.log web/node_modules/.cache/logger/umi.log
management/models--slanet_plus management/models--slanet_plus
node_modules/.cache/logger/umi.log

View File

@ -76,23 +76,32 @@ ollama pull bge-m3:latest
#### 1. 使用Docker Compose运行 #### 1. 使用Docker Compose运行
在项目根目录下执行 - 使用GPU运行(需保证首张显卡有6GB以上剩余显存)
使用GPU运行 1. 在宿主机安装nvidia-container-runtime让 Docker 自动挂载 GPU 设备和驱动:
```bash
docker compose -f docker/docker-compose_gpu.yml up -d
```
使用CPU运行 ```bash
```bash sudo apt install -y nvidia-container-runtime
docker compose -f docker/docker-compose.yml up -d ```
```
2. 在项目根目录下执行
```bash
docker compose -f docker/docker-compose_gpu.yml up -d
```
- 使用CPU运行
在项目根目录下执行
```bash
docker compose -f docker/docker-compose.yml up -d
```
访问地址:`服务器ip:80`,进入到前台界面 访问地址:`服务器ip:80`,进入到前台界面
访问地址:`服务器ip:8888`,进入到后台管理界面 访问地址:`服务器ip:8888`,进入到后台管理界面
图文教程:[https://blog.csdn.net/qq1198768105/article/details/147475488](https://blog.csdn.net/qq1198768105/article/details/147475488)
#### 2. 源码运行(mysql、minio、es等组件仍需docker启动) #### 2. 源码运行(mysql、minio、es等组件仍需docker启动)
@ -100,29 +109,29 @@ docker compose -f docker/docker-compose.yml up -d
- 启动后端:进入到`management/server`,执行: - 启动后端:进入到`management/server`,执行:
```bash ```bash
python app.py python app.py
``` ```
- 启动前端:进入到`management\web`,执行: - 启动前端:进入到`management\web`,执行:
```bash ```bash
pnpm dev pnpm dev
``` ```
2. 启动前台交互系统: 2. 启动前台交互系统:
- 启动后端:项目根目录下执行: - 启动后端:项目根目录下执行:
```bash ```bash
python -m api.ragflow_server python -m api.ragflow_server
``` ```
- 启动前端:进入到`web`,执行: - 启动前端:进入到`web`,执行:
```bash ```bash
pnpm dev pnpm dev
``` ```
> [!NOTE] > [!NOTE]
> 源码部署需要注意如果用到MinerU后台解析需要参考MinerU的文档下载模型文件并安装LibreOffice配置环境变量以适配支持除pdf之外的类型文件。 > 源码部署需要注意如果用到MinerU后台解析需要参考MinerU的文档下载模型文件并安装LibreOffice配置环境变量以适配支持除pdf之外的类型文件。

View File

@ -22,7 +22,8 @@ from flask_login import login_required, current_user
from rag.app.qa import rmPrefix, beAdoc from rag.app.qa import rmPrefix, beAdoc
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
from rag.prompts import keyword_extraction
# from rag.prompts import keyword_extraction, cross_languages
from rag.settings import PAGERANK_FLD from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace from rag.utils import rmSpace
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
@ -37,9 +38,9 @@ import xxhash
import re import re
@manager.route('/list', methods=['POST']) # noqa: F821 @manager.route("/list", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id") @validate_request("doc_id") # 验证请求中必须包含 doc_id 参数
def list_chunk(): def list_chunk():
req = request.json req = request.json
doc_id = req["doc_id"] doc_id = req["doc_id"]
@ -54,9 +55,7 @@ def list_chunk():
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
query = { query = {"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True}
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
}
if "available_int" in req: if "available_int" in req:
query["available_int"] = int(req["available_int"]) query["available_int"] = int(req["available_int"])
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True) sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
@ -64,9 +63,7 @@ def list_chunk():
for id in sres.ids: for id in sres.ids:
d = { d = {
"chunk_id": id, "chunk_id": id,
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[ "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get("content_with_weight", ""),
id].get(
"content_with_weight", ""),
"doc_id": sres.field[id]["doc_id"], "doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"], "docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []), "important_kwd": sres.field[id].get("important_kwd", []),
@ -81,12 +78,11 @@ def list_chunk():
return get_json_result(data=res) return get_json_result(data=res)
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found!', return get_json_result(data=False, message="No chunk found!", code=settings.RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)
@manager.route('/get', methods=['GET']) # noqa: F821 @manager.route("/get", methods=["GET"]) # noqa: F821
@login_required @login_required
def get(): def get():
chunk_id = request.args["chunk_id"] chunk_id = request.args["chunk_id"]
@ -112,19 +108,16 @@ def get():
return get_json_result(data=chunk) return get_json_result(data=chunk)
except Exception as e: except Exception as e:
if str(e).find("NotFoundError") >= 0: if str(e).find("NotFoundError") >= 0:
return get_json_result(data=False, message='Chunk not found!', return get_json_result(data=False, message="Chunk not found!", code=settings.RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)
@manager.route('/set', methods=['POST']) # noqa: F821 @manager.route("/set", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id", "chunk_id", "content_with_weight") @validate_request("doc_id", "chunk_id", "content_with_weight")
def set(): def set():
req = request.json req = request.json
d = { d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]}
"id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]}
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if "important_kwd" in req: if "important_kwd" in req:
@ -153,13 +146,9 @@ def set():
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if doc.parser_id == ParserType.QA: if doc.parser_id == ParserType.QA:
arr = [ arr = [t for t in re.split(r"[\n\t]", req["content_with_weight"]) if len(t) > 1]
t for t in re.split(
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:])) q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
d = beAdoc(d, q, a, not any( d = beAdoc(d, q, a, not any([rag_tokenizer.is_chinese(t) for t in q + a]))
[rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
@ -170,7 +159,7 @@ def set():
return server_error_response(e) return server_error_response(e)
@manager.route('/switch', methods=['POST']) # noqa: F821 @manager.route("/switch", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("chunk_ids", "available_int", "doc_id") @validate_request("chunk_ids", "available_int", "doc_id")
def switch(): def switch():
@ -180,20 +169,19 @@ def switch():
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
for cid in req["chunk_ids"]: for cid in req["chunk_ids"]:
if not settings.docStoreConn.update({"id": cid}, if not settings.docStoreConn.update({"id": cid}, {"available_int": int(req["available_int"])}, search.index_name(DocumentService.get_tenant_id(req["doc_id"])), doc.kb_id):
{"available_int": int(req["available_int"])},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
return get_data_error_result(message="Index updating failure") return get_data_error_result(message="Index updating failure")
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/rm', methods=['POST']) # noqa: F821 @manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("chunk_ids", "doc_id") @validate_request("chunk_ids", "doc_id")
def rm(): def rm():
from rag.utils.storage_factory import STORAGE_IMPL
req = request.json req = request.json
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
@ -204,19 +192,21 @@ def rm():
deleted_chunk_ids = req["chunk_ids"] deleted_chunk_ids = req["chunk_ids"]
chunk_number = len(deleted_chunk_ids) chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids:
if STORAGE_IMPL.obj_exist(doc.kb_id, cid):
STORAGE_IMPL.rm(doc.kb_id, cid)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/create', methods=['POST']) # noqa: F821 @manager.route("/create", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id", "content_with_weight") @validate_request("doc_id", "content_with_weight")
def create(): def create():
req = request.json req = request.json
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest() chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), "content_with_weight": req["content_with_weight"]}
"content_with_weight": req["content_with_weight"]}
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req.get("important_kwd", []) d["important_kwd"] = req.get("important_kwd", [])
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_kwd", []))) d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_kwd", [])))
@ -252,14 +242,35 @@ def create():
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk_id": chunck_id}) return get_json_result(data={"chunk_id": chunck_id})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821 """
{
"similarity_threshold": 0.2,
"vector_similarity_weight": 0.30000000000000004,
"question": "香港",
"doc_ids": [],
"kb_id": "4b071030bc8e43f1bfb8b7831f320d2f",
"page": 1,
"size": 10
},
{
"similarity_threshold": 0.2,
"vector_similarity_weight": 0.30000000000000004,
"question": "显著优势",
"doc_ids": [],
"kb_id": "1848bc54384611f0b33e4e66786d0323",
"page": 1,
"size": 10
}
"""
@manager.route("/retrieval_test", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("kb_id", "question") @validate_request("kb_id", "question")
def retrieval_test(): def retrieval_test():
@ -268,57 +279,53 @@ def retrieval_test():
size = int(req.get("size", 30)) size = int(req.get("size", 30))
question = req["question"] question = req["question"]
kb_ids = req["kb_id"] kb_ids = req["kb_id"]
# 如果kb_ids是字符串将其转换为列表
if isinstance(kb_ids, str): if isinstance(kb_ids, str):
kb_ids = [kb_ids] kb_ids = [kb_ids]
doc_ids = req.get("doc_ids", []) doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.0)) similarity_threshold = float(req.get("similarity_threshold", 0.0))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024)) top = int(req.get("top_k", 1024))
# langs = req.get("cross_languages", []) # 获取跨语言设定
tenant_ids = [] tenant_ids = []
try: try:
# 查询当前用户所属的租户
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
# 验证知识库权限
for kb_id in kb_ids: for kb_id in kb_ids:
for tenant in tenants: for tenant in tenants:
if KnowledgebaseService.query( if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
tenant_id=tenant.tenant_id, id=kb_id):
tenant_ids.append(tenant.tenant_id) tenant_ids.append(tenant.tenant_id)
break break
else: else:
return get_json_result( return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
data=False, message='Only owner of knowledgebase authorized for this operation.', # 获取知识库信息
code=settings.RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e: if not e:
return get_data_error_result(message="Knowledgebase not found!") return get_data_error_result(message="Knowledgebase not found!")
# if langs:
# question = cross_languages(kb.tenant_id, None, question, langs) # 跨语言处理
# 加载嵌入模型
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
# 加载重排序模型(如果指定)
rerank_mdl = None rerank_mdl = None
if req.get("rerank_id"): if req.get("rerank_id"):
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
if req.get("keyword", False): # 对问题进行标签化
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) # labels = label_question(question, [kb])
question += keyword_extraction(chat_mdl, question) labels = None
labels = label_question(question, [kb]) # 执行检索操作
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, ranks = settings.retrievaler.retrieval(
similarity_threshold, vector_similarity_weight, top, question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), )
rank_feature=labels
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question,
tenant_ids,
kb_ids,
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
# 移除不必要的向量信息
for c in ranks["chunks"]: for c in ranks["chunks"]:
c.pop("vector", None) c.pop("vector", None)
ranks["labels"] = labels ranks["labels"] = labels
@ -326,47 +333,5 @@ def retrieval_test():
return get_json_result(data=ranks) return get_json_result(data=ranks)
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!', return get_json_result(data=False, message="No chunk found! Check the chunk status please!", code=settings.RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
@login_required
def knowledge_graph():
doc_id = request.args["doc_id"]
tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
req = {
"doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"]
}
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"]
try:
content_json = json.loads(sres.field[id]["content_with_weight"])
except Exception:
continue
if ty == 'mind_map':
node_dict = {}
def repeat_deal(content_json, node_dict):
if 'id' in content_json:
if content_json['id'] in node_dict:
node_name = content_json['id']
content_json['id'] += f"({node_dict[content_json['id']]})"
node_dict[node_name] += 1
else:
node_dict[content_json['id']] = 1
if 'children' in content_json and content_json['children']:
for item in content_json['children']:
repeat_deal(item, node_dict)
repeat_deal(content_json, node_dict)
obj[ty] = content_json
return get_json_result(data=obj)

View File

@ -4,9 +4,11 @@ import redis
from minio import Minio from minio import Minio
from dotenv import load_dotenv from dotenv import load_dotenv
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from pathlib import Path
# 加载环境变量 # 加载环境变量
load_dotenv("../../docker/.env") env_path = Path(__file__).parent.parent.parent / "docker" / ".env"
load_dotenv(env_path)
# 检测是否在Docker容器中运行 # 检测是否在Docker容器中运行

View File

@ -25,3 +25,4 @@ omegaconf==2.3.0
rapid-table==1.0.3 rapid-table==1.0.3
openai==1.70.0 openai==1.70.0
redis==6.2.0 redis==6.2.0
tokenizer==3.4.5

View File

@ -2,12 +2,12 @@ import os
import tempfile import tempfile
import shutil import shutil
import json import json
from bs4 import BeautifulSoup
import mysql.connector import mysql.connector
import time import time
import traceback import traceback
import re import re
import requests import requests
from bs4 import BeautifulSoup
from io import BytesIO from io import BytesIO
from datetime import datetime from datetime import datetime
from database import MINIO_CONFIG, DB_CONFIG, get_minio_client, get_es_client from database import MINIO_CONFIG, DB_CONFIG, get_minio_client, get_es_client
@ -17,17 +17,18 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.read_api import read_local_office, read_local_images from magic_pdf.data.read_api import read_local_office, read_local_images
from utils import generate_uuid from utils import generate_uuid
from .rag_tokenizer import RagTokenizer
tknzr = RagTokenizer()
# 自定义tokenizer和文本处理函数替代rag.nlp中的功能
def tokenize_text(text): def tokenize_text(text):
"""将文本分词替代rag_tokenizer功能""" return tknzr.tokenize(text)
# 简单实现,未来可能需要改成更复杂的分词逻辑
return text.split()
def merge_chunks(sections, chunk_token_num=128, delimiter="\n。;!?"): def merge_chunks(sections, chunk_token_num=128, delimiter="\n。;!?"):
"""合并文本块替代naive_merge功能""" """合并文本块替代naive_merge功能(预留函数)"""
if not sections: if not sections:
return [] return []
@ -149,25 +150,7 @@ def _create_task_record(doc_id, chunk_ids_list):
progress, progress_msg, retry_count, digest, chunk_ids, task_type, priority progress, progress_msg, retry_count, digest, chunk_ids, task_type, priority
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""" """
task_params = [ task_params = [task_id, current_timestamp, current_date_only, current_timestamp, current_date_only, doc_id, 0, 1, None, 0.0, 1.0, "MinerU解析完成", 1, digest, chunk_ids_str, "", 0]
task_id,
current_timestamp,
current_date_only,
current_timestamp,
current_date_only,
doc_id,
0,
1,
None,
0.0,
1.0,
"MinerU解析完成",
1,
digest,
chunk_ids_str,
"",
0
]
cursor.execute(task_insert, task_params) cursor.execute(task_insert, task_params)
conn.commit() conn.commit()
print(f"[Parser-INFO] Task记录创建成功Task ID: {task_id}") print(f"[Parser-INFO] Task记录创建成功Task ID: {task_id}")
@ -214,13 +197,13 @@ def process_table_content(content_list):
new_content_list = [] new_content_list = []
for item in content_list: for item in content_list:
if 'table_body' in item and item['table_body']: if "table_body" in item and item["table_body"]:
# 使用BeautifulSoup解析HTML表格 # 使用BeautifulSoup解析HTML表格
soup = BeautifulSoup(item['table_body'], 'html.parser') soup = BeautifulSoup(item["table_body"], "html.parser")
table = soup.find('table') table = soup.find("table")
if table: if table:
rows = table.find_all('tr') rows = table.find_all("tr")
# 获取表头(第一行) # 获取表头(第一行)
header_row = rows[0] if rows else None header_row = rows[0] if rows else None
@ -230,7 +213,7 @@ def process_table_content(content_list):
new_item = item.copy() new_item = item.copy()
# 创建只包含当前行的表格 # 创建只包含当前行的表格
new_table = soup.new_tag('table') new_table = soup.new_tag("table")
# 如果有表头,添加表头 # 如果有表头,添加表头
if header_row and i > 0: if header_row and i > 0:
@ -241,7 +224,7 @@ def process_table_content(content_list):
# 创建新的HTML结构 # 创建新的HTML结构
new_html = f"<html><body>{str(new_table)}</body></html>" new_html = f"<html><body>{str(new_table)}</body></html>"
new_item['table_body'] = f"\n\n{new_html}\n\n" new_item["table_body"] = f"\n\n{new_html}\n\n"
# 添加到新的内容列表 # 添加到新的内容列表
new_content_list.append(new_item) new_content_list.append(new_item)
@ -252,6 +235,7 @@ def process_table_content(content_list):
return new_content_list return new_content_list
def perform_parse(doc_id, doc_info, file_info, embedding_config): def perform_parse(doc_id, doc_info, file_info, embedding_config):
""" """
执行文档解析的核心逻辑 执行文档解析的核心逻辑
@ -305,7 +289,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
if normalized_base_url.endswith("/v1"): if normalized_base_url.endswith("/v1"):
# 如果 base_url 已经是 http://host/v1 形式 # 如果 base_url 已经是 http://host/v1 形式
embedding_url = normalized_base_url + "/" + endpoint_segment embedding_url = normalized_base_url + "/" + endpoint_segment
elif normalized_base_url.endswith('/embeddings'): elif normalized_base_url.endswith("/embeddings"):
# 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API无需再进行处理) # 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API无需再进行处理)
embedding_url = normalized_base_url embedding_url = normalized_base_url
else: else:
@ -403,7 +387,7 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
middle_content = pipe_result.get_middle_json() middle_content = pipe_result.get_middle_json()
middle_json_content = json.loads(middle_content) middle_json_content = json.loads(middle_content)
# 对excel文件单独进行处理 # 对excel文件单独进行处理
elif file_type.endswith("excel") : elif file_type.endswith("excel"):
update_progress(0.3, "使用MinerU解析器") update_progress(0.3, "使用MinerU解析器")
# 创建临时文件保存文件内容 # 创建临时文件保存文件内容
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
@ -613,7 +597,6 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
) )
# 准备ES文档 # 准备ES文档
content_tokens = tokenize_text(content) # 分词
current_time_es = datetime.now().strftime("%Y-%m-%d %H:%M:%S") current_time_es = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
current_timestamp_es = datetime.now().timestamp() current_timestamp_es = datetime.now().timestamp()
@ -625,11 +608,11 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
"doc_id": doc_id, "doc_id": doc_id,
"kb_id": kb_id, "kb_id": kb_id,
"docnm_kwd": doc_info["name"], "docnm_kwd": doc_info["name"],
"title_tks": doc_info["name"], "title_tks": tokenize_text(doc_info["name"]),
"title_sm_tks": doc_info["name"], "title_sm_tks": tokenize_text(doc_info["name"]),
"content_with_weight": content, "content_with_weight": content,
"content_ltks": " ".join(content_tokens), # 字符串类型 "content_ltks": tokenize_text(content),
"content_sm_ltks": " ".join(content_tokens), # 字符串类型 "content_sm_ltks": tokenize_text(content),
"page_num_int": [page_idx + 1], "page_num_int": [page_idx + 1],
"position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]] "position_int": [[page_idx + 1] + bbox_reordered], # 格式: [[page, x1, x2, y1, y2]]
"top_int": [1], "top_int": [1],
@ -755,7 +738,6 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
traceback.print_exc() # 打印详细错误堆栈 traceback.print_exc() # 打印详细错误堆栈
# 更新文档状态为失败 # 更新文档状态为失败
_update_document_progress(doc_id, status="1", run="0", message=error_message, process_duration=process_duration) # status=1表示完成run=0表示失败 _update_document_progress(doc_id, status="1", run="0", message=error_message, process_duration=process_duration) # status=1表示完成run=0表示失败
# 不抛出异常,让调用者知道任务已结束(但失败)
return {"success": False, "error": error_message} return {"success": False, "error": error_message}
finally: finally:

View File

@ -0,0 +1,412 @@
import logging
import copy
import datrie
import math
import os
import re
import string
from hanziconv import HanziConv
from nltk import word_tokenize
from nltk.stem import PorterStemmer, WordNetLemmatizer
# def get_project_base_directory():
# return os.path.abspath(
# os.path.join(
# os.path.dirname(os.path.realpath(__file__)),
# os.pardir,
# os.pardir,
# )
# )
class RagTokenizer:
def key_(self, line):
return str(line.lower().encode("utf-8"))[2:-1]
def rkey_(self, line):
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
def loadDict_(self, fnm):
print(f"[HUQIE]:Build trie from {fnm}")
try:
of = open(fnm, "r", encoding="utf-8")
while True:
line = of.readline()
if not line:
break
line = re.sub(r"[\r\n]+", "", line)
line = re.split(r"[ \t]", line)
k = self.key_(line[0])
F = int(math.log(float(line[1]) / self.DENOMINATOR) + 0.5)
if k not in self.trie_ or self.trie_[k][0] < F:
self.trie_[self.key_(line[0])] = (F, line[2])
self.trie_[self.rkey_(line[0])] = 1
dict_file_cache = fnm + ".trie"
print(f"[HUQIE]:Build trie cache to {dict_file_cache}")
self.trie_.save(dict_file_cache)
of.close()
except Exception:
logging.exception(f"[HUQIE]:Build trie {fnm} failed")
def __init__(self):
self.DENOMINATOR = 1000000
self.DIR_ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "res", "huqie")
self.stemmer = PorterStemmer()
self.lemmatizer = WordNetLemmatizer()
self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)"
trie_file_name = self.DIR_ + ".txt.trie"
# check if trie file existence
if os.path.exists(trie_file_name):
try:
# load trie from file
self.trie_ = datrie.Trie.load(trie_file_name)
return
except Exception:
# fail to load trie from file, build default trie
logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
self.trie_ = datrie.Trie(string.printable)
else:
# file not exist, build default trie
print(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
self.trie_ = datrie.Trie(string.printable)
# load data from dict file and save to trie file
self.loadDict_(self.DIR_ + ".txt")
def _strQ2B(self, ustring):
"""全角转半角,转小写"""
rstring = ""
for uchar in ustring:
inside_code = ord(uchar)
if inside_code == 0x3000:
inside_code = 0x0020
else:
inside_code -= 0xFEE0
if inside_code < 0x0020 or inside_code > 0x7E: # After the conversion, if it's not a half-width character, return the original character.
rstring += uchar
else:
rstring += chr(inside_code)
return rstring
def _tradi2simp(self, line):
"""繁体转简体"""
return HanziConv.toSimplified(line)
def dfs_(self, chars, s, preTks, tkslist):
res = s
# if s > MAX_L or s>= len(chars):
if s >= len(chars):
tkslist.append(preTks)
return res
# pruning
S = s + 1
if s + 2 <= len(chars):
t1, t2 = "".join(chars[s : s + 1]), "".join(chars[s : s + 2])
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
S = s + 2
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
t1 = preTks[-1][0] + "".join(chars[s : s + 1])
if self.trie_.has_keys_with_prefix(self.key_(t1)):
S = s + 2
for e in range(S, len(chars) + 1):
t = "".join(chars[s:e])
k = self.key_(t)
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
break
if k in self.trie_:
pretks = copy.deepcopy(preTks)
if k in self.trie_:
pretks.append((t, self.trie_[k]))
else:
pretks.append((t, (-12, "")))
res = max(res, self.dfs_(chars, e, pretks, tkslist))
if res > s:
return res
t = "".join(chars[s : s + 1])
k = self.key_(t)
if k in self.trie_:
preTks.append((t, self.trie_[k]))
else:
preTks.append((t, (-12, "")))
return self.dfs_(chars, s + 1, preTks, tkslist)
def freq(self, tk):
k = self.key_(tk)
if k not in self.trie_:
return 0
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
def score_(self, tfts):
B = 30
F, L, tks = 0, 0, []
for tk, (freq, tag) in tfts:
F += freq
L += 0 if len(tk) < 2 else 1
tks.append(tk)
# F /= len(tks)
L /= len(tks)
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
return tks, B / len(tks) + L + F
def sortTks_(self, tkslist):
res = []
for tfts in tkslist:
tks, s = self.score_(tfts)
res.append((tks, s))
return sorted(res, key=lambda x: x[1], reverse=True)
def merge_(self, tks):
# if split chars is part of token
res = []
tks = re.sub(r"[ ]+", " ", tks).split()
s = 0
while True:
if s >= len(tks):
break
E = s + 1
for e in range(s + 2, min(len(tks) + 2, s + 6)):
tk = "".join(tks[s:e])
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
E = e
res.append("".join(tks[s:E]))
s = E
return " ".join(res)
def maxForward_(self, line):
res = []
s = 0
while s < len(line):
e = s + 1
t = line[s:e]
while e < len(line) and self.trie_.has_keys_with_prefix(self.key_(t)):
e += 1
t = line[s:e]
while e - 1 > s and self.key_(t) not in self.trie_:
e -= 1
t = line[s:e]
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, "")))
s = e
return self.score_(res)
def maxBackward_(self, line):
res = []
s = len(line) - 1
while s >= 0:
e = s + 1
t = line[s:e]
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
s -= 1
t = line[s:e]
while s + 1 < e and self.key_(t) not in self.trie_:
s += 1
t = line[s:e]
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, "")))
s -= 1
return self.score_(res[::-1])
def _split_by_lang(self, line):
"""根据语言进行切分"""
txt_lang_pairs = []
arr = re.split(self.SPLIT_CHAR, line)
for a in arr:
if not a:
continue
s = 0
e = s + 1
zh = is_chinese(a[s])
while e < len(a):
_zh = is_chinese(a[e])
if _zh == zh:
e += 1
continue
txt_lang_pairs.append((a[s:e], zh))
s = e
e = s + 1
zh = _zh
if s >= len(a):
continue
txt_lang_pairs.append((a[s:e], zh))
return txt_lang_pairs
def tokenize(self, line: str) -> str:
"""
对输入文本进行分词支持中英文混合处理
分词流程
1. 预处理
- 将所有非单词字符字母数字下划线以外的替换为空格
- 全角字符转半角
- 转换为小写
- 繁体中文转简体中文
2. 按语言切分
- 将预处理后的文本按语言中文/非中文分割成多个片段
3. 分段处理
- 对于非中文通常是英文片段
- 使用 NLTK `word_tokenize` 进行分词
- 对分词结果进行词干提取 (PorterStemmer) 和词形还原 (WordNetLemmatizer)
- 对于中文片段
- 如果片段过短长度<2或为纯粹的英文/数字模式 "abc-def", "123.45"则直接保留该片段
- 否则采用基于词典的混合分词策略
a. 执行正向最大匹配 (FMM) 和逆向最大匹配 (BMM) 得到两组分词结果 (`tks` `tks1`)
b. 比较 FMM BMM 的结果
i. 找到两者从开头开始最长的相同分词序列这部分通常是无歧义的直接加入结果
ii. 对于 FMM BMM 结果不一致的歧义部分即从第一个不同点开始的子串
- 提取出这段有歧义的原始文本
- 调用 `self.dfs_` (深度优先搜索) 在这段文本上探索所有可能的分词组合
- `self.dfs_` 会利用Trie词典并由 `self.sortTks_` 对所有组合进行评分和排序
- 选择得分最高的分词方案作为该歧义段落的结果
iii.继续处理 FMM BMM 结果中歧义段落之后的部分重复步骤 i ii直到两个序列都处理完毕
c. 如果在比较完所有对应部分后FMM BMM 仍有剩余理论上如果实现正确且输入相同剩余部分也应相同
则对这部分剩余的原始文本同样使用 `self.dfs_` 进行最优分词
4. 后处理
- 将所有处理过的片段英文词元中文词元用空格连接起来
- 调用 `self.merge_` 对连接后的结果进行进一步的合并操作
尝试合并一些可能被错误分割但实际是一个完整词的片段基于词典检查
5. 返回最终分词结果字符串词元间用空格分隔
Args:
line (str): 待分词的原始输入字符串
Returns:
str: 分词后的字符串词元之间用空格分隔
"""
# 1. 预处理
line = re.sub(r"\W+", " ", line) # 将非字母数字下划线替换为空格
line = self._strQ2B(line).lower() # 全角转半角,转小写
line = self._tradi2simp(line) # 繁体转简体
# 2. 按语言切分
arr = self._split_by_lang(line) # 将文本分割成 (文本片段, 是否为中文) 的列表
res = [] # 存储最终分词结果的列表
# 3. 分段处理
for L, lang in arr: # L 是文本片段lang 是布尔值表示是否为中文
if not lang: # 如果不是中文
# 使用NLTK进行分词、词干提取和词形还原
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
continue # 处理下一个片段
# 如果是中文但长度小于2或匹配纯英文/数字模式,则直接添加,不进一步切分
if len(L) < 2 or re.match(r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
res.append(L)
continue # 处理下一个片段
# 对较长的中文片段执行FMM和BMM
tks, s = self.maxForward_(L) # tks: FMM结果列表, s: FMM评分 FMM (Forward Maximum Matching - 正向最大匹配)
tks1, s1 = self.maxBackward_(L) # tks1: BMM结果列表, s1: BMM评分 BMM (Backward Maximum Matching - 逆向最大匹配)
# 初始化用于比较FMM和BMM结果的指针
i, j = 0, 0 # i 指向 tks1 (BMM), j 指向 tks (FMM)
_i, _j = 0, 0 # _i, _j 记录上一段歧义处理的结束位置
# 3.b.i. 查找 FMM 和 BMM 从头开始的最长相同前缀
same = 0 # 相同词元的数量
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
if same > 0: # 如果存在相同前缀
res.append(" ".join(tks[j : j + same])) # 将FMM中的相同部分加入结果
# 更新指针到不同部分的开始
_i = i + same
_j = j + same
# 准备开始处理可能存在的歧义部分
j = _j + 1 # FMM指针向后移动一位或多位取决于下面tk的构造
i = _i + 1 # BMM指针向后移动一位或多位
# 3.b.ii. 迭代处理 FMM 和 BMM 结果中的歧义部分
while i < len(tks1) and j < len(tks):
# tk1 是 BMM 从上一个同步点 _i 到当前指针 i 形成的字符串
# tk 是 FMM 从上一个同步点 _j 到当前指针 j 形成的字符串
tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
if tk1 != tk: # 如果这两个子串不相同说明FMM和BMM的切分路径出现分叉
# 尝试通过移动较短子串的指针来寻找下一个可能的同步点
if len(tk1) > len(tk):
j += 1
else:
i += 1
continue # 继续外层while循环
# 如果子串相同但当前位置的单个词元不同则这也是一个需要DFS解决的歧义点
if tks1[i] != tks[j]: # 注意这里比较的是tks1[i]和tks[j]而不是tk1和tk的最后一个词
i += 1
j += 1
continue
# 从_j到j (不包括j处的词) 这段 FMM 产生的文本是歧义的需要用DFS解决。
tkslist = []
self.dfs_("".join(tks[_j:j]), 0, [], tkslist) # 对这段FMM子串进行DFS
if tkslist: # 确保DFS有结果
res.append(" ".join(self.sortTks_(tkslist)[0][0])) # 取最优DFS结果
# 处理当前这个相同的词元 (tks[j] 或 tks1[i]) 以及之后连续相同的词元
same = 1
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
res.append(" ".join(tks[j : j + same])) # 将FMM中从j开始的连续相同部分加入结果
# 更新指针到下一个不同部分的开始
_i = i + same
_j = j + same
j = _j + 1
i = _i + 1
# 3.c. 处理 FMM 或 BMM 可能的尾部剩余部分
# 如果 _i (BMM的已处理指针) 还没有到达 tks1 的末尾
# (并且假设 _j (FMM的已处理指针) 也未到 tks 的末尾,且剩余部分代表相同的原始文本)
if _i < len(tks1):
# 断言确保FMM的已处理指针也未到末尾
assert _j < len(tks)
# 断言FMM和BMM的剩余部分代表相同的原始字符串
assert "".join(tks1[_i:]) == "".join(tks[_j:])
# 对FMM的剩余部分代表了原始文本的尾部进行DFS分词
tkslist = []
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
if tkslist: # 确保DFS有结果
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
# 4. 后处理
res_str = " ".join(res) # 将所有分词结果用空格连接
return self.merge_(res_str) # 返回经过合并处理的最终分词结果
def is_chinese(s):
if s >= "\u4e00" and s <= "\u9fa5":
return True
else:
return False
if __name__ == "__main__":
tknzr = RagTokenizer()
tks = tknzr.tokenize("基于动态视觉相机的光流估计研究_孙文义.pdf")
print(tks)
tks = tknzr.tokenize("图3-1 事件流输入表征。a事件帧b时间面c体素网格\n(a)\n(b)\n(c)")
print(tks)

View File

@ -7,6 +7,7 @@ import time
from datetime import datetime from datetime import datetime
from utils import generate_uuid from utils import generate_uuid
from database import DB_CONFIG from database import DB_CONFIG
# 解析相关模块 # 解析相关模块
from .document_parser import perform_parse, _update_document_progress from .document_parser import perform_parse, _update_document_progress
@ -14,15 +15,15 @@ from .document_parser import perform_parse, _update_document_progress
# 结构: { kb_id: {"status": "running/completed/failed", "total": N, "current": M, "message": "...", "start_time": timestamp} } # 结构: { kb_id: {"status": "running/completed/failed", "total": N, "current": M, "message": "...", "start_time": timestamp} }
SEQUENTIAL_BATCH_TASKS = {} SEQUENTIAL_BATCH_TASKS = {}
class KnowledgebaseService:
class KnowledgebaseService:
@classmethod @classmethod
def _get_db_connection(cls): def _get_db_connection(cls):
"""创建数据库连接""" """创建数据库连接"""
return mysql.connector.connect(**DB_CONFIG) return mysql.connector.connect(**DB_CONFIG)
@classmethod @classmethod
def get_knowledgebase_list(cls, page=1, size=10, name='', sort_by="create_time", sort_order="desc"): def get_knowledgebase_list(cls, page=1, size=10, name="", sort_by="create_time", sort_order="desc"):
"""获取知识库列表""" """获取知识库列表"""
conn = cls._get_db_connection() conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
@ -57,7 +58,7 @@ class KnowledgebaseService:
query += f" {sort_clause}" query += f" {sort_clause}"
query += " LIMIT %s OFFSET %s" query += " LIMIT %s OFFSET %s"
params.extend([size, (page-1)*size]) params.extend([size, (page - 1) * size])
cursor.execute(query, params) cursor.execute(query, params)
results = cursor.fetchall() results = cursor.fetchall()
@ -65,33 +66,30 @@ class KnowledgebaseService:
# 处理结果 # 处理结果
for result in results: for result in results:
# 处理空描述 # 处理空描述
if not result.get('description'): if not result.get("description"):
result['description'] = "暂无描述" result["description"] = "暂无描述"
# 处理时间格式 # 处理时间格式
if result.get('create_date'): if result.get("create_date"):
if isinstance(result['create_date'], datetime): if isinstance(result["create_date"], datetime):
result['create_date'] = result['create_date'].strftime('%Y-%m-%d %H:%M:%S') result["create_date"] = result["create_date"].strftime("%Y-%m-%d %H:%M:%S")
elif isinstance(result['create_date'], str): elif isinstance(result["create_date"], str):
try: try:
# 尝试解析已有字符串格式 # 尝试解析已有字符串格式
datetime.strptime(result['create_date'], '%Y-%m-%d %H:%M:%S') datetime.strptime(result["create_date"], "%Y-%m-%d %H:%M:%S")
except ValueError: except ValueError:
result['create_date'] = "" result["create_date"] = ""
# 获取总数 # 获取总数
count_query = "SELECT COUNT(*) as total FROM knowledgebase" count_query = "SELECT COUNT(*) as total FROM knowledgebase"
if name: if name:
count_query += " WHERE name LIKE %s" count_query += " WHERE name LIKE %s"
cursor.execute(count_query, params[:1] if name else []) cursor.execute(count_query, params[:1] if name else [])
total = cursor.fetchone()['total'] total = cursor.fetchone()["total"]
cursor.close() cursor.close()
conn.close() conn.close()
return { return {"list": results, "total": total}
'list': results,
'total': total
}
@classmethod @classmethod
def get_knowledgebase_detail(cls, kb_id): def get_knowledgebase_detail(cls, kb_id):
@ -115,17 +113,17 @@ class KnowledgebaseService:
if result: if result:
# 处理空描述 # 处理空描述
if not result.get('description'): if not result.get("description"):
result['description'] = "暂无描述" result["description"] = "暂无描述"
# 处理时间格式 # 处理时间格式
if result.get('create_date'): if result.get("create_date"):
if isinstance(result['create_date'], datetime): if isinstance(result["create_date"], datetime):
result['create_date'] = result['create_date'].strftime('%Y-%m-%d %H:%M:%S') result["create_date"] = result["create_date"].strftime("%Y-%m-%d %H:%M:%S")
elif isinstance(result['create_date'], str): elif isinstance(result["create_date"], str):
try: try:
datetime.strptime(result['create_date'], '%Y-%m-%d %H:%M:%S') datetime.strptime(result["create_date"], "%Y-%m-%d %H:%M:%S")
except ValueError: except ValueError:
result['create_date'] = "" result["create_date"] = ""
cursor.close() cursor.close()
conn.close() conn.close()
@ -157,7 +155,7 @@ class KnowledgebaseService:
try: try:
# 检查知识库名称是否已存在 # 检查知识库名称是否已存在
exists = cls._check_name_exists(data['name']) exists = cls._check_name_exists(data["name"])
if exists: if exists:
raise Exception("知识库名称已存在") raise Exception("知识库名称已存在")
@ -165,8 +163,8 @@ class KnowledgebaseService:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 使用传入的 creator_id 作为 tenant_id 和 created_by # 使用传入的 creator_id 作为 tenant_id 和 created_by
tenant_id = data.get('creator_id') tenant_id = data.get("creator_id")
created_by = data.get('creator_id') created_by = data.get("creator_id")
if not tenant_id: if not tenant_id:
# 如果没有提供 creator_id则使用默认值 # 如果没有提供 creator_id则使用默认值
@ -181,8 +179,8 @@ class KnowledgebaseService:
earliest_user = cursor.fetchone() earliest_user = cursor.fetchone()
if earliest_user: if earliest_user:
tenant_id = earliest_user['id'] tenant_id = earliest_user["id"]
created_by = earliest_user['id'] created_by = earliest_user["id"]
print(f"使用创建时间最早的用户ID作为tenant_id和created_by: {tenant_id}") print(f"使用创建时间最早的用户ID作为tenant_id和created_by: {tenant_id}")
else: else:
# 如果找不到用户,使用默认值 # 如果找不到用户,使用默认值
@ -196,10 +194,9 @@ class KnowledgebaseService:
else: else:
print(f"使用传入的 creator_id 作为 tenant_id 和 created_by: {tenant_id}") print(f"使用传入的 creator_id 作为 tenant_id 和 created_by: {tenant_id}")
# --- 获取动态 embd_id --- # --- 获取动态 embd_id ---
dynamic_embd_id = None dynamic_embd_id = None
default_embd_id = 'bge-m3' # Fallback default default_embd_id = "bge-m3" # Fallback default
try: try:
query_embedding_model = """ query_embedding_model = """
SELECT llm_name SELECT llm_name
@ -211,8 +208,8 @@ class KnowledgebaseService:
cursor.execute(query_embedding_model) cursor.execute(query_embedding_model)
embedding_model = cursor.fetchone() embedding_model = cursor.fetchone()
if embedding_model and embedding_model.get('llm_name'): if embedding_model and embedding_model.get("llm_name"):
dynamic_embd_id = embedding_model['llm_name'] dynamic_embd_id = embedding_model["llm_name"]
# 对硅基流动平台进行特异性处理 # 对硅基流动平台进行特异性处理
if dynamic_embd_id == "netease-youdao/bce-embedding-base_v1": if dynamic_embd_id == "netease-youdao/bce-embedding-base_v1":
dynamic_embd_id = "BAAI/bge-m3" dynamic_embd_id = "BAAI/bge-m3"
@ -223,10 +220,10 @@ class KnowledgebaseService:
except Exception as e: except Exception as e:
dynamic_embd_id = default_embd_id dynamic_embd_id = default_embd_id
print(f"查询 embedding 模型失败: {str(e)},使用默认值: {dynamic_embd_id}") print(f"查询 embedding 模型失败: {str(e)},使用默认值: {dynamic_embd_id}")
traceback.print_exc() # Log the full traceback for debugging traceback.print_exc() # Log the full traceback for debugging
current_time = datetime.now() current_time = datetime.now()
create_date = current_time.strftime('%Y-%m-%d %H:%M:%S') create_date = current_time.strftime("%Y-%m-%d %H:%M:%S")
create_time = int(current_time.timestamp() * 1000) # 毫秒级时间戳 create_time = int(current_time.timestamp() * 1000) # 毫秒级时间戳
update_date = create_date update_date = create_date
update_time = create_time update_time = create_time
@ -249,42 +246,47 @@ class KnowledgebaseService:
""" """
# 设置默认值 # 设置默认值
default_parser_config = json.dumps({ default_parser_config = json.dumps(
"layout_recognize": "MinerU", {
"chunk_token_num": 512, "layout_recognize": "MinerU",
"delimiter": "\n!?;。;!?", "chunk_token_num": 512,
"auto_keywords": 0, "delimiter": "\n!?;。;!?",
"auto_questions": 0, "auto_keywords": 0,
"html4excel": False, "auto_questions": 0,
"raptor": {"use_raptor": False}, "html4excel": False,
"graphrag": {"use_graphrag": False} "raptor": {"use_raptor": False},
}) "graphrag": {"use_graphrag": False},
}
)
kb_id = generate_uuid() kb_id = generate_uuid()
cursor.execute(query, ( cursor.execute(
kb_id, # id query,
create_time, # create_time (
create_date, # create_date kb_id, # id
update_time, # update_time create_time, # create_time
update_date, # update_date create_date, # create_date
None, # avatar update_time, # update_time
tenant_id, # tenant_id update_date, # update_date
data['name'], # name None, # avatar
data.get('language', 'Chinese'), # language tenant_id, # tenant_id
data.get('description', ''), # description data["name"], # name
dynamic_embd_id, # embd_id data.get("language", "Chinese"), # language
data.get('permission', 'me'), # permission data.get("description", ""), # description
created_by, # created_by - 使用内部获取的值 dynamic_embd_id, # embd_id
0, # doc_num data.get("permission", "me"), # permission
0, # token_num created_by, # created_by - 使用内部获取的值
0, # chunk_num 0, # doc_num
0.7, # similarity_threshold 0, # token_num
0.3, # vector_similarity_weight 0, # chunk_num
'naive', # parser_id 0.7, # similarity_threshold
default_parser_config, # parser_config 0.3, # vector_similarity_weight
0, # pagerank "naive", # parser_id
'1' # status default_parser_config, # parser_config
)) 0, # pagerank
"1", # status
),
)
conn.commit() conn.commit()
cursor.close() cursor.close()
@ -310,8 +312,8 @@ class KnowledgebaseService:
cursor = conn.cursor() cursor = conn.cursor()
# 如果要更新名称,先检查名称是否已存在 # 如果要更新名称,先检查名称是否已存在
if data.get('name') and data['name'] != kb['name']: if data.get("name") and data["name"] != kb["name"]:
exists = cls._check_name_exists(data['name']) exists = cls._check_name_exists(data["name"])
if exists: if exists:
raise Exception("知识库名称已存在") raise Exception("知识库名称已存在")
@ -319,21 +321,21 @@ class KnowledgebaseService:
update_fields = [] update_fields = []
params = [] params = []
if data.get('name'): if data.get("name"):
update_fields.append("name = %s") update_fields.append("name = %s")
params.append(data['name']) params.append(data["name"])
if 'description' in data: if "description" in data:
update_fields.append("description = %s") update_fields.append("description = %s")
params.append(data['description']) params.append(data["description"])
if 'permission' in data: if "permission" in data:
update_fields.append("permission = %s") update_fields.append("permission = %s")
params.append(data['permission']) params.append(data["permission"])
# 更新时间 # 更新时间
current_time = datetime.now() current_time = datetime.now()
update_date = current_time.strftime('%Y-%m-%d %H:%M:%S') update_date = current_time.strftime("%Y-%m-%d %H:%M:%S")
update_fields.append("update_date = %s") update_fields.append("update_date = %s")
params.append(update_date) params.append(update_date)
@ -344,7 +346,7 @@ class KnowledgebaseService:
# 构建并执行更新语句 # 构建并执行更新语句
query = f""" query = f"""
UPDATE knowledgebase UPDATE knowledgebase
SET {', '.join(update_fields)} SET {", ".join(update_fields)}
WHERE id = %s WHERE id = %s
""" """
params.append(kb_id) params.append(kb_id)
@ -396,8 +398,7 @@ class KnowledgebaseService:
cursor = conn.cursor() cursor = conn.cursor()
# 检查所有ID是否存在 # 检查所有ID是否存在
check_query = "SELECT id FROM knowledgebase WHERE id IN (%s)" % \ check_query = "SELECT id FROM knowledgebase WHERE id IN (%s)" % ",".join(["%s"] * len(kb_ids))
','.join(['%s'] * len(kb_ids))
cursor.execute(check_query, kb_ids) cursor.execute(check_query, kb_ids)
existing_ids = [row[0] for row in cursor.fetchall()] existing_ids = [row[0] for row in cursor.fetchall()]
@ -406,8 +407,7 @@ class KnowledgebaseService:
raise Exception(f"以下知识库不存在: {', '.join(missing_ids)}") raise Exception(f"以下知识库不存在: {', '.join(missing_ids)}")
# 执行批量删除 # 执行批量删除
delete_query = "DELETE FROM knowledgebase WHERE id IN (%s)" % \ delete_query = "DELETE FROM knowledgebase WHERE id IN (%s)" % ",".join(["%s"] * len(kb_ids))
','.join(['%s'] * len(kb_ids))
cursor.execute(delete_query, kb_ids) cursor.execute(delete_query, kb_ids)
conn.commit() conn.commit()
@ -420,7 +420,7 @@ class KnowledgebaseService:
raise Exception(f"批量删除知识库失败: {str(e)}") raise Exception(f"批量删除知识库失败: {str(e)}")
@classmethod @classmethod
def get_knowledgebase_documents(cls, kb_id, page=1, size=10, name='', sort_by="create_time", sort_order="desc"): def get_knowledgebase_documents(cls, kb_id, page=1, size=10, name="", sort_by="create_time", sort_order="desc"):
"""获取知识库下的文档列表""" """获取知识库下的文档列表"""
try: try:
conn = cls._get_db_connection() conn = cls._get_db_connection()
@ -466,15 +466,15 @@ class KnowledgebaseService:
query += f" {sort_clause}" query += f" {sort_clause}"
query += " LIMIT %s OFFSET %s" query += " LIMIT %s OFFSET %s"
params.extend([size, (page-1)*size]) params.extend([size, (page - 1) * size])
cursor.execute(query, params) cursor.execute(query, params)
results = cursor.fetchall() results = cursor.fetchall()
# 处理日期时间格式 # 处理日期时间格式
for result in results: for result in results:
if result.get('create_date'): if result.get("create_date"):
result['create_date'] = result['create_date'].strftime('%Y-%m-%d %H:%M:%S') result["create_date"] = result["create_date"].strftime("%Y-%m-%d %H:%M:%S")
# 获取总数 # 获取总数
count_query = "SELECT COUNT(*) as total FROM document WHERE kb_id = %s" count_query = "SELECT COUNT(*) as total FROM document WHERE kb_id = %s"
@ -484,15 +484,12 @@ class KnowledgebaseService:
count_params.append(f"%{name}%") count_params.append(f"%{name}%")
cursor.execute(count_query, count_params) cursor.execute(count_query, count_params)
total = cursor.fetchone()['total'] total = cursor.fetchone()["total"]
cursor.close() cursor.close()
conn.close() conn.close()
return { return {"list": results, "total": total}
'list': results,
'total': total
}
except Exception as e: except Exception as e:
print(f"获取知识库文档列表失败: {str(e)}") print(f"获取知识库文档列表失败: {str(e)}")
@ -519,10 +516,10 @@ class KnowledgebaseService:
earliest_user = cursor.fetchone() earliest_user = cursor.fetchone()
if earliest_user: if earliest_user:
created_by = earliest_user['id'] created_by = earliest_user["id"]
print(f"使用创建时间最早的用户ID: {created_by}") print(f"使用创建时间最早的用户ID: {created_by}")
else: else:
created_by = 'system' created_by = "system"
print("未找到用户, 使用默认用户ID: system") print("未找到用户, 使用默认用户ID: system")
cursor.close() cursor.close()
@ -543,7 +540,7 @@ class KnowledgebaseService:
SELECT id, name, location, size, type SELECT id, name, location, size, type
FROM file FROM file
WHERE id IN (%s) WHERE id IN (%s)
""" % ','.join(['%s'] * len(file_ids)) """ % ",".join(["%s"] * len(file_ids))
print(f"[DEBUG] 执行文件查询SQL: {file_query}") print(f"[DEBUG] 执行文件查询SQL: {file_query}")
print(f"[DEBUG] 查询参数: {file_ids}") print(f"[DEBUG] 查询参数: {file_ids}")
@ -592,20 +589,18 @@ class KnowledgebaseService:
# 设置默认值 # 设置默认值
default_parser_id = "naive" default_parser_id = "naive"
default_parser_config = json.dumps({ default_parser_config = json.dumps(
"layout_recognize": "MinerU", {
"chunk_token_num": 512, "layout_recognize": "MinerU",
"delimiter": "\n!?;。;!?", "chunk_token_num": 512,
"auto_keywords": 0, "delimiter": "\n!?;。;!?",
"auto_questions": 0, "auto_keywords": 0,
"html4excel": False, "auto_questions": 0,
"raptor": { "html4excel": False,
"use_raptor": False "raptor": {"use_raptor": False},
}, "graphrag": {"use_graphrag": False},
"graphrag": {
"use_graphrag": False
} }
}) )
default_source_type = "local" default_source_type = "local"
# 插入document表 # 插入document表
@ -626,11 +621,30 @@ class KnowledgebaseService:
""" """
doc_params = [ doc_params = [
doc_id, create_time, current_date, create_time, current_date, # ID和时间 doc_id,
None, kb_id, default_parser_id, default_parser_config, default_source_type, # thumbnail到source_type create_time,
file_type, created_by, file_name, file_location, file_size, # type到size current_date,
0, 0, 0.0, None, None, # token_num到process_begin_at create_time,
0.0, None, '0', '1' # process_duation到status current_date, # ID和时间
None,
kb_id,
default_parser_id,
default_parser_config,
default_source_type, # thumbnail到source_type
file_type,
created_by,
file_name,
file_location,
file_size, # type到size
0,
0,
0.0,
None,
None, # token_num到process_begin_at
0.0,
None,
"0",
"1", # process_duation到status
] ]
cursor.execute(doc_query, doc_params) cursor.execute(doc_query, doc_params)
@ -647,10 +661,7 @@ class KnowledgebaseService:
) )
""" """
f2d_params = [ f2d_params = [f2d_id, create_time, current_date, create_time, current_date, file_id, doc_id]
f2d_id, create_time, current_date, create_time, current_date,
file_id, doc_id
]
cursor.execute(f2d_query, f2d_params) cursor.execute(f2d_query, f2d_params)
@ -673,14 +684,13 @@ class KnowledgebaseService:
cursor.close() cursor.close()
conn.close() conn.close()
return { return {"added_count": added_count}
"added_count": added_count
}
except Exception as e: except Exception as e:
print(f"[ERROR] 添加文档失败: {str(e)}") print(f"[ERROR] 添加文档失败: {str(e)}")
print(f"[ERROR] 错误类型: {type(e)}") print(f"[ERROR] 错误类型: {type(e)}")
import traceback import traceback
print(f"[ERROR] 堆栈信息: {traceback.format_exc()}") print(f"[ERROR] 堆栈信息: {traceback.format_exc()}")
raise Exception(f"添加文档到知识库失败: {str(e)}") raise Exception(f"添加文档到知识库失败: {str(e)}")
@ -757,7 +767,7 @@ class KnowledgebaseService:
f2d_result = cursor.fetchone() f2d_result = cursor.fetchone()
if not f2d_result: if not f2d_result:
raise Exception("无法找到文件到文档的映射关系") raise Exception("无法找到文件到文档的映射关系")
file_id = f2d_result['file_id'] file_id = f2d_result["file_id"]
file_query = "SELECT parent_id FROM file WHERE id = %s" file_query = "SELECT parent_id FROM file WHERE id = %s"
cursor.execute(file_query, (file_id,)) cursor.execute(file_query, (file_id,))
@ -767,10 +777,10 @@ class KnowledgebaseService:
cursor.close() cursor.close()
conn.close() conn.close()
conn = None # 确保连接已关闭 conn = None # 确保连接已关闭
# 2. 更新文档状态为处理中 (使用 parser 模块的函数) # 2. 更新文档状态为处理中 (使用 parser 模块的函数)
_update_document_progress(doc_id, status='2', run='1', progress=0.0, message='开始解析') _update_document_progress(doc_id, status="2", run="1", progress=0.0, message="开始解析")
# 3. 调用后台解析函数 # 3. 调用后台解析函数
embedding_config = cls.get_system_embedding_config() embedding_config = cls.get_system_embedding_config()
@ -783,9 +793,9 @@ class KnowledgebaseService:
print(f"文档解析启动或执行过程中出错 (Doc ID: {doc_id}): {str(e)}") print(f"文档解析启动或执行过程中出错 (Doc ID: {doc_id}): {str(e)}")
# 确保在异常时更新状态为失败 # 确保在异常时更新状态为失败
try: try:
_update_document_progress(doc_id, status='1', run='0', message=f"解析失败: {str(e)}") _update_document_progress(doc_id, status="1", run="0", message=f"解析失败: {str(e)}")
except Exception as update_err: except Exception as update_err:
print(f"更新文档失败状态时出错 (Doc ID: {doc_id}): {str(update_err)}") print(f"更新文档失败状态时出错 (Doc ID: {doc_id}): {str(update_err)}")
# raise Exception(f"文档解析失败: {str(e)}") # raise Exception(f"文档解析失败: {str(e)}")
return {"success": False, "error": f"文档解析失败: {str(e)}"} return {"success": False, "error": f"文档解析失败: {str(e)}"}
@ -801,22 +811,21 @@ class KnowledgebaseService:
try: try:
# 启动后台线程执行同步的 parse_document 方法 # 启动后台线程执行同步的 parse_document 方法
thread = threading.Thread(target=cls.parse_document, args=(doc_id,)) thread = threading.Thread(target=cls.parse_document, args=(doc_id,))
thread.daemon = True # 设置为守护线程,主程序退出时线程也退出 thread.daemon = True # 设置为守护线程,主程序退出时线程也退出
thread.start() thread.start()
# 立即返回,表示任务已提交 # 立即返回,表示任务已提交
return { return {
"task_id": doc_id, # 使用 doc_id 作为任务标识符 "task_id": doc_id, # 使用 doc_id 作为任务标识符
"status": "processing", "status": "processing",
"message": "文档解析任务已提交到后台处理" "message": "文档解析任务已提交到后台处理",
} }
except Exception as e: except Exception as e:
print(f"启动异步解析任务失败 (Doc ID: {doc_id}): {str(e)}") print(f"启动异步解析任务失败 (Doc ID: {doc_id}): {str(e)}")
# 可以在这里尝试更新文档状态为失败
try: try:
_update_document_progress(doc_id, status='1', run='0', message=f"启动解析失败: {str(e)}") _update_document_progress(doc_id, status="1", run="0", message=f"启动解析失败: {str(e)}")
except Exception as update_err: except Exception as update_err:
print(f"更新文档启动失败状态时出错 (Doc ID: {doc_id}): {str(update_err)}") print(f"更新文档启动失败状态时出错 (Doc ID: {doc_id}): {str(update_err)}")
raise Exception(f"启动异步解析任务失败: {str(e)}") raise Exception(f"启动异步解析任务失败: {str(e)}")
@classmethod @classmethod
@ -845,7 +854,7 @@ class KnowledgebaseService:
try: try:
progress_value = float(result["progress"]) progress_value = float(result["progress"])
except (ValueError, TypeError): except (ValueError, TypeError):
progress_value = 0.0 # 或记录错误 progress_value = 0.0 # 或记录错误
return { return {
"progress": progress_value, "progress": progress_value,
@ -876,7 +885,7 @@ class KnowledgebaseService:
cursor.execute(query) cursor.execute(query)
result = cursor.fetchone() result = cursor.fetchone()
if result: if result:
return result[0] # 返回用户 ID return result[0] # 返回用户 ID
else: else:
print("警告: 数据库中没有用户!") print("警告: 数据库中没有用户!")
return None return None
@ -904,26 +913,26 @@ class KnowledgebaseService:
payload = {"input": ["Test connection"], "model": model_name} payload = {"input": ["Test connection"], "model": model_name}
if not base_url.startswith(('http://', 'https://')): if not base_url.startswith(("http://", "https://")):
base_url = 'http://' + base_url base_url = "http://" + base_url
if not base_url.endswith('/'): if not base_url.endswith("/"):
base_url += '/' base_url += "/"
# --- URL 拼接优化 --- # --- URL 拼接优化 ---
endpoint_segment = "embeddings" endpoint_segment = "embeddings"
full_endpoint_path = "v1/embeddings" full_endpoint_path = "v1/embeddings"
# 移除末尾斜杠以方便判断 # 移除末尾斜杠以方便判断
normalized_base_url = base_url.rstrip('/') normalized_base_url = base_url.rstrip("/")
if normalized_base_url.endswith('/v1'): if normalized_base_url.endswith("/v1"):
# 如果 base_url 已经是 http://host/v1 形式 # 如果 base_url 已经是 http://host/v1 形式
current_test_url = normalized_base_url + '/' + endpoint_segment current_test_url = normalized_base_url + "/" + endpoint_segment
elif normalized_base_url.endswith('/embeddings'): elif normalized_base_url.endswith("/embeddings"):
# 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API无需再进行处理) # 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API无需再进行处理)
current_test_url = normalized_base_url current_test_url = normalized_base_url
else: else:
# 如果 base_url 是 http://host 或 http://host/api 形式 # 如果 base_url 是 http://host 或 http://host/api 形式
current_test_url = normalized_base_url + '/' + full_endpoint_path current_test_url = normalized_base_url + "/" + full_endpoint_path
# --- 结束 URL 拼接优化 --- # --- 结束 URL 拼接优化 ---
print(f"尝试请求 URL: {current_test_url}") print(f"尝试请求 URL: {current_test_url}")
@ -933,8 +942,9 @@ class KnowledgebaseService:
if response.status_code == 200: if response.status_code == 200:
res_json = response.json() res_json = response.json()
if ("data" in res_json and isinstance(res_json["data"], list) and len(res_json["data"]) > 0 and "embedding" in res_json["data"][0] and len(res_json["data"][0]["embedding"]) > 0) or \ if (
(isinstance(res_json, list) and len(res_json) > 0 and isinstance(res_json[0], list) and len(res_json[0]) > 0): "data" in res_json and isinstance(res_json["data"], list) and len(res_json["data"]) > 0 and "embedding" in res_json["data"][0] and len(res_json["data"][0]["embedding"]) > 0
) or (isinstance(res_json, list) and len(res_json) > 0 and isinstance(res_json[0], list) and len(res_json[0]) > 0):
print(f"连接测试成功: {current_test_url}") print(f"连接测试成功: {current_test_url}")
return True, "连接成功" return True, "连接成功"
else: else:
@ -958,7 +968,7 @@ class KnowledgebaseService:
cursor = None cursor = None
try: try:
conn = cls._get_db_connection() conn = cls._get_db_connection()
cursor = conn.cursor(dictionary=True) # 使用字典游标方便访问列名 cursor = conn.cursor(dictionary=True) # 使用字典游标方便访问列名
# 1. 找到最早创建的用户ID # 1. 找到最早创建的用户ID
query_earliest_user = """ query_earliest_user = """
@ -971,13 +981,9 @@ class KnowledgebaseService:
if not earliest_user: if not earliest_user:
# 如果没有用户,返回空配置 # 如果没有用户,返回空配置
return { return {"llm_name": "", "api_key": "", "api_base": ""}
"llm_name": "",
"api_key": "",
"api_base": ""
}
earliest_user_id = earliest_user['id'] earliest_user_id = earliest_user["id"]
# 2. 根据最早用户ID查询 tenant_llm 表中 model_type 为 embedding 的配置 # 2. 根据最早用户ID查询 tenant_llm 表中 model_type 为 embedding 的配置
query_embedding_config = """ query_embedding_config = """
@ -995,8 +1001,8 @@ class KnowledgebaseService:
api_key = config.get("api_key", "") api_key = config.get("api_key", "")
api_base = config.get("api_base", "") api_base = config.get("api_base", "")
# 对模型名称进行处理 (可选,根据需要保留或移除) # 对模型名称进行处理 (可选,根据需要保留或移除)
if llm_name and '___' in llm_name: if llm_name and "___" in llm_name:
llm_name = llm_name.split('___')[0] llm_name = llm_name.split("___")[0]
# (对硅基流动平台进行特异性处理) # (对硅基流动平台进行特异性处理)
if llm_name == "netease-youdao/bce-embedding-base_v1": if llm_name == "netease-youdao/bce-embedding-base_v1":
@ -1007,18 +1013,10 @@ class KnowledgebaseService:
api_base = "https://api.siliconflow.cn/v1/embeddings" api_base = "https://api.siliconflow.cn/v1/embeddings"
# 如果有配置,返回 # 如果有配置,返回
return { return {"llm_name": llm_name, "api_key": api_key, "api_base": api_base}
"llm_name": llm_name,
"api_key": api_key,
"api_base": api_base
}
else: else:
# 如果最早的用户没有 embedding 配置,返回空 # 如果最早的用户没有 embedding 配置,返回空
return { return {"llm_name": "", "api_key": "", "api_base": ""}
"llm_name": "",
"api_key": "",
"api_base": ""
}
except Exception as e: except Exception as e:
print(f"获取系统 Embedding 配置时出错: {e}") print(f"获取系统 Embedding 配置时出错: {e}")
traceback.print_exc() traceback.print_exc()
@ -1040,11 +1038,7 @@ class KnowledgebaseService:
print(f"开始设置系统 Embedding 配置: {llm_name}, {api_base}, {api_key}") print(f"开始设置系统 Embedding 配置: {llm_name}, {api_base}, {api_key}")
# 执行连接测试 # 执行连接测试
is_connected, message = cls._test_embedding_connection( is_connected, message = cls._test_embedding_connection(base_url=api_base, model_name=llm_name, api_key=api_key)
base_url=api_base,
model_name=llm_name,
api_key=api_key
)
if not is_connected: if not is_connected:
# 返回具体的测试失败原因给调用者(路由层)处理 # 返回具体的测试失败原因给调用者(路由层)处理
@ -1098,10 +1092,10 @@ class KnowledgebaseService:
# # 返回 False 和错误信息给路由层 # # 返回 False 和错误信息给路由层
# return False, f"保存配置时数据库出错: {e}" # return False, f"保存配置时数据库出错: {e}"
# finally: # finally:
# if cursor: # if cursor:
# cursor.close() # cursor.close()
# if conn and conn.is_connected(): # if conn and conn.is_connected():
# conn.close() # conn.close()
# 顺序批量解析 (核心逻辑,在后台线程运行) # 顺序批量解析 (核心逻辑,在后台线程运行)
@classmethod @classmethod
@ -1111,7 +1105,7 @@ class KnowledgebaseService:
task_info = SEQUENTIAL_BATCH_TASKS.get(kb_id) task_info = SEQUENTIAL_BATCH_TASKS.get(kb_id)
if not task_info: if not task_info:
print(f"[Seq Batch ERROR] Task info for KB {kb_id} not found at start.") print(f"[Seq Batch ERROR] Task info for KB {kb_id} not found at start.")
return # 理论上不应发生 return # 理论上不应发生
conn = None conn = None
cursor = None cursor = None
@ -1139,7 +1133,7 @@ class KnowledgebaseService:
task_info["message"] = f"共找到 {total_count} 个文档待解析。" task_info["message"] = f"共找到 {total_count} 个文档待解析。"
task_info["start_time"] = time.time() task_info["start_time"] = time.time()
start_time = time.time() start_time = time.time()
SEQUENTIAL_BATCH_TASKS[kb_id] = task_info # 更新字典 SEQUENTIAL_BATCH_TASKS[kb_id] = task_info # 更新字典
if not documents_to_parse: if not documents_to_parse:
task_info["status"] = "completed" task_info["status"] = "completed"
@ -1152,14 +1146,14 @@ class KnowledgebaseService:
# 按顺序解析每个文档 # 按顺序解析每个文档
for i, doc in enumerate(documents_to_parse): for i, doc in enumerate(documents_to_parse):
doc_id = doc['id'] doc_id = doc["id"]
doc_name = doc['name'] doc_name = doc["name"]
# 更新当前进度 # 更新当前进度
task_info["current"] = i + 1 task_info["current"] = i + 1
task_info["message"] = f"正在解析: {doc_name} ({i+1}/{total_count})" task_info["message"] = f"正在解析: {doc_name} ({i + 1}/{total_count})"
SEQUENTIAL_BATCH_TASKS[kb_id] = task_info SEQUENTIAL_BATCH_TASKS[kb_id] = task_info
print(f"[Seq Batch] KB {kb_id}: ({i+1}/{total_count}) Parsing {doc_name} (ID: {doc_id})...") print(f"[Seq Batch] KB {kb_id}: ({i + 1}/{total_count}) Parsing {doc_name} (ID: {doc_id})...")
try: try:
# 调用同步的 parse_document 方法 # 调用同步的 parse_document 方法
@ -1172,16 +1166,15 @@ class KnowledgebaseService:
failed_count += 1 failed_count += 1
error_msg = result.get("message", "未知错误") if result else "未知错误" error_msg = result.get("message", "未知错误") if result else "未知错误"
print(f"[Seq Batch] KB {kb_id}: Document {doc_id} parsing failed: {error_msg}") print(f"[Seq Batch] KB {kb_id}: Document {doc_id} parsing failed: {error_msg}")
# 即使单个失败,也继续处理下一个
except Exception as e: except Exception as e:
failed_count += 1 failed_count += 1
print(f"[Seq Batch ERROR] KB {kb_id}: Error calling parse_document for {doc_id}: {str(e)}") print(f"[Seq Batch ERROR] KB {kb_id}: Error calling parse_document for {doc_id}: {str(e)}")
traceback.print_exc() traceback.print_exc()
# 尝试更新文档状态为失败,以防 parse_document 内部未处理 # 更新文档状态为失败
try: try:
_update_document_progress(doc_id, status='1', run='0', progress=0.0, message=f"批量任务中解析失败: {str(e)[:255]}") _update_document_progress(doc_id, status="1", run="0", progress=0.0, message=f"批量任务中解析失败: {str(e)[:255]}")
except Exception as update_err: except Exception as update_err:
print(f"[Service-ERROR] 更新文档 {doc_id} 失败状态时出错: {str(update_err)}") print(f"[Service-ERROR] 更新文档 {doc_id} 失败状态时出错: {str(update_err)}")
# 任务完成 # 任务完成
end_time = time.time() end_time = time.time()
@ -1189,7 +1182,7 @@ class KnowledgebaseService:
final_message = f"批量顺序解析完成。总计 {total_count} 个,成功 {parsed_count} 个,失败 {failed_count} 个。耗时 {duration} 秒。" final_message = f"批量顺序解析完成。总计 {total_count} 个,成功 {parsed_count} 个,失败 {failed_count} 个。耗时 {duration} 秒。"
task_info["status"] = "completed" task_info["status"] = "completed"
task_info["message"] = final_message task_info["message"] = final_message
task_info["current"] = total_count # 确保 current 等于 total task_info["current"] = total_count
SEQUENTIAL_BATCH_TASKS[kb_id] = task_info SEQUENTIAL_BATCH_TASKS[kb_id] = task_info
print(f"[Seq Batch] KB {kb_id}: {final_message}") print(f"[Seq Batch] KB {kb_id}: {final_message}")
@ -1217,13 +1210,7 @@ class KnowledgebaseService:
# 初始化任务状态 # 初始化任务状态
start_time = time.time() start_time = time.time()
SEQUENTIAL_BATCH_TASKS[kb_id] = { SEQUENTIAL_BATCH_TASKS[kb_id] = {"status": "starting", "total": 0, "current": 0, "message": "任务准备启动...", "start_time": start_time}
"status": "starting",
"total": 0,
"current": 0,
"message": "任务准备启动...",
"start_time": start_time
}
try: try:
# 启动后台线程执行顺序解析逻辑 # 启动后台线程执行顺序解析逻辑
@ -1239,13 +1226,7 @@ class KnowledgebaseService:
print(f"[Seq Batch ERROR] KB {kb_id}: {error_message}") print(f"[Seq Batch ERROR] KB {kb_id}: {error_message}")
traceback.print_exc() traceback.print_exc()
# 更新任务状态为失败 # 更新任务状态为失败
SEQUENTIAL_BATCH_TASKS[kb_id] = { SEQUENTIAL_BATCH_TASKS[kb_id] = {"status": "failed", "total": 0, "current": 0, "message": error_message, "start_time": start_time}
"status": "failed",
"total": 0,
"current": 0,
"message": error_message,
"start_time": start_time
}
return {"success": False, "message": error_message} return {"success": False, "message": error_message}
# 获取顺序批量解析进度 # 获取顺序批量解析进度
@ -1294,10 +1275,7 @@ class KnowledgebaseService:
doc["status"] = doc.get("status", "0") doc["status"] = doc.get("status", "0")
doc["run"] = doc.get("run", "0") doc["run"] = doc.get("run", "0")
return {"documents": documents_status}
return {
"documents": documents_status
}
except Exception as e: except Exception as e:
print(f"获取知识库 {kb_id} 文档进度失败: {str(e)}") print(f"获取知识库 {kb_id} 文档进度失败: {str(e)}")

View File

@ -35,22 +35,19 @@ def beAdoc(d, q, a, eng, row_num=-1):
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
""" """
Excel and csv(txt) format files are supported. Excel and csv(txt) format files are supported.
If the file is in excel format, there should be 2 column content and tags without header. If the file is in excel format, there should be 2 column content and tags without header.
And content column is ahead of tags column. And content column is ahead of tags column.
And it's O.K if it has multiple sheets as long as the columns are rightly composed. And it's O.K if it has multiple sheets as long as the columns are rightly composed.
If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate content and tags. If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate content and tags.
All the deformed lines will be ignored. All the deformed lines will be ignored.
Every pair will be treated as a chunk. Every pair will be treated as a chunk.
""" """
eng = lang.lower() == "english" eng = lang.lower() == "english"
res = [] res = []
doc = { doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
if re.search(r"\.xlsx?$", filename, re.IGNORECASE): if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
excel_parser = Excel() excel_parser = Excel()
@ -83,11 +80,9 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
content = "" content = ""
i += 1 i += 1
if len(res) % 999 == 0: if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + ( callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract TAG: {}".format(len(res)) + ( callback(0.6, ("Extract TAG: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res return res
@ -110,40 +105,61 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
res.append(beAdoc(deepcopy(doc), content, row[1], eng, i)) res.append(beAdoc(deepcopy(doc), content, row[1], eng, i))
content = "" content = ""
if len(res) % 999 == 0: if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + ( callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract TAG : {}".format(len(res)) + ( callback(0.6, ("Extract TAG : {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res return res
raise NotImplementedError( raise NotImplementedError("Excel, csv(txt) format files are supported.")
"Excel, csv(txt) format files are supported.")
def label_question(question, kbs): def label_question(question, kbs):
"""
标记问题的标签
该函数通过给定的问题和知识库列表对问题进行标签标记它首先确定哪些知识库配置了标签
然后从缓存中获取这些标签必要时从设置中检索标签最后使用这些标签对问题进行标记
参数:
question (str): 需要标记的问题
kbs (list): 知识库对象列表用于标签标记
返回:
list: 与问题相关的标签列表
"""
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from graphrag.utils import get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_tags_from_cache, set_tags_to_cache
from api import settings from api import settings
# 初始化标签和标签知识库ID列表
tags = None tags = None
tag_kb_ids = [] tag_kb_ids = []
# 遍历知识库收集所有标签知识库ID
for kb in kbs: for kb in kbs:
if kb.parser_config.get("tag_kb_ids"): if kb.parser_config.get("tag_kb_ids"):
tag_kb_ids.extend(kb.parser_config["tag_kb_ids"]) tag_kb_ids.extend(kb.parser_config["tag_kb_ids"])
# 如果存在标签知识库ID则进一步处理
if tag_kb_ids: if tag_kb_ids:
# 尝试从缓存中获取所有标签
all_tags = get_tags_from_cache(tag_kb_ids) all_tags = get_tags_from_cache(tag_kb_ids)
# 如果缓存中没有标签,从设置中检索标签,并设置缓存
if not all_tags: if not all_tags:
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids) all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
set_tags_to_cache(all_tags, tag_kb_ids) set_tags_to_cache(all_tags, tag_kb_ids)
else: else:
# 如果缓存中获取到标签将其解析为JSON格式
all_tags = json.loads(all_tags) all_tags = json.loads(all_tags)
# 根据标签知识库ID获取对应的标签知识库
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids) tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
tags = settings.retrievaler.tag_query(question,
list(set([kb.tenant_id for kb in tag_kbs])), # 使用设置中的检索器对问题进行标签标记
tag_kb_ids, tags = settings.retrievaler.tag_query(question, list(set([kb.tenant_id for kb in tag_kbs])), tag_kb_ids, all_tags, kb.parser_config.get("topn_tags", 3))
all_tags,
kb.parser_config.get("topn_tags", 3) # 返回标记的标签
)
return tags return tags
@ -152,4 +168,5 @@ if __name__ == "__main__":
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -53,6 +53,16 @@ class FulltextQueryer:
@staticmethod @staticmethod
def rmWWW(txt): def rmWWW(txt):
"""
移除文本中的WWW(WHATWHOWHERE等疑问词)
本函数通过一系列正则表达式模式来识别并替换文本中的疑问词以简化文本或为后续处理做准备
参数:
- txt: 待处理的文本字符串
返回:
- 处理后的文本字符串如果所有疑问词都被移除且文本为空则返回原始文本
"""
patts = [ patts = [
( (
r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*", r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
@ -61,7 +71,8 @@ class FulltextQueryer:
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "), (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
( (
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ", r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
" ") " ",
),
] ]
otxt = txt otxt = txt
for r, p in patts: for r, p in patts:
@ -70,28 +81,53 @@ class FulltextQueryer:
txt = otxt txt = otxt
return txt return txt
def question(self, txt, tbl="qa", min_match: float = 0.6): @staticmethod
def add_space_between_eng_zh(txt):
""" """
处理用户问题并生成全文检索表达式 在英文和中文之间添加空格
该函数通过正则表达式匹配文本中英文和中文相邻的情况并在它们之间插入空格
这样做可以改善文本的可读性特别是在混合使用英文和中文时
参数: 参数:
txt: 原始问题文本 txt (str): 需要处理的文本字符串
tbl: 查询表名(默认"qa")
min_match: 最小匹配阈值(默认0.6)
返回: 返回:
MatchTextExpr: 全文检索表达式对象 str: 处理后的文本字符串其中英文和中文之间添加了空格
list: 提取的关键词列表
""" """
# 1. 文本预处理:去除特殊字符、繁体转简体、全角转半角、转小写 # (ENG/ENG+NUM) + ZH
txt = re.sub(r"([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)", r"\1 \2", txt)
# ENG + ZH
txt = re.sub(r"([A-Za-z])([\u4e00-\u9fa5]+)", r"\1 \2", txt)
# ZH + (ENG/ENG+NUM)
txt = re.sub(r"([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)", r"\1 \2", txt)
txt = re.sub(r"([\u4e00-\u9fa5]+)([A-Za-z])", r"\1 \2", txt)
return txt
def question(self, txt, tbl="qa", min_match: float = 0.6):
"""
根据输入的文本生成查询表达式用于在数据库中匹配相关问题
参数:
- txt (str): 输入的文本
- tbl (str): 数据表名默认为"qa"
- min_match (float): 最小匹配度默认为0.6
返回:
- MatchTextExpr: 生成的查询表达式对象
- keywords (list): 提取的关键词列表
"""
txt = FulltextQueryer.add_space_between_eng_zh(txt) # 在英文和中文之间添加空格
# 使用正则表达式替换特殊字符为单个空格,并将文本转换为简体中文和小写
txt = re.sub( txt = re.sub(
r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+", r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+",
" ", " ",
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())), rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
).strip() ).strip()
txt = FulltextQueryer.rmWWW(txt) # 去除停用词 otxt = txt
txt = FulltextQueryer.rmWWW(txt)
# 2. 非中文文本处理 # 如果文本不是中文,则进行英文处理
if not self.isChinese(txt): if not self.isChinese(txt):
txt = FulltextQueryer.rmWWW(txt) txt = FulltextQueryer.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split() tks = rag_tokenizer.tokenize(txt).split()
@ -106,11 +142,10 @@ class FulltextQueryer:
syn = self.syn.lookup(tk) syn = self.syn.lookup(tk)
syn = rag_tokenizer.tokenize(" ".join(syn)).split() syn = rag_tokenizer.tokenize(" ".join(syn)).split()
keywords.extend(syn) keywords.extend(syn)
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()] syn = ['"{}"^{:.4f}'.format(s, w / 4.0) for s in syn if s.strip()]
syns.append(" ".join(syn)) syns.append(" ".join(syn))
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if tk and not re.match(r"[.^+\(\)-]", tk)]
tk and not re.match(r"[.^+\(\)-]", tk)]
for i in range(1, len(tks_w)): for i in range(1, len(tks_w)):
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip() left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
if not left or not right: if not left or not right:
@ -126,48 +161,53 @@ class FulltextQueryer:
if not q: if not q:
q.append(txt) q.append(txt)
query = " ".join(q) query = " ".join(q)
return MatchTextExpr( return MatchTextExpr(self.query_fields, query, 100), keywords
self.query_fields, query, 100
), keywords
def need_fine_grained_tokenize(tk): def need_fine_grained_tokenize(tk):
""" """
判断是否需要细粒度分词 判断是否需要对词进行细粒度分词
参数: 参数:
tk: 待判断的词条 - tk (str): 待判断的词
返回: 返回:
bool: True表示需要细粒度分词 - bool: 是否需要进行细粒度分词
""" """
# 长度小于3的词不处理
if len(tk) < 3: if len(tk) < 3:
return False return False
# 匹配特定模式的词不处理(如数字、字母、符号组合)
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk): if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
return False return False
return True return True
txt = FulltextQueryer.rmWWW(txt) # 二次去除停用词 txt = FulltextQueryer.rmWWW(txt)
qs, keywords = [], [] # 初始化查询表达式和关键词列表 qs, keywords = [], []
# 3. 中文文本处理最多处理256个词 # 遍历文本分割后的前256个片段防止处理过长文本
for tt in self.tw.split(txt)[:256]: # .split(): for tt in self.tw.split(txt)[:256]: # 这个split似乎是对英文设计中文不起作用
if not tt: if not tt:
continue continue
# 3.1 基础关键词收集 # 将当前片段加入关键词列表
keywords.append(tt) keywords.append(tt)
twts = self.tw.weights([tt]) # 获取词权重 # 获取当前片段的权重
syns = self.syn.lookup(tt) # 查询同义词 twts = self.tw.weights([tt])
# 3.2 同义词扩展最多扩展到32个关键词 # 查找同义词
syns = self.syn.lookup(tt)
# 如果有同义词且关键词数量未超过32将同义词加入关键词列表
if syns and len(keywords) < 32: if syns and len(keywords) < 32:
keywords.extend(syns) keywords.extend(syns)
# 调试日志:输出权重信息
logging.debug(json.dumps(twts, ensure_ascii=False)) logging.debug(json.dumps(twts, ensure_ascii=False))
# 初始化查询条件列表
tms = [] tms = []
# 3.3 处理每个词及其权重 # 按权重降序排序处理每个token
for tk, w in sorted(twts, key=lambda x: x[1] * -1): for tk, w in sorted(twts, key=lambda x: x[1] * -1):
# 3.3.1 细粒度分词处理 # 如果需要细粒度分词,则进行分词处理
sm = ( sm = rag_tokenizer.fine_grained_tokenize(tk).split() if need_fine_grained_tokenize(tk) else []
rag_tokenizer.fine_grained_tokenize(tk).split() # 对每个分词结果进行清洗:
if need_fine_grained_tokenize(tk) # 1. 去除标点符号和特殊字符
else [] # 2. 使用subSpecialChar进一步处理
) # 3. 过滤掉长度<=1的词
# 3.3.2 清洗分词结果
sm = [ sm = [
re.sub( re.sub(
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+", r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
@ -178,59 +218,65 @@ class FulltextQueryer:
] ]
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1] sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1] sm = [m for m in sm if len(m) > 1]
# 3.3.3 收集关键词不超过32个
if len(keywords) < 32:
keywords.append(re.sub(r"[ \\\"']+", "", tk))
keywords.extend(sm)
# 3.3.4 同义词处理 # 如果关键词数量未达上限添加处理后的token和分词结果
if len(keywords) < 32:
keywords.append(re.sub(r"[ \\\"']+", "", tk)) # 去除转义字符
keywords.extend(sm) # 添加分词结果
# 获取当前token的同义词并进行处理
tk_syns = self.syn.lookup(tk) tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
# 添加有效同义词到关键词列表
if len(keywords) < 32: if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s]) keywords.extend([s for s in tk_syns if s])
# 对同义词进行分词处理,并为包含空格的同义词添加引号
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] tk_syns = [f'"{s}"' if s.find(" ") > 0 else s for s in tk_syns]
# 关键词数量限制
# 关键词数量达到上限则停止处理
if len(keywords) >= 32: if len(keywords) >= 32:
break break
# 3.3.5 构建查询表达式 # 处理当前token用于构建查询条件
# 1. 特殊字符处理
# 2. 为包含空格的token添加引号
# 3. 如果有同义词构建OR条件并降低权重
# 4. 如果有分词结果添加OR条件
tk = FulltextQueryer.subSpecialChar(tk) tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0: if tk.find(" ") > 0:
tk = '"%s"' % tk # 处理短语查询 tk = '"%s"' % tk
if tk_syns: if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns) # 添加同义词查询 tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
if sm: if sm:
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm)) # 添加细粒度分词查询 tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
if tk.strip(): if tk.strip():
tms.append((tk, w)) # 保存带权重的查询表达式 tms.append((tk, w))
# 3.4 合并当前词的查询表达式 # 将处理后的查询条件按权重组合成字符串
tms = " ".join([f"({t})^{w}" for t, w in tms]) tms = " ".join([f"({t})^{w}" for t, w in tms])
# 3.5 添加相邻词组合查询(提升短语匹配权重) # 如果有多个权重项,添加短语搜索条件(提高相邻词匹配的权重)
if len(twts) > 1: if len(twts) > 1:
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt) tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
# 3.6 处理同义词查询表达式 # 处理同义词的查询条件
syns = " OR ".join( syns = " OR ".join(['"%s"' % rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s)) for s in syns])
[ # 组合主查询条件和同义词条件
'"%s"'
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
for s in syns
]
)
if syns and tms: if syns and tms:
tms = f"({tms})^5 OR ({syns})^0.7" tms = f"({tms})^5 OR ({syns})^0.7"
# 将最终查询条件加入列表
qs.append(tms)
qs.append(tms) # 添加到最终查询列表 # 处理所有查询条件
# 4. 生成最终查询表达式
if qs: if qs:
# 组合所有查询条件为OR关系
query = " OR ".join([f"({t})" for t in qs if t]) query = " OR ".join([f"({t})" for t in qs if t])
return MatchTextExpr( # 如果查询条件为空,使用原始文本
self.query_fields, query, 100, {"minimum_should_match": min_match} if not query:
), keywords query = otxt
# 返回匹配文本表达式和关键词
return MatchTextExpr(self.query_fields, query, 100, {"minimum_should_match": min_match}), keywords
# 如果没有生成查询条件,只返回关键词
return None, keywords return None, keywords
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7): def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
@ -282,7 +328,7 @@ class FulltextQueryer:
tk_syns = self.syn.lookup(tk) tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] tk_syns = [f'"{s}"' if s.find(" ") > 0 else s for s in tk_syns]
tk = FulltextQueryer.subSpecialChar(tk) tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0: if tk.find(" ") > 0:
tk = '"%s"' % tk tk = '"%s"' % tk
@ -291,5 +337,4 @@ class FulltextQueryer:
if tk: if tk:
keywords.append(f"{tk}^{w}") keywords.append(f"{tk}^{w}")
return MatchTextExpr(self.query_fields, " ".join(keywords), 100, return MatchTextExpr(self.query_fields, " ".join(keywords), 100, {"minimum_should_match": min(3, len(keywords) / 10)})
{"minimum_should_match": min(3, len(keywords) / 10)})

View File

@ -22,9 +22,12 @@ import os
import re import re
import string import string
import sys import sys
from pathlib import Path
from hanziconv import HanziConv from hanziconv import HanziConv
from nltk import word_tokenize from nltk import word_tokenize
from nltk.stem import PorterStemmer, WordNetLemmatizer from nltk.stem import PorterStemmer, WordNetLemmatizer
sys.path.append(str(Path(__file__).parent.parent.parent))
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
@ -38,7 +41,7 @@ class RagTokenizer:
def loadDict_(self, fnm): def loadDict_(self, fnm):
logging.info(f"[HUQIE]:Build trie from {fnm}") logging.info(f"[HUQIE]:Build trie from {fnm}")
try: try:
of = open(fnm, "r", encoding='utf-8') of = open(fnm, "r", encoding="utf-8")
while True: while True:
line = of.readline() line = of.readline()
if not line: if not line:
@ -46,7 +49,7 @@ class RagTokenizer:
line = re.sub(r"[\r\n]+", "", line) line = re.sub(r"[\r\n]+", "", line)
line = re.split(r"[ \t]", line) line = re.split(r"[ \t]", line)
k = self.key_(line[0]) k = self.key_(line[0])
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5) F = int(math.log(float(line[1]) / self.DENOMINATOR) + 0.5)
if k not in self.trie_ or self.trie_[k][0] < F: if k not in self.trie_ or self.trie_[k][0] < F:
self.trie_[self.key_(line[0])] = (F, line[2]) self.trie_[self.key_(line[0])] = (F, line[2])
self.trie_[self.rkey_(line[0])] = 1 self.trie_[self.rkey_(line[0])] = 1
@ -106,8 +109,8 @@ class RagTokenizer:
if inside_code == 0x3000: if inside_code == 0x3000:
inside_code = 0x0020 inside_code = 0x0020
else: else:
inside_code -= 0xfee0 inside_code -= 0xFEE0
if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character. if inside_code < 0x0020 or inside_code > 0x7E: # After the conversion, if it's not a half-width character, return the original character.
rstring += uchar rstring += uchar
else: else:
rstring += chr(inside_code) rstring += chr(inside_code)
@ -126,13 +129,11 @@ class RagTokenizer:
# pruning # pruning
S = s + 1 S = s + 1
if s + 2 <= len(chars): if s + 2 <= len(chars):
t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2]) t1, t2 = "".join(chars[s : s + 1]), "".join(chars[s : s + 2])
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix( if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
self.key_(t2)):
S = s + 2 S = s + 2
if len(preTks) > 2 and len( if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1: t1 = preTks[-1][0] + "".join(chars[s : s + 1])
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
if self.trie_.has_keys_with_prefix(self.key_(t1)): if self.trie_.has_keys_with_prefix(self.key_(t1)):
S = s + 2 S = s + 2
@ -149,18 +150,18 @@ class RagTokenizer:
if k in self.trie_: if k in self.trie_:
pretks.append((t, self.trie_[k])) pretks.append((t, self.trie_[k]))
else: else:
pretks.append((t, (-12, ''))) pretks.append((t, (-12, "")))
res = max(res, self.dfs_(chars, e, pretks, tkslist)) res = max(res, self.dfs_(chars, e, pretks, tkslist))
if res > s: if res > s:
return res return res
t = "".join(chars[s:s + 1]) t = "".join(chars[s : s + 1])
k = self.key_(t) k = self.key_(t)
if k in self.trie_: if k in self.trie_:
preTks.append((t, self.trie_[k])) preTks.append((t, self.trie_[k]))
else: else:
preTks.append((t, (-12, ''))) preTks.append((t, (-12, "")))
return self.dfs_(chars, s + 1, preTks, tkslist) return self.dfs_(chars, s + 1, preTks, tkslist)
@ -183,7 +184,7 @@ class RagTokenizer:
F += freq F += freq
L += 0 if len(tk) < 2 else 1 L += 0 if len(tk) < 2 else 1
tks.append(tk) tks.append(tk)
#F /= len(tks) # F /= len(tks)
L /= len(tks) L /= len(tks)
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F)) logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
return tks, B / len(tks) + L + F return tks, B / len(tks) + L + F
@ -219,8 +220,7 @@ class RagTokenizer:
while s < len(line): while s < len(line):
e = s + 1 e = s + 1
t = line[s:e] t = line[s:e]
while e < len(line) and self.trie_.has_keys_with_prefix( while e < len(line) and self.trie_.has_keys_with_prefix(self.key_(t)):
self.key_(t)):
e += 1 e += 1
t = line[s:e] t = line[s:e]
@ -231,7 +231,7 @@ class RagTokenizer:
if self.key_(t) in self.trie_: if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)])) res.append((t, self.trie_[self.key_(t)]))
else: else:
res.append((t, (0, ''))) res.append((t, (0, "")))
s = e s = e
@ -254,7 +254,7 @@ class RagTokenizer:
if self.key_(t) in self.trie_: if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)])) res.append((t, self.trie_[self.key_(t)]))
else: else:
res.append((t, (0, ''))) res.append((t, (0, "")))
s -= 1 s -= 1
@ -277,13 +277,13 @@ class RagTokenizer:
if _zh == zh: if _zh == zh:
e += 1 e += 1
continue continue
txt_lang_pairs.append((a[s: e], zh)) txt_lang_pairs.append((a[s:e], zh))
s = e s = e
e = s + 1 e = s + 1
zh = _zh zh = _zh
if s >= len(a): if s >= len(a):
continue continue
txt_lang_pairs.append((a[s: e], zh)) txt_lang_pairs.append((a[s:e], zh))
return txt_lang_pairs return txt_lang_pairs
def tokenize(self, line): def tokenize(self, line):
@ -293,12 +293,11 @@ class RagTokenizer:
arr = self._split_by_lang(line) arr = self._split_by_lang(line)
res = [] res = []
for L,lang in arr: for L, lang in arr:
if not lang: if not lang:
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)]) res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
continue continue
if len(L) < 2 or re.match( if len(L) < 2 or re.match(r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
res.append(L) res.append(L)
continue continue
@ -314,7 +313,7 @@ class RagTokenizer:
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]: while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1 same += 1
if same > 0: if same > 0:
res.append(" ".join(tks[j: j + same])) res.append(" ".join(tks[j : j + same]))
_i = i + same _i = i + same
_j = j + same _j = j + same
j = _j + 1 j = _j + 1
@ -341,7 +340,7 @@ class RagTokenizer:
same = 1 same = 1
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]: while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1 same += 1
res.append(" ".join(tks[j: j + same])) res.append(" ".join(tks[j : j + same]))
_i = i + same _i = i + same
_j = j + same _j = j + same
j = _j + 1 j = _j + 1
@ -359,31 +358,64 @@ class RagTokenizer:
return self.merge_(res) return self.merge_(res)
def fine_grained_tokenize(self, tks): def fine_grained_tokenize(self, tks):
"""
细粒度分词方法根据文本特征中英文比例数字符号等动态选择分词策略
参数:
tks (str): 待分词的文本字符串
返回:
str: 分词后的结果用空格连接的词序列
处理逻辑:
1. 先按空格初步切分文本
2. 根据中文占比决定是否启用细粒度分词
3. 对特殊格式短词纯数字等直接保留原样
4. 对长词或复杂词使用DFS回溯算法寻找最优切分
5. 对英文词进行额外校验和规范化处理
"""
# 初始切分:按空格分割输入文本
tks = tks.split() tks = tks.split()
# 计算中文词占比(判断是否主要包含中文内容)
zh_num = len([1 for c in tks if c and is_chinese(c[0])]) zh_num = len([1 for c in tks if c and is_chinese(c[0])])
# 如果中文占比低于20%,则按简单规则处理(主要处理英文混合文本)
if zh_num < len(tks) * 0.2: if zh_num < len(tks) * 0.2:
res = [] res = []
for tk in tks: for tk in tks:
res.extend(tk.split("/")) res.extend(tk.split("/"))
return " ".join(res) return " ".join(res)
# 中文或复杂文本处理流程
res = [] res = []
for tk in tks: for tk in tks:
# 规则1跳过短词长度<3或纯数字/符号组合(如"3.14"
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk): if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
res.append(tk) res.append(tk)
continue continue
# 初始化候选分词列表
tkslist = [] tkslist = []
# 规则2超长词长度>10直接保留不切分
if len(tk) > 10: if len(tk) > 10:
tkslist.append(tk) tkslist.append(tk)
else: else:
# 使用DFS回溯算法寻找所有可能的分词组合
self.dfs_(tk, 0, [], tkslist) self.dfs_(tk, 0, [], tkslist)
# 规则3若无有效切分方案则保留原词
if len(tkslist) < 2: if len(tkslist) < 2:
res.append(tk) res.append(tk)
continue continue
# 从候选方案中选择最优切分通过sortTks_排序
stk = self.sortTks_(tkslist)[1][0] stk = self.sortTks_(tkslist)[1][0]
# 规则4若切分结果与原词长度相同则视为无效切分
if len(stk) == len(tk): if len(stk) == len(tk):
stk = tk stk = tk
else: else:
# 英文特殊处理:检查子词长度是否合法
if re.match(r"[a-z\.-]+$", tk): if re.match(r"[a-z\.-]+$", tk):
for t in stk: for t in stk:
if len(t) < 3: if len(t) < 3:
@ -393,29 +425,28 @@ class RagTokenizer:
stk = " ".join(stk) stk = " ".join(stk)
else: else:
stk = " ".join(stk) stk = " ".join(stk)
# 中文词直接拼接结果
res.append(stk) res.append(stk)
return " ".join(self.english_normalize_(res)) return " ".join(self.english_normalize_(res))
def is_chinese(s): def is_chinese(s):
if s >= u'\u4e00' and s <= u'\u9fa5': if s >= "\u4e00" and s <= "\u9fa5":
return True return True
else: else:
return False return False
def is_number(s): def is_number(s):
if s >= u'\u0030' and s <= u'\u0039': if s >= "\u0030" and s <= "\u0039":
return True return True
else: else:
return False return False
def is_alphabet(s): def is_alphabet(s):
if (s >= u'\u0041' and s <= u'\u005a') or ( if (s >= "\u0041" and s <= "\u005a") or (s >= "\u0061" and s <= "\u007a"):
s >= u'\u0061' and s <= u'\u007a'):
return True return True
else: else:
return False return False
@ -424,8 +455,7 @@ def is_alphabet(s):
def naiveQie(txt): def naiveQie(txt):
tks = [] tks = []
for t in txt.split(): for t in txt.split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1] if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and re.match(r".*[a-zA-Z]$", t):
) and re.match(r".*[a-zA-Z]$", t):
tks.append(" ") tks.append(" ")
tks.append(t) tks.append(t)
return tks return tks
@ -441,43 +471,41 @@ addUserDict = tokenizer.addUserDict
tradi2simp = tokenizer._tradi2simp tradi2simp = tokenizer._tradi2simp
strQ2B = tokenizer._strQ2B strQ2B = tokenizer._strQ2B
if __name__ == '__main__': if __name__ == "__main__":
tknzr = RagTokenizer(debug=True) tknzr = RagTokenizer(debug=True)
# huqie.addUserDict("/tmp/tmp.new.tks.dict") # huqie.addUserDict("/tmp/tmp.new.tks.dict")
tks = tknzr.tokenize("哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
# logging.info(tknzr.fine_grained_tokenize(tks))
print(tks)
tks = tknzr.tokenize( tks = tknzr.tokenize(
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈") "公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。"
logging.info(tknzr.fine_grained_tokenize(tks)) )
tks = tknzr.tokenize( print(tks)
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。") # logging.info(tknzr.fine_grained_tokenize(tks))
logging.info(tknzr.fine_grained_tokenize(tks)) # tks = tknzr.tokenize("多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
tks = tknzr.tokenize( # logging.info(tknzr.fine_grained_tokenize(tks))
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥") # tks = tknzr.tokenize("实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
logging.info(tknzr.fine_grained_tokenize(tks)) # logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize( # tks = tknzr.tokenize("虽然我不怎么玩")
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa") # logging.info(tknzr.fine_grained_tokenize(tks))
logging.info(tknzr.fine_grained_tokenize(tks)) # tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
tks = tknzr.tokenize("虽然我不怎么玩") # logging.info(tknzr.fine_grained_tokenize(tks))
logging.info(tknzr.fine_grained_tokenize(tks)) # tks = tknzr.tokenize("涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的") # logging.info(tknzr.fine_grained_tokenize(tks))
logging.info(tknzr.fine_grained_tokenize(tks)) # tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
tks = tknzr.tokenize( # logging.info(tknzr.fine_grained_tokenize(tks))
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了") # tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
logging.info(tknzr.fine_grained_tokenize(tks)) # logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?") # tks = tknzr.tokenize("数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
logging.info(tknzr.fine_grained_tokenize(tks)) # logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ") # if len(sys.argv) < 2:
logging.info(tknzr.fine_grained_tokenize(tks)) # sys.exit()
tks = tknzr.tokenize( # tknzr.DEBUG = False
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-") # tknzr.loadUserDict(sys.argv[1])
logging.info(tknzr.fine_grained_tokenize(tks)) # of = open(sys.argv[2], "r")
if len(sys.argv) < 2: # while True:
sys.exit() # line = of.readline()
tknzr.DEBUG = False # if not line:
tknzr.loadUserDict(sys.argv[1]) # break
of = open(sys.argv[2], "r") # logging.info(tknzr.tokenize(line))
while True: # of.close()
line = of.readline()
if not line:
break
logging.info(tknzr.tokenize(line))
of.close()

View File

@ -15,6 +15,7 @@
# #
import logging import logging
import re import re
import math
from dataclasses import dataclass from dataclasses import dataclass
from rag.settings import TAG_FLD, PAGERANK_FLD from rag.settings import TAG_FLD, PAGERANK_FLD
@ -24,7 +25,8 @@ import numpy as np
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
def index_name(uid): return f"ragflow_{uid}" def index_name(uid):
return f"ragflow_{uid}"
class Dealer: class Dealer:
@ -47,11 +49,10 @@ class Dealer:
qv, _ = emb_mdl.encode_queries(txt) qv, _ = emb_mdl.encode_queries(txt)
shape = np.array(qv).shape shape = np.array(qv).shape
if len(shape) > 1: if len(shape) > 1:
raise Exception( raise Exception(f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
embedding_data = [float(v) for v in qv] embedding_data = [float(v) for v in qv]
vector_column_name = f"q_{len(embedding_data)}_vec" vector_column_name = f"q_{len(embedding_data)}_vec"
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity}) return MatchDenseExpr(vector_column_name, embedding_data, "float", "cosine", topk, {"similarity": similarity})
def get_filters(self, req): def get_filters(self, req):
condition = dict() condition = dict()
@ -64,12 +65,7 @@ class Dealer:
condition[key] = req[key] condition[key] = req[key]
return condition return condition
def search(self, req, idx_names: str | list[str], def search(self, req, idx_names: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False, rank_feature: dict | None = None):
kb_ids: list[str],
emb_mdl=None,
highlight=False,
rank_feature: dict | None = None
):
""" """
执行混合检索全文检索+向量检索 执行混合检索全文检索+向量检索
@ -108,18 +104,37 @@ class Dealer:
offset, limit = pg * ps, ps offset, limit = pg * ps, ps
# 3. 设置返回字段(默认包含文档名、内容等核心字段) # 3. 设置返回字段(默认包含文档名、内容等核心字段)
src = req.get("fields", src = req.get(
["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int", "fields",
"doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd", [
"question_kwd", "question_tks", "docnm_kwd",
"available_int", "content_with_weight", PAGERANK_FLD, TAG_FLD]) "content_ltks",
kwds = set([]) "kb_id",
"img_id",
"title_tks",
"important_kwd",
"position_int",
"doc_id",
"page_num_int",
"top_int",
"create_timestamp_flt",
"knowledge_graph_kwd",
"question_kwd",
"question_tks",
"available_int",
"content_with_weight",
PAGERANK_FLD,
TAG_FLD,
],
)
kwds = set([]) # 初始化关键词集合
# 4. 处理查询问题 # 4. 处理查询问题
qst = req.get("question", "") qst = req.get("question", "") # 获取查询问题文本
q_vec = [] print(f"收到前端问题:{qst}")
q_vec = [] # 初始化查询向量(如需向量检索)
if not qst: if not qst:
# 4.1 无查询文本时的处理(按文档排序) # 4.1 若查询文本为空,执行默认排序检索(通常用于无搜索条件浏览)(注:前端测试检索时会禁止空文本的提交)
if req.get("sort"): if req.get("sort"):
orderBy.asc("page_num_int") orderBy.asc("page_num_int")
orderBy.asc("top_int") orderBy.asc("top_int")
@ -128,44 +143,58 @@ class Dealer:
total = self.dataStore.getTotal(res) total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total)) logging.debug("Dealer.search TOTAL: {}".format(total))
else: else:
# 4.2 有查询文本时的处理 # 4.2 若存在查询文本,进入全文/混合检索流程
highlightFields = ["content_ltks", "title_tks"] if highlight else [] highlightFields = ["content_ltks", "title_tks"] if highlight else [] # highlight当前会一直为False不起作用
# 4.2.1 生成全文检索表达式和关键词 # 4.2.1 生成全文检索表达式和关键词
matchText, keywords = self.qryr.question(qst, min_match=0.3) matchText, keywords = self.qryr.question(qst, min_match=0.3)
print(f"matchText.matching_text: {matchText.matching_text}")
print(f"keywords: {keywords}\n")
if emb_mdl is None: if emb_mdl is None:
# 4.2.2 纯全文检索模式 # 4.2.2 纯全文检索模式 (未提供向量模型,正常情况不会进入)
matchExprs = [matchText] matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res) total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total)) logging.debug("Dealer.search TOTAL: {}".format(total))
else: else:
# 4.2.3 混合检索模式(全文+向量) # 4.2.3 混合检索模式(全文+向量)
# 生成查询向量 # 生成查询向量
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1)) matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
q_vec = matchDense.embedding_data q_vec = matchDense.embedding_data
# 在返回字段中加入查询向量字段
src.append(f"q_{len(q_vec)}_vec") src.append(f"q_{len(q_vec)}_vec")
# 设置混合检索权重全文5% + 向量95% # 创建融合表达式:设置向量匹配为95%全文为5%(可以调整权重
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"}) fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
# 构建混合查询表达式
matchExprs = [matchText, matchDense, fusionExpr] matchExprs = [matchText, matchDense, fusionExpr]
# 执行混合检索 # 执行混合检索
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res) total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total)) logging.debug("Dealer.search TOTAL: {}".format(total))
# If result is empty, try again with lower min_match print(f"共查询到: {total} 条信息")
if total == 0: # print(f"查询信息结果: {res}\n")
matchText, _ = self.qryr.question(qst, min_match=0.1)
filters.pop("doc_ids", None)
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
# 若未找到结果,则尝试降低匹配门槛后重试
if total == 0:
if filters.get("doc_id"):
# 有特定文档ID时执行无条件查询
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res)
print(f"针对选中文档,共查询到: {total} 条信息")
# print(f"查询信息结果: {res}\n")
else:
# 否则调整全文和向量匹配参数再次搜索
matchText, _ = self.qryr.question(qst, min_match=0.1)
filters.pop("doc_id", None)
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
print(f"再次查询,共查询到: {total} 条信息")
# print(f"查询信息结果: {res}\n")
# 4.3 处理关键词(对关键词进行更细粒度的切词)
for k in keywords: for k in keywords:
kwds.add(k) kwds.add(k)
for kk in rag_tokenizer.fine_grained_tokenize(k).split(): for kk in rag_tokenizer.fine_grained_tokenize(k).split():
@ -175,27 +204,23 @@ class Dealer:
continue continue
kwds.add(kk) kwds.add(kk)
# 5. 提取检索结果中的ID、字段、聚合和高亮信息
logging.debug(f"TOTAL: {total}") logging.debug(f"TOTAL: {total}")
ids = self.dataStore.getChunkIds(res) ids = self.dataStore.getChunkIds(res) # 提取匹配chunk的ID
keywords = list(kwds) keywords = list(kwds) # 转为列表格式返回
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") # 获取高亮内容
aggs = self.dataStore.getAggregation(res, "docnm_kwd") aggs = self.dataStore.getAggregation(res, "docnm_kwd") # 执行基于文档名的聚合分析
return self.SearchResult( print(f"ids:{ids}")
total=total, print(f"keywords:{keywords}")
ids=ids, print(f"highlight:{highlight}")
query_vector=q_vec, print(f"aggs:{aggs}")
aggregation=aggs, return self.SearchResult(total=total, ids=ids, query_vector=q_vec, aggregation=aggs, highlight=highlight, field=self.dataStore.getFields(res, src), keywords=keywords)
highlight=highlight,
field=self.dataStore.getFields(res, src),
keywords=keywords
)
@staticmethod @staticmethod
def trans2floats(txt): def trans2floats(txt):
return [float(t) for t in txt.split("\t")] return [float(t) for t in txt.split("\t")]
def insert_citations(self, answer, chunks, chunk_v, def insert_citations(self, answer, chunks, chunk_v, embd_mdl, tkweight=0.1, vtweight=0.9):
embd_mdl, tkweight=0.1, vtweight=0.9):
assert len(chunks) == len(chunk_v) assert len(chunks) == len(chunk_v)
if not chunks: if not chunks:
return answer, set([]) return answer, set([])
@ -211,12 +236,9 @@ class Dealer:
i += 1 i += 1
if i < len(pieces): if i < len(pieces):
i += 1 i += 1
pieces_.append("".join(pieces[st: i]) + "\n") pieces_.append("".join(pieces[st:i]) + "\n")
else: else:
pieces_.extend( pieces_.extend(re.split(r"([^\|][;。?!\n]|[a-z][.?;!][ \n])", pieces[i]))
re.split(
r"([^\|][;。?!\n]|[a-z][.?;!][ \n])",
pieces[i]))
i += 1 i += 1
pieces = pieces_ pieces = pieces_
else: else:
@ -239,30 +261,22 @@ class Dealer:
ans_v, _ = embd_mdl.encode(pieces_) ans_v, _ = embd_mdl.encode(pieces_)
for i in range(len(chunk_v)): for i in range(len(chunk_v)):
if len(ans_v[0]) != len(chunk_v[i]): if len(ans_v[0]) != len(chunk_v[i]):
chunk_v[i] = [0.0]*len(ans_v[0]) chunk_v[i] = [0.0] * len(ans_v[0])
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i]))) logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[0]))
len(ans_v[0]), len(chunk_v[0]))
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split() chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split() for ck in chunks]
for ck in chunks]
cites = {} cites = {}
thr = 0.63 thr = 0.63
while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks: while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
for i, a in enumerate(pieces_): for i, a in enumerate(pieces_):
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], chunk_v, rag_tokenizer.tokenize(self.qryr.rmWWW(pieces_[i])).split(), chunks_tks, tkweight, vtweight)
chunk_v,
rag_tokenizer.tokenize(
self.qryr.rmWWW(pieces_[i])).split(),
chunks_tks,
tkweight, vtweight)
mx = np.max(sim) * 0.99 mx = np.max(sim) * 0.99
logging.debug("{} SIM: {}".format(pieces_[i], mx)) logging.debug("{} SIM: {}".format(pieces_[i], mx))
if mx < thr: if mx < thr:
continue continue
cites[idx[i]] = list( cites[idx[i]] = list(set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
thr *= 0.8 thr *= 0.8
res = "" res = ""
@ -294,7 +308,7 @@ class Dealer:
if not query_rfea: if not query_rfea:
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD])) q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
for i in search_res.ids: for i in search_res.ids:
nor, denor = 0, 0 nor, denor = 0, 0
for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items(): for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items():
@ -304,13 +318,10 @@ class Dealer:
if denor == 0: if denor == 0:
rank_fea.append(0) rank_fea.append(0)
else: else:
rank_fea.append(nor/np.sqrt(denor)/q_denor) rank_fea.append(nor / np.sqrt(denor) / q_denor)
return np.array(rank_fea)*10. + pageranks return np.array(rank_fea) * 10.0 + pageranks
def rerank(self, sres, query, tkweight=0.3, def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks", rank_feature: dict | None = None):
vtweight=0.7, cfield="content_ltks",
rank_feature: dict | None = None
):
_, keywords = self.qryr.question(query) _, keywords = self.qryr.question(query)
vector_size = len(sres.query_vector) vector_size = len(sres.query_vector)
vector_column = f"q_{vector_size}_vec" vector_column = f"q_{vector_size}_vec"
@ -339,16 +350,11 @@ class Dealer:
## For rank feature(tag_fea) scores. ## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres) rank_fea = self._rank_feature_scores(rank_feature, sres)
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, ins_embd, keywords, ins_tw, tkweight, vtweight)
ins_embd,
keywords,
ins_tw, tkweight, vtweight)
return sim + rank_fea, tksim, vtsim return sim + rank_fea, tksim, vtsim
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3, def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks", rank_feature: dict | None = None):
vtweight=0.7, cfield="content_ltks",
rank_feature: dict | None = None):
_, keywords = self.qryr.question(query) _, keywords = self.qryr.question(query)
for i in sres.ids: for i in sres.ids:
@ -367,18 +373,28 @@ class Dealer:
## For rank feature(tag_fea) scores. ## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres) rank_fea = self._rank_feature_scores(rank_feature, sres)
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim return tkweight * (np.array(tksim) + rank_fea) + vtweight * vtsim, tksim, vtsim
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst): def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
return self.qryr.hybrid_similarity(ans_embd, return self.qryr.hybrid_similarity(ans_embd, ins_embd, rag_tokenizer.tokenize(ans).split(), rag_tokenizer.tokenize(inst).split())
ins_embd,
rag_tokenizer.tokenize(ans).split(),
rag_tokenizer.tokenize(inst).split())
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2, def retrieval(
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, self,
rerank_mdl=None, highlight=False, question,
rank_feature: dict | None = {PAGERANK_FLD: 10}): embd_mdl,
tenant_ids,
kb_ids,
page,
page_size,
similarity_threshold=0.2,
vector_similarity_weight=0.3,
top=1024,
doc_ids=None,
aggs=True,
rerank_mdl=None,
highlight=False,
rank_feature: dict | None = {PAGERANK_FLD: 10},
):
""" """
执行检索操作根据问题查询相关文档片段 执行检索操作根据问题查询相关文档片段
@ -406,59 +422,49 @@ class Dealer:
if not question: if not question:
return ranks return ranks
# 设置重排序页面限制 # 设置重排序页面限制
RERANK_PAGE_LIMIT = 3 RERANK_LIMIT = 64
RERANK_LIMIT = int(RERANK_LIMIT // page_size + ((RERANK_LIMIT % page_size) / (page_size * 1.0) + 0.5)) * page_size if page_size > 1 else 1
if RERANK_LIMIT < 1:
RERANK_LIMIT = 1
# 构建检索请求参数 # 构建检索请求参数
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size * RERANK_PAGE_LIMIT, 128), req = {
"question": question, "vector": True, "topk": top, "kb_ids": kb_ids,
"similarity": similarity_threshold, "doc_ids": doc_ids,
"available_int": 1} "page": math.ceil(page_size * page / RERANK_LIMIT),
"size": RERANK_LIMIT,
# 如果页码超过重排序限制,直接请求指定页的数据 "question": question,
if page > RERANK_PAGE_LIMIT: "vector": True,
req["page"] = page "topk": top,
req["size"] = page_size "similarity": similarity_threshold,
"available_int": 1,
}
# 处理租户ID格式 # 处理租户ID格式
if isinstance(tenant_ids, str): if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",") tenant_ids = tenant_ids.split(",")
# 执行搜索操作 # 执行搜索操作
sres = self.search(req, [index_name(tid) for tid in tenant_ids], sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
ranks["total"] = sres.total
# 根据页码决定是否需要重排序 if rerank_mdl and sres.total > 0:
if page <= RERANK_PAGE_LIMIT: sim, tsim, vsim = self.rerank_by_model(rerank_mdl, sres, question, 1 - vector_similarity_weight, vector_similarity_weight, rank_feature=rank_feature)
# 前几页需要重排序以提高结果质量
if rerank_mdl and sres.total > 0:
# 使用重排序模型进行重排序
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
sres, question, 1 - vector_similarity_weight,
vector_similarity_weight,
rank_feature=rank_feature)
else:
# 使用默认方法进行重排序
sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
rank_feature=rank_feature)
# 根据相似度降序排序,并选择当前页的结果
idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size]
else: else:
# 后续页面不需要重排序,直接使用搜索结果 sim, tsim, vsim = self.rerank(sres, question, 1 - vector_similarity_weight, vector_similarity_weight, rank_feature=rank_feature)
sim = tsim = vsim = [1] * len(sres.ids) # Already paginated in search function
idx = list(range(len(sres.ids))) idx = np.argsort(sim * -1)[(page - 1) * page_size : page * page_size]
# 获取向量维度和列名
dim = len(sres.query_vector) dim = len(sres.query_vector)
vector_column = f"q_{dim}_vec" vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim zero_vector = [0.0] * dim
if doc_ids:
# 处理每个检索结果 similarity_threshold = 0
page_size = 30
sim_np = np.array(sim)
filtered_count = (sim_np >= similarity_threshold).sum()
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
for i in idx: for i in idx:
# 过滤低于阈值的结果
if sim[i] < similarity_threshold: if sim[i] < similarity_threshold:
break break
# 控制返回结果数量
if len(ranks["chunks"]) >= page_size: if len(ranks["chunks"]) >= page_size:
if aggs: if aggs:
continue continue
@ -468,7 +474,6 @@ class Dealer:
dnm = chunk.get("docnm_kwd", "") dnm = chunk.get("docnm_kwd", "")
did = chunk.get("doc_id", "") did = chunk.get("doc_id", "")
position_int = chunk.get("position_int", []) position_int = chunk.get("position_int", [])
# 构建结果字典
d = { d = {
"chunk_id": id, "chunk_id": id,
"content_ltks": chunk["content_ltks"], "content_ltks": chunk["content_ltks"],
@ -483,9 +488,8 @@ class Dealer:
"term_similarity": tsim[i], "term_similarity": tsim[i],
"vector": chunk.get(vector_column, zero_vector), "vector": chunk.get(vector_column, zero_vector),
"positions": position_int, "positions": position_int,
"doc_type_kwd": chunk.get("doc_type_kwd", ""),
} }
# 处理高亮内容
if highlight and sres.highlight: if highlight and sres.highlight:
if id in sres.highlight: if id in sres.highlight:
d["highlight"] = rmSpace(sres.highlight[id]) d["highlight"] = rmSpace(sres.highlight[id])
@ -495,12 +499,7 @@ class Dealer:
if dnm not in ranks["doc_aggs"]: if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
ranks["doc_aggs"][dnm]["count"] += 1 ranks["doc_aggs"][dnm]["count"] += 1
# 将文档聚合信息转换为列表格式,并按计数降序排序 ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k, v in sorted(ranks["doc_aggs"].items(), key=lambda x: x[1]["count"] * -1)]
ranks["doc_aggs"] = [{"doc_name": k,
"doc_id": v["doc_id"],
"count": v["count"]} for k,
v in sorted(ranks["doc_aggs"].items(),
key=lambda x: x[1]["count"] * -1)]
ranks["chunks"] = ranks["chunks"][:page_size] ranks["chunks"] = ranks["chunks"][:page_size]
return ranks return ranks
@ -509,16 +508,12 @@ class Dealer:
tbl = self.dataStore.sql(sql, fetch_size, format) tbl = self.dataStore.sql(sql, fetch_size, format)
return tbl return tbl
def chunk_list(self, doc_id: str, tenant_id: str, def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, offset=0, fields=["docnm_kwd", "content_with_weight", "img_id"]):
kb_ids: list[str], max_count=1024,
offset=0,
fields=["docnm_kwd", "content_with_weight", "img_id"]):
condition = {"doc_id": doc_id} condition = {"doc_id": doc_id}
res = [] res = []
bs = 128 bs = 128
for p in range(offset, max_count, bs): for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), kb_ids)
kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields) dict_chunks = self.dataStore.getFields(es_res, fields)
for id, doc in dict_chunks.items(): for id, doc in dict_chunks.items():
doc["id"] = id doc["id"] = id
@ -548,8 +543,7 @@ class Dealer:
if not aggs: if not aggs:
return False return False
cnt = np.sum([c for _, c in aggs]) cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], key=lambda x: x[1] * -1)[:topn_tags]
key=lambda x: x[1] * -1)[:topn_tags]
doc[TAG_FLD] = {a: c for a, c in tag_fea if c > 0} doc[TAG_FLD] = {a: c for a, c in tag_fea if c > 0}
return True return True
@ -564,6 +558,5 @@ class Dealer:
if not aggs: if not aggs:
return {} return {}
cnt = np.sum([c for _, c in aggs]) cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], key=lambda x: x[1] * -1)[:topn_tags]
key=lambda x: x[1] * -1)[:topn_tags] return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
return {a: max(1, c) for a, c in tag_fea}

View File

@ -25,20 +25,18 @@ from api.utils.file_utils import get_project_base_directory
class Dealer: class Dealer:
def __init__(self, redis=None): def __init__(self, redis=None):
self.lookup_num = 100000000 self.lookup_num = 100000000
self.load_tm = time.time() - 1000000 self.load_tm = time.time() - 1000000
self.dictionary = None self.dictionary = None
path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json") path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json")
try: try:
self.dictionary = json.load(open(path, 'r')) self.dictionary = json.load(open(path, "r"))
except Exception: except Exception:
logging.warning("Missing synonym.json") logging.warning("Missing synonym.json")
self.dictionary = {} self.dictionary = {}
if not redis: if not redis:
logging.warning( logging.warning("Realtime synonym is disabled, since no redis connection.")
"Realtime synonym is disabled, since no redis connection.")
if not len(self.dictionary.keys()): if not len(self.dictionary.keys()):
logging.warning("Fail to load synonym") logging.warning("Fail to load synonym")
@ -67,18 +65,36 @@ class Dealer:
logging.error("Fail to load synonym!" + str(e)) logging.error("Fail to load synonym!" + str(e))
def lookup(self, tk, topn=8): def lookup(self, tk, topn=8):
"""
查找输入词条(tk)的同义词支持英文和中文混合处理
参数:
tk (str): 待查询的词条"happy""苹果"
topn (int): 最多返回的同义词数量默认为8
返回:
list: 同义词列表可能为空无同义词时
处理逻辑:
1. 英文单词使用WordNet语义网络查询
2. 中文/其他从预加载的自定义词典查询
"""
# 英文单词处理分支
if re.match(r"[a-z]+$", tk): if re.match(r"[a-z]+$", tk):
res = list(set([re.sub("_", " ", syn.name().split(".")[0]) for syn in wordnet.synsets(tk)]) - set([tk])) res = list(set([re.sub("_", " ", syn.name().split(".")[0]) for syn in wordnet.synsets(tk)]) - set([tk]))
return [t for t in res if t] return [t for t in res if t]
# 中文/其他词条处理
self.lookup_num += 1 self.lookup_num += 1
self.load() self.load() # 自定义词典
# 从字典获取同义词,默认返回空列表
res = self.dictionary.get(re.sub(r"[ \t]+", " ", tk.lower()), []) res = self.dictionary.get(re.sub(r"[ \t]+", " ", tk.lower()), [])
# 兼容处理:如果字典值是字符串,转为单元素列表
if isinstance(res, str): if isinstance(res, str):
res = [res] res = [res]
return res[:topn] return res[:topn]
if __name__ == '__main__': if __name__ == "__main__":
dl = Dealer() dl = Dealer()
print(dl.dictionary) print(dl.dictionary)

View File

@ -1,4 +1,4 @@
# #
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -26,37 +26,41 @@ from api.utils.file_utils import get_project_base_directory
class Dealer: class Dealer:
def __init__(self): def __init__(self):
self.stop_words = set(["请问", self.stop_words = set(
"", [
"", "请问",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"", "",
"#", "",
"什么", "",
"怎么", "#",
"哪个", "什么",
"哪些", "怎么",
"", "哪个",
"相关"]) "哪些",
"",
"相关",
]
)
def load_dict(fnm): def load_dict(fnm):
res = {} res = {}
@ -90,50 +94,45 @@ class Dealer:
logging.warning("Load term.freq FAIL!") logging.warning("Load term.freq FAIL!")
def pretoken(self, txt, num=False, stpwd=True): def pretoken(self, txt, num=False, stpwd=True):
patt = [ patt = [r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"]
r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]" rewt = []
]
rewt = [
]
for p, r in rewt: for p, r in rewt:
txt = re.sub(p, r, txt) txt = re.sub(p, r, txt)
res = [] res = []
for t in rag_tokenizer.tokenize(txt).split(): for t in rag_tokenizer.tokenize(txt).split():
tk = t tk = t
if (stpwd and tk in self.stop_words) or ( if (stpwd and tk in self.stop_words) or (re.match(r"[0-9]$", tk) and not num):
re.match(r"[0-9]$", tk) and not num):
continue continue
for p in patt: for p in patt:
if re.match(p, t): if re.match(p, t):
tk = "#" tk = "#"
break break
#tk = re.sub(r"([\+\\-])", r"\\\1", tk) # tk = re.sub(r"([\+\\-])", r"\\\1", tk)
if tk != "#" and tk: if tk != "#" and tk:
res.append(tk) res.append(tk)
return res return res
def tokenMerge(self, tks): def tokenMerge(self, tks):
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t) def oneTerm(t):
return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0 res, i = [], 0
while i < len(tks): while i < len(tks):
j = i j = i
if i == 0 and oneTerm(tks[i]) and len( if i == 0 and oneTerm(tks[i]) and len(tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
res.append(" ".join(tks[0:2])) res.append(" ".join(tks[0:2]))
i = 2 i = 2
continue continue
while j < len( while j < len(tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
j += 1 j += 1
if j - i > 1: if j - i > 1:
if j - i < 5: if j - i < 5:
res.append(" ".join(tks[i:j])) res.append(" ".join(tks[i:j]))
i = j i = j
else: else:
res.append(" ".join(tks[i:i + 2])) res.append(" ".join(tks[i : i + 2]))
i = i + 2 i = i + 2
else: else:
if len(tks[i]) > 0: if len(tks[i]) > 0:
@ -159,9 +158,7 @@ class Dealer:
""" """
tks = [] tks = []
for t in re.sub(r"[ \t]+", " ", txt).split(): for t in re.sub(r"[ \t]+", " ", txt).split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \ if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and re.match(r".*[a-zA-Z]$", t) and tks and self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
re.match(r".*[a-zA-Z]$", t) and tks and \
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
tks[-1] = tks[-1] + " " + t tks[-1] = tks[-1] + " " + t
else: else:
tks.append(t) tks.append(t)
@ -180,8 +177,7 @@ class Dealer:
return 0.01 return 0.01
if not self.ne or t not in self.ne: if not self.ne or t not in self.ne:
return 1 return 1
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, "firstnm": 1}
"firstnm": 1}
return m[self.ne[t]] return m[self.ne[t]]
def postag(t): def postag(t):
@ -208,7 +204,7 @@ class Dealer:
if not s and len(t) >= 4: if not s and len(t) >= 4:
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1] s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
if len(s) > 1: if len(s) > 1:
s = np.min([freq(tt) for tt in s]) / 6. s = np.min([freq(tt) for tt in s]) / 6.0
else: else:
s = 0 s = 0
@ -224,18 +220,18 @@ class Dealer:
elif len(t) >= 4: elif len(t) >= 4:
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1] s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
if len(s) > 1: if len(s) > 1:
return max(3, np.min([df(tt) for tt in s]) / 6.) return max(3, np.min([df(tt) for tt in s]) / 6.0)
return 3 return 3
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5))) def idf(s, N):
return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
tw = [] tw = []
if not preprocess: if not preprocess:
idf1 = np.array([idf(freq(t), 10000000) for t in tks]) idf1 = np.array([idf(freq(t), 10000000) for t in tks])
idf2 = np.array([idf(df(t), 1000000000) for t in tks]) idf2 = np.array([idf(df(t), 1000000000) for t in tks])
wts = (0.3 * idf1 + 0.7 * idf2) * \ wts = (0.3 * idf1 + 0.7 * idf2) * np.array([ner(t) * postag(t) for t in tks])
np.array([ner(t) * postag(t) for t in tks])
wts = [s for s in wts] wts = [s for s in wts]
tw = list(zip(tks, wts)) tw = list(zip(tks, wts))
else: else:
@ -243,8 +239,7 @@ class Dealer:
tt = self.tokenMerge(self.pretoken(tk, True)) tt = self.tokenMerge(self.pretoken(tk, True))
idf1 = np.array([idf(freq(t), 10000000) for t in tt]) idf1 = np.array([idf(freq(t), 10000000) for t in tt])
idf2 = np.array([idf(df(t), 1000000000) for t in tt]) idf2 = np.array([idf(df(t), 1000000000) for t in tt])
wts = (0.3 * idf1 + 0.7 * idf2) * \ wts = (0.3 * idf1 + 0.7 * idf2) * np.array([ner(t) * postag(t) for t in tt])
np.array([ner(t) * postag(t) for t in tt])
wts = [s for s in wts] wts = [s for s in wts]
tw.extend(zip(tt, wts)) tw.extend(zip(tt, wts))

View File

@ -42,6 +42,10 @@ def chunks_format(reference):
"image_id": get_value(chunk, "image_id", "img_id"), "image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"), "positions": get_value(chunk, "positions", "position_int"),
"url": chunk.get("url"), "url": chunk.get("url"),
"similarity": chunk.get("similarity"),
"vector_similarity": chunk.get("vector_similarity"),
"term_similarity": chunk.get("term_similarity"),
"doc_type": chunk.get("doc_type_kwd"),
} }
for chunk in reference.get("chunks", []) for chunk in reference.get("chunks", [])
] ]
@ -145,15 +149,17 @@ def kb_prompt(kbinfos, max_tokens):
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []}) doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]): for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
doc2chunks[ck["docnm_kwd"]]["chunks"].append((f"URL: {ck['url']}\n" if "url" in ck else "") + f"ID: {i}\n" + ck["content_with_weight"]) cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
cnt += ck["content_with_weight"]
doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {}) doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
knowledges = [] knowledges = []
for nm, cks_meta in doc2chunks.items(): for nm, cks_meta in doc2chunks.items():
txt = f"\nDocument: {nm} \n" txt = f"\n文档: {nm} \n"
for k, v in cks_meta["meta"].items(): for k, v in cks_meta["meta"].items():
txt += f"{k}: {v}\n" txt += f"{k}: {v}\n"
txt += "Relevant fragments as following:\n" txt += "相关片段如下:\n"
for i, chunk in enumerate(cks_meta["chunks"], 1): for i, chunk in enumerate(cks_meta["chunks"], 1):
txt += f"{chunk}\n" txt += f"{chunk}\n"
knowledges.append(txt) knowledges.append(txt)
@ -388,3 +394,57 @@ Output:
except Exception as e: except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}") logging.exception(f"JSON parsing error: {result} -> {e}")
raise e raise e
def vision_llm_describe_prompt(page=None) -> str:
prompt_en = """
INSTRUCTION:
Transcribe the content from the provided PDF page image into clean Markdown format.
- Only output the content transcribed from the image.
- Do NOT output this instruction or any other explanation.
- If the content is missing or you do not understand the input, return an empty string.
RULES:
1. Do NOT generate examples, demonstrations, or templates.
2. Do NOT output any extra text such as 'Example', 'Example Output', or similar.
3. Do NOT generate any tables, headings, or content that is not explicitly present in the image.
4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content.
5. Do NOT explain Markdown or mention that you are using Markdown.
6. Do NOT wrap the output in ```markdown or ``` blocks.
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
8. Preserve the original language, information, and order exactly as shown in the image.
"""
if page is not None:
prompt_en += f"\nAt the end of the transcription, add the page divider: `--- Page {page} ---`."
prompt_en += """
FAILURE HANDLING:
- If you do not detect valid content in the image, return an empty string.
"""
return prompt_en
def vision_llm_figure_describe_prompt() -> str:
prompt = """
You are an expert visual data analyst. Analyze the image and provide a comprehensive description of its content. Focus on identifying the type of visual data representation (e.g., bar chart, pie chart, line graph, table, flowchart), its structure, and any text captions or labels included in the image.
Tasks:
1. Describe the overall structure of the visual representation. Specify if it is a chart, graph, table, or diagram.
2. Identify and extract any axes, legends, titles, or labels present in the image. Provide the exact text where available.
3. Extract the data points from the visual elements (e.g., bar heights, line graph coordinates, pie chart segments, table rows and columns).
4. Analyze and explain any trends, comparisons, or patterns shown in the data.
5. Capture any annotations, captions, or footnotes, and explain their relevance to the image.
6. Only include details that are explicitly present in the image. If an element (e.g., axis, legend, or caption) does not exist or is not visible, do not mention it.
Output format (include only sections relevant to the image content):
- Visual Type: [Type]
- Title: [Title text, if available]
- Axes / Legends / Labels: [Details, if available]
- Data Points: [Extracted data]
- Trends / Insights: [Analysis and interpretation]
- Captions / Annotations: [Text and relevance, if available]
Ensure high accuracy, clarity, and completeness in your analysis, and includes only the information present in the image. Avoid unnecessary statements about missing elements.
"""
return prompt

File diff suppressed because it is too large Load Diff