feat(文档解析): 适配硅基流动平台并优化Embedding配置处理 (#97) (#97)

This commit is contained in:
zstar 2025-05-16 13:48:16 +08:00 committed by GitHub
parent d0d7a24297
commit 8ce493003b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 65 additions and 6 deletions

View File

@ -96,6 +96,10 @@ python -m api.ragflow_server
pnpm dev
```
> [!NOTE]
> 源码部署需要注意如果用到MinerU后台解析需要参考MinerU的文档下载模型文件并安装LibreOffice配置环境变量以适配支持除pdf之外的类型文件。
## 📝 常见问题
参见[常见问题](docs/faq.md)

View File

@ -22,7 +22,7 @@ from utils import generate_uuid
# 自定义tokenizer和文本处理函数替代rag.nlp中的功能
def tokenize_text(text):
"""将文本分词替代rag_tokenizer功能"""
# 简单实现,实际应用中可能需要更复杂的分词逻辑
# 简单实现,未来可能需要改成更复杂的分词逻辑
return text.split()
@ -146,8 +146,8 @@ def _create_task_record(doc_id, chunk_ids_list):
INSERT INTO task (
id, create_time, create_date, update_time, update_date,
doc_id, from_page, to_page, begin_at, process_duation,
progress, progress_msg, retry_count, digest, chunk_ids, task_type
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
progress, progress_msg, retry_count, digest, chunk_ids, task_type, priority
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
task_params = [
task_id,
@ -159,13 +159,14 @@ def _create_task_record(doc_id, chunk_ids_list):
0,
1,
None,
0.0, # begin_at, process_duration
0.0,
1.0,
"MinerU解析完成",
1,
digest,
chunk_ids_str,
"", # progress, msg, retry, digest, chunks, type
"",
0
]
cursor.execute(task_insert, task_params)
conn.commit()
@ -274,7 +275,18 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
# 对模型名称进行处理
if embedding_model_name and "___" in embedding_model_name:
embedding_model_name = embedding_model_name.split("___")[0]
# 替换特定模型名称(对硅基流动平台进行特异性处理)
if embedding_model_name == "netease-youdao/bce-embedding-base_v1":
embedding_model_name = "BAAI/bge-m3"
embedding_api_base = embedding_config.get("api_base") if embedding_config and embedding_config.get("api_base") else "http://localhost:8000" # 默认基础 URL
# 如果 API 基础地址为空字符串,设置为硅基流动的 API 地址
if embedding_api_base == "":
embedding_api_base = "https://api.siliconflow.cn/v1/embeddings"
print(f"[Parser-INFO] API 基础地址为空,已设置为硅基流动的 API 地址: {embedding_api_base}")
embedding_api_key = embedding_config.get("api_key") if embedding_config else None # 可能为 None 或空字符串
# 构建完整的 Embedding API URL
@ -293,6 +305,9 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config):
if normalized_base_url.endswith("/v1"):
# 如果 base_url 已经是 http://host/v1 形式
embedding_url = normalized_base_url + "/" + endpoint_segment
elif normalized_base_url.endswith('/embeddings'):
# 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API无需再进行处理)
embedding_url = normalized_base_url
else:
# 如果 base_url 是 http://host 或 http://host/api 等其他形式
embedding_url = normalized_base_url + "/" + full_endpoint_path

View File

@ -915,6 +915,9 @@ class KnowledgebaseService:
if normalized_base_url.endswith('/v1'):
# 如果 base_url 已经是 http://host/v1 形式
current_test_url = normalized_base_url + '/' + endpoint_segment
elif normalized_base_url.endswith('/embeddings'):
# 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API无需再进行处理)
current_test_url = normalized_base_url
else:
# 如果 base_url 是 http://host 或 http://host/api 形式
current_test_url = normalized_base_url + '/' + full_endpoint_path
@ -991,6 +994,15 @@ class KnowledgebaseService:
# 对模型名称进行处理 (可选,根据需要保留或移除)
if llm_name and '___' in llm_name:
llm_name = llm_name.split('___')[0]
# (对硅基流动平台进行特异性处理)
if llm_name == "netease-youdao/bce-embedding-base_v1":
llm_name = "BAAI/bge-m3"
# 如果 API 基础地址为空字符串,设置为硅基流动嵌入模型的 API 地址
if api_base == "":
api_base = "https://api.siliconflow.cn/v1/embeddings"
# 如果有配置,返回
return {
"llm_name": llm_name,
@ -1023,6 +1035,7 @@ class KnowledgebaseService:
if not tenant_id:
raise Exception("无法找到系统基础用户")
print(f"开始设置系统 Embedding 配置: {llm_name}, {api_base}, {api_key}")
# 执行连接测试
is_connected, message = cls._test_embedding_connection(
base_url=api_base,

View File

@ -0,0 +1,27 @@
# 用于测试siliconflow的embedding model连通性
import requests
url = "https://api.siliconflow.cn/v1/embeddings"
api_key = "你的API密钥" # 替换为你的API密钥
payload = {
"model": "BAAI/bge-m3",
"input": "Silicon flow embedding online: fast, affordable, and high-quality embedding services. come try it out!"
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
response = requests.request("POST", url, json=payload, headers=headers)
print(response.text)
# print(response.text.data)
# embedding_resp = response
# embedding_data = embedding_resp.json()
# q_1024_vec = embedding_data["data"][0]["embedding"]
# print("q_1024_vec", q_1024_vec)