RAGflow/rag/nlp/query.py

296 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import json
import re
from rag.utils.doc_store_conn import MatchTextExpr
from rag.nlp import rag_tokenizer, term_weight, synonym
class FulltextQueryer:
def __init__(self):
self.tw = term_weight.Dealer()
self.syn = synonym.Dealer()
self.query_fields = [
"title_tks^10",
"title_sm_tks^5",
"important_kwd^30",
"important_tks^20",
"question_tks^20",
"content_ltks^2",
"content_sm_ltks",
]
@staticmethod
def subSpecialChar(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod
def isChinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1.0 / len(arr) >= 0.7
@staticmethod
def rmWWW(txt):
patts = [
(
r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
" ")
]
otxt = txt
for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
if not txt:
txt = otxt
return txt
def question(self, txt, tbl="qa", min_match: float = 0.6):
"""
处理用户问题并生成全文检索表达式
参数:
txt: 原始问题文本
tbl: 查询表名(默认"qa")
min_match: 最小匹配阈值(默认0.6)
返回:
MatchTextExpr: 全文检索表达式对象
list: 提取的关键词列表
"""
# 1. 文本预处理:去除特殊字符、繁体转简体、全角转半角、转小写
txt = re.sub(
r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+",
" ",
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
).strip()
txt = FulltextQueryer.rmWWW(txt) # 去除停用词
# 2. 非中文文本处理
if not self.isChinese(txt):
txt = FulltextQueryer.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split()
keywords = [t for t in tks if t]
tks_w = self.tw.weights(tks, preprocess=False)
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
syns = []
for tk, w in tks_w[:256]:
syn = self.syn.lookup(tk)
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
keywords.extend(syn)
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
syns.append(" ".join(syn))
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
tk and not re.match(r"[.^+\(\)-]", tk)]
for i in range(1, len(tks_w)):
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
if not left or not right:
continue
q.append(
'"%s %s"^%.4f'
% (
tks_w[i - 1][0],
tks_w[i][0],
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
)
)
if not q:
q.append(txt)
query = " ".join(q)
return MatchTextExpr(
self.query_fields, query, 100
), keywords
def need_fine_grained_tokenize(tk):
"""
判断是否需要细粒度分词
参数:
tk: 待判断的词条
返回:
bool: True表示需要细粒度分词
"""
if len(tk) < 3:
return False
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
return False
return True
txt = FulltextQueryer.rmWWW(txt) # 二次去除停用词
qs, keywords = [], [] # 初始化查询表达式和关键词列表
# 3. 中文文本处理最多处理256个词
for tt in self.tw.split(txt)[:256]: # .split():
if not tt:
continue
# 3.1 基础关键词收集
keywords.append(tt)
twts = self.tw.weights([tt]) # 获取词权重
syns = self.syn.lookup(tt) # 查询同义词
# 3.2 同义词扩展最多扩展到32个关键词
if syns and len(keywords) < 32:
keywords.extend(syns)
logging.debug(json.dumps(twts, ensure_ascii=False))
tms = []
# 3.3 处理每个词及其权重
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
# 3.3.1 细粒度分词处理
sm = (
rag_tokenizer.fine_grained_tokenize(tk).split()
if need_fine_grained_tokenize(tk)
else []
)
# 3.3.2 清洗分词结果
sm = [
re.sub(
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
"",
m,
)
for m in sm
]
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
# 3.3.3 收集关键词不超过32个
if len(keywords) < 32:
keywords.append(re.sub(r"[ \\\"']+", "", tk))
keywords.extend(sm)
# 3.3.4 同义词处理
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
# 关键词数量限制
if len(keywords) >= 32:
break
# 3.3.5 构建查询表达式
tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk # 处理短语查询
if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns) # 添加同义词查询
if sm:
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm)) # 添加细粒度分词查询
if tk.strip():
tms.append((tk, w)) # 保存带权重的查询表达式
# 3.4 合并当前词的查询表达式
tms = " ".join([f"({t})^{w}" for t, w in tms])
# 3.5 添加相邻词组合查询(提升短语匹配权重)
if len(twts) > 1:
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
# 3.6 处理同义词查询表达式
syns = " OR ".join(
[
'"%s"'
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
for s in syns
]
)
if syns and tms:
tms = f"({tms})^5 OR ({syns})^0.7"
qs.append(tms) # 添加到最终查询列表
# 4. 生成最终查询表达式
if qs:
query = " OR ".join([f"({t})" for t in qs if t])
return MatchTextExpr(
self.query_fields, query, 100, {"minimum_should_match": min_match}
), keywords
return None, keywords
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np
sims = CosineSimilarity([avec], bvecs)
tksim = self.token_similarity(atks, btkss)
if np.sum(sims[0]) == 0:
return np.array(tksim), tksim, sims[0]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
def token_similarity(self, atks, btkss):
def toDict(tks):
d = {}
if isinstance(tks, str):
tks = tks.split()
for t, c in self.tw.weights(tks, preprocess=False):
if t not in d:
d[t] = 0
d[t] += c
return d
atks = toDict(atks)
btkss = [toDict(tks) for tks in btkss]
return [self.similarity(atks, btks) for btks in btkss]
def similarity(self, qtwt, dtwt):
if isinstance(dtwt, type("")):
dtwt = {t: w for t, w in self.tw.weights(self.tw.split(dtwt), preprocess=False)}
if isinstance(qtwt, type("")):
qtwt = {t: w for t, w in self.tw.weights(self.tw.split(qtwt), preprocess=False)}
s = 1e-9
for k, v in qtwt.items():
if k in dtwt:
s += v # * dtwt[k]
q = 1e-9
for k, v in qtwt.items():
q += v
return s / q
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
if isinstance(content_tks, str):
content_tks = [c.strip() for c in content_tks.strip() if c.strip()]
tks_w = self.tw.weights(content_tks, preprocess=False)
keywords = [f'"{k.strip()}"' for k in keywords]
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
if tk:
keywords.append(f"{tk}^{w}")
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
{"minimum_should_match": min(3, len(keywords) / 10)})