512 lines
18 KiB
Python
512 lines
18 KiB
Python
#
|
||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
#
|
||
|
||
import logging
|
||
import copy
|
||
import datrie
|
||
import math
|
||
import os
|
||
import re
|
||
import string
|
||
import sys
|
||
from pathlib import Path
|
||
from hanziconv import HanziConv
|
||
from nltk import word_tokenize
|
||
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
||
|
||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||
from api.utils.file_utils import get_project_base_directory
|
||
|
||
|
||
class RagTokenizer:
|
||
def key_(self, line):
|
||
return str(line.lower().encode("utf-8"))[2:-1]
|
||
|
||
def rkey_(self, line):
|
||
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
||
|
||
def loadDict_(self, fnm):
|
||
logging.info(f"[HUQIE]:Build trie from {fnm}")
|
||
try:
|
||
of = open(fnm, "r", encoding="utf-8")
|
||
while True:
|
||
line = of.readline()
|
||
if not line:
|
||
break
|
||
line = re.sub(r"[\r\n]+", "", line)
|
||
line = re.split(r"[ \t]", line)
|
||
k = self.key_(line[0])
|
||
F = int(math.log(float(line[1]) / self.DENOMINATOR) + 0.5)
|
||
if k not in self.trie_ or self.trie_[k][0] < F:
|
||
self.trie_[self.key_(line[0])] = (F, line[2])
|
||
self.trie_[self.rkey_(line[0])] = 1
|
||
|
||
dict_file_cache = fnm + ".trie"
|
||
logging.info(f"[HUQIE]:Build trie cache to {dict_file_cache}")
|
||
self.trie_.save(dict_file_cache)
|
||
of.close()
|
||
except Exception:
|
||
logging.exception(f"[HUQIE]:Build trie {fnm} failed")
|
||
|
||
def __init__(self, debug=False):
|
||
self.DEBUG = debug
|
||
self.DENOMINATOR = 1000000
|
||
self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
|
||
|
||
self.stemmer = PorterStemmer()
|
||
self.lemmatizer = WordNetLemmatizer()
|
||
|
||
self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)"
|
||
|
||
trie_file_name = self.DIR_ + ".txt.trie"
|
||
# check if trie file existence
|
||
if os.path.exists(trie_file_name):
|
||
try:
|
||
# load trie from file
|
||
self.trie_ = datrie.Trie.load(trie_file_name)
|
||
return
|
||
except Exception:
|
||
# fail to load trie from file, build default trie
|
||
logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
|
||
self.trie_ = datrie.Trie(string.printable)
|
||
else:
|
||
# file not exist, build default trie
|
||
logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
|
||
self.trie_ = datrie.Trie(string.printable)
|
||
|
||
# load data from dict file and save to trie file
|
||
self.loadDict_(self.DIR_ + ".txt")
|
||
|
||
def loadUserDict(self, fnm):
|
||
try:
|
||
self.trie_ = datrie.Trie.load(fnm + ".trie")
|
||
return
|
||
except Exception:
|
||
self.trie_ = datrie.Trie(string.printable)
|
||
self.loadDict_(fnm)
|
||
|
||
def addUserDict(self, fnm):
|
||
self.loadDict_(fnm)
|
||
|
||
def _strQ2B(self, ustring):
|
||
"""Convert full-width characters to half-width characters"""
|
||
rstring = ""
|
||
for uchar in ustring:
|
||
inside_code = ord(uchar)
|
||
if inside_code == 0x3000:
|
||
inside_code = 0x0020
|
||
else:
|
||
inside_code -= 0xFEE0
|
||
if inside_code < 0x0020 or inside_code > 0x7E: # After the conversion, if it's not a half-width character, return the original character.
|
||
rstring += uchar
|
||
else:
|
||
rstring += chr(inside_code)
|
||
return rstring
|
||
|
||
def _tradi2simp(self, line):
|
||
return HanziConv.toSimplified(line)
|
||
|
||
def dfs_(self, chars, s, preTks, tkslist):
|
||
res = s
|
||
# if s > MAX_L or s>= len(chars):
|
||
if s >= len(chars):
|
||
tkslist.append(preTks)
|
||
return res
|
||
|
||
# pruning
|
||
S = s + 1
|
||
if s + 2 <= len(chars):
|
||
t1, t2 = "".join(chars[s : s + 1]), "".join(chars[s : s + 2])
|
||
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
|
||
S = s + 2
|
||
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
||
t1 = preTks[-1][0] + "".join(chars[s : s + 1])
|
||
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
||
S = s + 2
|
||
|
||
################
|
||
for e in range(S, len(chars) + 1):
|
||
t = "".join(chars[s:e])
|
||
k = self.key_(t)
|
||
|
||
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
|
||
break
|
||
|
||
if k in self.trie_:
|
||
pretks = copy.deepcopy(preTks)
|
||
if k in self.trie_:
|
||
pretks.append((t, self.trie_[k]))
|
||
else:
|
||
pretks.append((t, (-12, "")))
|
||
res = max(res, self.dfs_(chars, e, pretks, tkslist))
|
||
|
||
if res > s:
|
||
return res
|
||
|
||
t = "".join(chars[s : s + 1])
|
||
k = self.key_(t)
|
||
if k in self.trie_:
|
||
preTks.append((t, self.trie_[k]))
|
||
else:
|
||
preTks.append((t, (-12, "")))
|
||
|
||
return self.dfs_(chars, s + 1, preTks, tkslist)
|
||
|
||
def freq(self, tk):
|
||
k = self.key_(tk)
|
||
if k not in self.trie_:
|
||
return 0
|
||
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
|
||
|
||
def tag(self, tk):
|
||
k = self.key_(tk)
|
||
if k not in self.trie_:
|
||
return ""
|
||
return self.trie_[k][1]
|
||
|
||
def score_(self, tfts):
|
||
B = 30
|
||
F, L, tks = 0, 0, []
|
||
for tk, (freq, tag) in tfts:
|
||
F += freq
|
||
L += 0 if len(tk) < 2 else 1
|
||
tks.append(tk)
|
||
# F /= len(tks)
|
||
L /= len(tks)
|
||
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
|
||
return tks, B / len(tks) + L + F
|
||
|
||
def sortTks_(self, tkslist):
|
||
res = []
|
||
for tfts in tkslist:
|
||
tks, s = self.score_(tfts)
|
||
res.append((tks, s))
|
||
return sorted(res, key=lambda x: x[1], reverse=True)
|
||
|
||
def merge_(self, tks):
|
||
# if split chars is part of token
|
||
res = []
|
||
tks = re.sub(r"[ ]+", " ", tks).split()
|
||
s = 0
|
||
while True:
|
||
if s >= len(tks):
|
||
break
|
||
E = s + 1
|
||
for e in range(s + 2, min(len(tks) + 2, s + 6)):
|
||
tk = "".join(tks[s:e])
|
||
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
|
||
E = e
|
||
res.append("".join(tks[s:E]))
|
||
s = E
|
||
|
||
return " ".join(res)
|
||
|
||
def maxForward_(self, line):
|
||
res = []
|
||
s = 0
|
||
while s < len(line):
|
||
e = s + 1
|
||
t = line[s:e]
|
||
while e < len(line) and self.trie_.has_keys_with_prefix(self.key_(t)):
|
||
e += 1
|
||
t = line[s:e]
|
||
|
||
while e - 1 > s and self.key_(t) not in self.trie_:
|
||
e -= 1
|
||
t = line[s:e]
|
||
|
||
if self.key_(t) in self.trie_:
|
||
res.append((t, self.trie_[self.key_(t)]))
|
||
else:
|
||
res.append((t, (0, "")))
|
||
|
||
s = e
|
||
|
||
return self.score_(res)
|
||
|
||
def maxBackward_(self, line):
|
||
res = []
|
||
s = len(line) - 1
|
||
while s >= 0:
|
||
e = s + 1
|
||
t = line[s:e]
|
||
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
|
||
s -= 1
|
||
t = line[s:e]
|
||
|
||
while s + 1 < e and self.key_(t) not in self.trie_:
|
||
s += 1
|
||
t = line[s:e]
|
||
|
||
if self.key_(t) in self.trie_:
|
||
res.append((t, self.trie_[self.key_(t)]))
|
||
else:
|
||
res.append((t, (0, "")))
|
||
|
||
s -= 1
|
||
|
||
return self.score_(res[::-1])
|
||
|
||
def english_normalize_(self, tks):
|
||
return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
|
||
|
||
def _split_by_lang(self, line):
|
||
txt_lang_pairs = []
|
||
arr = re.split(self.SPLIT_CHAR, line)
|
||
for a in arr:
|
||
if not a:
|
||
continue
|
||
s = 0
|
||
e = s + 1
|
||
zh = is_chinese(a[s])
|
||
while e < len(a):
|
||
_zh = is_chinese(a[e])
|
||
if _zh == zh:
|
||
e += 1
|
||
continue
|
||
txt_lang_pairs.append((a[s:e], zh))
|
||
s = e
|
||
e = s + 1
|
||
zh = _zh
|
||
if s >= len(a):
|
||
continue
|
||
txt_lang_pairs.append((a[s:e], zh))
|
||
return txt_lang_pairs
|
||
|
||
def tokenize(self, line):
|
||
line = re.sub(r"\W+", " ", line)
|
||
line = self._strQ2B(line).lower()
|
||
line = self._tradi2simp(line)
|
||
|
||
arr = self._split_by_lang(line)
|
||
res = []
|
||
for L, lang in arr:
|
||
if not lang:
|
||
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
|
||
continue
|
||
if len(L) < 2 or re.match(r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
|
||
res.append(L)
|
||
continue
|
||
|
||
# use maxforward for the first time
|
||
tks, s = self.maxForward_(L)
|
||
tks1, s1 = self.maxBackward_(L)
|
||
if self.DEBUG:
|
||
logging.debug("[FW] {} {}".format(tks, s))
|
||
logging.debug("[BW] {} {}".format(tks1, s1))
|
||
|
||
i, j, _i, _j = 0, 0, 0, 0
|
||
same = 0
|
||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||
same += 1
|
||
if same > 0:
|
||
res.append(" ".join(tks[j : j + same]))
|
||
_i = i + same
|
||
_j = j + same
|
||
j = _j + 1
|
||
i = _i + 1
|
||
|
||
while i < len(tks1) and j < len(tks):
|
||
tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
|
||
if tk1 != tk:
|
||
if len(tk1) > len(tk):
|
||
j += 1
|
||
else:
|
||
i += 1
|
||
continue
|
||
|
||
if tks1[i] != tks[j]:
|
||
i += 1
|
||
j += 1
|
||
continue
|
||
# backward tokens from_i to i are different from forward tokens from _j to j.
|
||
tkslist = []
|
||
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
|
||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||
|
||
same = 1
|
||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||
same += 1
|
||
res.append(" ".join(tks[j : j + same]))
|
||
_i = i + same
|
||
_j = j + same
|
||
j = _j + 1
|
||
i = _i + 1
|
||
|
||
if _i < len(tks1):
|
||
assert _j < len(tks)
|
||
assert "".join(tks1[_i:]) == "".join(tks[_j:])
|
||
tkslist = []
|
||
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
|
||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||
|
||
res = " ".join(res)
|
||
logging.debug("[TKS] {}".format(self.merge_(res)))
|
||
return self.merge_(res)
|
||
|
||
def fine_grained_tokenize(self, tks):
|
||
"""
|
||
细粒度分词方法,根据文本特征(中英文比例、数字符号等)动态选择分词策略
|
||
|
||
参数:
|
||
tks (str): 待分词的文本字符串
|
||
|
||
返回:
|
||
str: 分词后的结果(用空格连接的词序列)
|
||
|
||
处理逻辑:
|
||
1. 先按空格初步切分文本
|
||
2. 根据中文占比决定是否启用细粒度分词
|
||
3. 对特殊格式(短词、纯数字等)直接保留原样
|
||
4. 对长词或复杂词使用DFS回溯算法寻找最优切分
|
||
5. 对英文词进行额外校验和规范化处理
|
||
"""
|
||
# 初始切分:按空格分割输入文本
|
||
tks = tks.split()
|
||
# 计算中文词占比(判断是否主要包含中文内容)
|
||
zh_num = len([1 for c in tks if c and is_chinese(c[0])])
|
||
# 如果中文占比低于20%,则按简单规则处理(主要处理英文混合文本)
|
||
if zh_num < len(tks) * 0.2:
|
||
res = []
|
||
for tk in tks:
|
||
res.extend(tk.split("/"))
|
||
return " ".join(res)
|
||
|
||
# 中文或复杂文本处理流程
|
||
res = []
|
||
for tk in tks:
|
||
# 规则1:跳过短词(长度<3)或纯数字/符号组合(如"3.14")
|
||
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
|
||
res.append(tk)
|
||
continue
|
||
|
||
# 初始化候选分词列表
|
||
tkslist = []
|
||
|
||
# 规则2:超长词(长度>10)直接保留不切分
|
||
if len(tk) > 10:
|
||
tkslist.append(tk)
|
||
else:
|
||
# 使用DFS回溯算法寻找所有可能的分词组合
|
||
self.dfs_(tk, 0, [], tkslist)
|
||
|
||
# 规则3:若无有效切分方案则保留原词
|
||
if len(tkslist) < 2:
|
||
res.append(tk)
|
||
continue
|
||
|
||
# 从候选方案中选择最优切分(通过sortTks_排序)
|
||
stk = self.sortTks_(tkslist)[1][0]
|
||
|
||
# 规则4:若切分结果与原词长度相同则视为无效切分
|
||
if len(stk) == len(tk):
|
||
stk = tk
|
||
else:
|
||
# 英文特殊处理:检查子词长度是否合法
|
||
if re.match(r"[a-z\.-]+$", tk):
|
||
for t in stk:
|
||
if len(t) < 3:
|
||
stk = tk
|
||
break
|
||
else:
|
||
stk = " ".join(stk)
|
||
else:
|
||
stk = " ".join(stk)
|
||
# 中文词直接拼接结果
|
||
res.append(stk)
|
||
|
||
return " ".join(self.english_normalize_(res))
|
||
|
||
|
||
def is_chinese(s):
|
||
if s >= "\u4e00" and s <= "\u9fa5":
|
||
return True
|
||
else:
|
||
return False
|
||
|
||
|
||
def is_number(s):
|
||
if s >= "\u0030" and s <= "\u0039":
|
||
return True
|
||
else:
|
||
return False
|
||
|
||
|
||
def is_alphabet(s):
|
||
if (s >= "\u0041" and s <= "\u005a") or (s >= "\u0061" and s <= "\u007a"):
|
||
return True
|
||
else:
|
||
return False
|
||
|
||
|
||
def naiveQie(txt):
|
||
tks = []
|
||
for t in txt.split():
|
||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and re.match(r".*[a-zA-Z]$", t):
|
||
tks.append(" ")
|
||
tks.append(t)
|
||
return tks
|
||
|
||
|
||
tokenizer = RagTokenizer()
|
||
tokenize = tokenizer.tokenize
|
||
fine_grained_tokenize = tokenizer.fine_grained_tokenize
|
||
tag = tokenizer.tag
|
||
freq = tokenizer.freq
|
||
loadUserDict = tokenizer.loadUserDict
|
||
addUserDict = tokenizer.addUserDict
|
||
tradi2simp = tokenizer._tradi2simp
|
||
strQ2B = tokenizer._strQ2B
|
||
|
||
if __name__ == "__main__":
|
||
tknzr = RagTokenizer(debug=True)
|
||
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
|
||
tks = tknzr.tokenize("哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
|
||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||
print(tks)
|
||
tks = tknzr.tokenize(
|
||
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。"
|
||
)
|
||
print(tks)
|
||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||
# tks = tknzr.tokenize("多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
|
||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||
# tks = tknzr.tokenize("实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
|
||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||
# tks = tknzr.tokenize("虽然我不怎么玩")
|
||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||
# tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
|
||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||
# tks = tknzr.tokenize("涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
|
||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||
# tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
|
||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||
# tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
|
||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||
# tks = tknzr.tokenize("数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
|
||
# logging.info(tknzr.fine_grained_tokenize(tks))
|
||
# if len(sys.argv) < 2:
|
||
# sys.exit()
|
||
# tknzr.DEBUG = False
|
||
# tknzr.loadUserDict(sys.argv[1])
|
||
# of = open(sys.argv[2], "r")
|
||
# while True:
|
||
# line = of.readline()
|
||
# if not line:
|
||
# break
|
||
# logging.info(tknzr.tokenize(line))
|
||
# of.close()
|