fix: 修复Ollama嵌入模型接口和前台不一致的问题 issue#65

This commit is contained in:
zstar 2025-06-05 14:35:23 +08:00
parent c1c517c4c0
commit 3496fff83c
5 changed files with 228 additions and 103 deletions

View File

@ -52,16 +52,8 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [
cls.model.llm_factory,
LLMFactories.logo,
LLMFactories.tags,
cls.model.model_type,
cls.model.llm_name,
cls.model.used_tokens
]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
return list(objs)
@ -117,8 +109,7 @@ class TenantLLMService(CommonService):
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
if not model_config:
if mdlnm == "flag-embedding":
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
"llm_name": llm_name, "api_base": ""}
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
else:
if not mdlnm:
raise LookupError(f"Type of {llm_type} model is not set.")
@ -127,43 +118,32 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type,
llm_name=None, lang="Chinese"):
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"):
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return
return EmbeddingModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel:
return
return RerankModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return
return CvModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], lang,
base_url=model_config["api_base"]
)
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"])
if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return
return ChatModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
return
return Seq2txtModel[model_config["llm_factory"]](
key=model_config["api_key"], model_name=model_config["llm_name"],
lang=lang,
base_url=model_config["api_base"]
)
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
if llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return
@ -187,7 +167,7 @@ class TenantLLMService(CommonService):
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
}
mdlnm = llm_map.get(llm_type)
@ -198,17 +178,13 @@ class TenantLLMService(CommonService):
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
try:
num = cls.model.update(
used_tokens=cls.model.used_tokens + used_tokens
).where(
cls.model.tenant_id == tenant_id,
cls.model.llm_name == llm_name,
cls.model.llm_factory == llm_factory if llm_factory else True
).execute()
num = (
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
.execute()
)
except Exception:
logging.exception(
"TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s",
tenant_id, llm_name)
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
return 0
return num
@ -216,11 +192,7 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_openai_models(cls):
objs = cls.model.select().where(
(cls.model.llm_factory == "OpenAI"),
~(cls.model.llm_name == "text-embedding-3-small"),
~(cls.model.llm_name == "text-embedding-3-large")
).dicts()
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs)
@ -229,79 +201,59 @@ class LLMBundle:
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(
tenant_id, llm_type, llm_name, lang=lang)
assert self.mdl, "Can't find model for {}/{}/{}".format(
tenant_id, llm_type, llm_name)
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang)
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
self.max_length = model_config.get("max_tokens", 8192)
def encode(self, texts: list):
embeddings, used_tokens = self.mdl.encode(texts)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
logging.error(
"LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
return embeddings, used_tokens
def encode_queries(self, query: str):
emd, used_tokens = self.mdl.encode_queries(query)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
logging.error(
"LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
return emd, used_tokens
def similarity(self, query: str, texts: list):
sim, used_tokens = self.mdl.similarity(query, texts)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
logging.error(
"LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
return sim, used_tokens
def describe(self, image, max_tokens=300):
txt, used_tokens = self.mdl.describe(image, max_tokens)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
logging.error(
"LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
return txt
def transcription(self, audio):
txt, used_tokens = self.mdl.transcription(audio)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
logging.error(
"LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
return txt
def tts(self, text):
for chunk in self.mdl.tts(text):
if isinstance(chunk, int):
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, chunk, self.llm_name):
logging.error(
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
return
yield chunk
def chat(self, system, history, gen_conf):
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
if isinstance(txt, int) and not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error(
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name,
used_tokens))
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
return txt
def chat_streamly(self, system, history, gen_conf):
for txt in self.mdl.chat_streamly(system, history, gen_conf):
if isinstance(txt, int):
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error(
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name,
txt))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
return
yield txt

View File

@ -0,0 +1,110 @@
import requests
import time
# Ollama配置
OLLAMA_HOST = "http://localhost:11434" # 默认Ollama地址
MODEL_NAME = "bge-m3" # 使用的embedding模型
TEXT_TO_EMBED = "测试文本"
# 定义接口URL和对应的请求体结构
ENDPOINTS = {
"api/embeddings": {
"url": f"{OLLAMA_HOST}/api/embeddings", # 原生API路径
"payload": {"model": MODEL_NAME, "prompt": TEXT_TO_EMBED}, # 原生API用prompt字段
},
"v1/embeddings": {
"url": f"{OLLAMA_HOST}/v1/embeddings", # OpenAI兼容API路径
"payload": {"model": MODEL_NAME, "input": TEXT_TO_EMBED}, # OpenAI兼容API用input字段
},
}
headers = {"Content-Type": "application/json"}
def test_endpoint(endpoint_name, endpoint_info):
"""测试单个端点并返回结果"""
print(f"\n测试接口: {endpoint_name}")
url = endpoint_info["url"]
payload = endpoint_info["payload"]
try:
start_time = time.time()
response = requests.post(url, headers=headers, json=payload)
response_time = time.time() - start_time
print(f"状态码: {response.status_code}")
print(f"响应时间: {response_time:.3f}")
try:
data = response.json()
# 处理不同接口的响应结构差异
embedding = None
if endpoint_name == "api/embeddings":
embedding = data.get("embedding") # 原生API返回embedding字段
elif endpoint_name == "v1/embeddings":
embedding = data.get("data", [{}])[0].get("embedding") # OpenAI兼容API返回data数组中的embedding
if embedding:
print(f"Embedding向量长度: {len(embedding)}")
return {
"endpoint": endpoint_name,
"status_code": response.status_code,
"response_time": response_time,
"embedding_length": len(embedding),
"embedding": embedding[:5],
}
else:
print("响应中未找到'embedding'字段")
return {"endpoint": endpoint_name, "status_code": response.status_code, "error": "No embedding field in response"}
except ValueError:
print("响应不是有效的JSON格式")
return {"endpoint": endpoint_name, "status_code": response.status_code, "error": "Invalid JSON response"}
except Exception as e:
print(f"请求失败: {str(e)}")
return {"endpoint": endpoint_name, "error": str(e)}
def compare_endpoints():
"""比较两个端点的性能"""
results = []
print("=" * 50)
print(f"开始比较Ollama的embeddings接口使用模型: {MODEL_NAME}")
print("=" * 50)
for endpoint_name, endpoint_info in ENDPOINTS.items():
results.append(test_endpoint(endpoint_name, endpoint_info))
print("\n" + "=" * 50)
print("比较结果摘要:")
print("=" * 50)
successful_results = [res for res in results if "embedding_length" in res]
if len(successful_results) == 2:
if successful_results[0]["embedding_length"] == successful_results[1]["embedding_length"]:
print(f"两个接口返回的embedding维度相同: {successful_results[0]['embedding_length']}")
else:
print("两个接口返回的embedding维度不同:")
for result in successful_results:
print(f"- {result['endpoint']}: {result['embedding_length']}")
print("\nEmbedding前5个元素示例:")
for result in successful_results:
print(f"- {result['endpoint']}: {result['embedding']}")
faster = min(successful_results, key=lambda x: x["response_time"])
slower = max(successful_results, key=lambda x: x["response_time"])
print(f"\n更快的接口: {faster['endpoint']} ({faster['response_time']:.3f}秒 vs {slower['response_time']:.3f}秒)")
else:
print("至少有一个接口未返回有效的embedding数据")
for result in results:
if "error" in result:
print(f"- {result['endpoint']} 错误: {result['error']}")
if __name__ == "__main__":
compare_endpoints()

View File

@ -0,0 +1,64 @@
import numpy as np
from abc import ABC
from ollama import Client
class Base(ABC):
def __init__(self, key, model_name):
pass
def encode(self, texts: list):
raise NotImplementedError("Please implement encode method!")
def encode_queries(self, text: str):
raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
class OllamaEmbed(Base):
def __init__(self, model_name, **kwargs):
self.client = Client(host="http://localhost:11434", **kwargs)
self.model_name = model_name
def encode(self, texts: list):
arr = []
tks_num = 0
for txt in texts:
res = self.client.embeddings(prompt=txt, model=self.model_name)
arr.append(res["embedding"])
tks_num += 128
return np.array(arr), tks_num
if __name__ == "__main__":
# 初始化嵌入模型
embedder = OllamaEmbed(model_name="bge-m3")
# 测试文本
test_texts = ["测试文本"]
# 获取嵌入向量和token计数
embeddings, total_tokens = embedder.encode(test_texts)
# 打印结果
print(f"Total tokens used: {total_tokens}")
print("\nEmbedding vectors:")
for i, (text, embedding) in enumerate(zip(test_texts, embeddings)):
print(f"\nText {i + 1}: '{text}'")
print(f"Embedding shape: {embedding.shape}")
print(f"First 5 values: {embedding[:5]}")
print(f"Embedding dtype: {embedding.dtype}")
# 打印完整的第一个embedding向量
print("\nComplete first embedding vector:")
print(embeddings[0])

View File

@ -5,23 +5,17 @@ import requests
url = "https://api.siliconflow.cn/v1/embeddings"
api_key = "你的API密钥" # 替换为你的API密钥
payload = {
"model": "BAAI/bge-m3",
"input": "Silicon flow embedding online: fast, affordable, and high-quality embedding services. come try it out!"
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {"model": "BAAI/bge-m3", "input": "测试文本"}
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
response = requests.request("POST", url, json=payload, headers=headers)
print(response.text)
# print(response.text.data)
# print(response.text.data)
# embedding_resp = response
# embedding_data = embedding_resp.json()
# q_1024_vec = embedding_data["data"][0]["embedding"]
# print("q_1024_vec", q_1024_vec)
# print("q_1024_vec", q_1024_vec)

View File

@ -248,21 +248,20 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
if not embedding_api_base.startswith(("http://", "https://")):
embedding_api_base = "http://" + embedding_api_base
# --- URL 拼接优化 (处理 /v1) ---
endpoint_segment = "embeddings"
full_endpoint_path = "v1/embeddings"
# 移除末尾斜杠以方便判断
normalized_base_url = embedding_api_base.rstrip("/")
if normalized_base_url.endswith("/v1"):
# 如果 base_url 已经是 http://host/v1 形式
embedding_url = normalized_base_url + "/" + endpoint_segment
# 如果请求url端口号为11434则认为是ollama模型采用ollama特定的api
is_ollama = "11434" in normalized_base_url
if is_ollama:
# Ollama 的特殊接口路径
embedding_url = normalized_base_url + "/api/embeddings"
elif normalized_base_url.endswith("/v1"):
embedding_url = normalized_base_url + "/embeddings"
elif normalized_base_url.endswith("/embeddings"):
# 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API无需再进行处理)
embedding_url = normalized_base_url
else:
# 如果 base_url 是 http://host 或 http://host/api 等其他形式
embedding_url = normalized_base_url + "/" + full_endpoint_path
embedding_url = normalized_base_url + "/v1/embeddings"
print(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}")
@ -535,8 +534,14 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
embedding_resp.raise_for_status()
embedding_data = embedding_resp.json()
q_1024_vec = embedding_data["data"][0]["embedding"]
# 对ollama嵌入模型的接口返回值进行特殊处理
if is_ollama:
q_1024_vec = embedding_data.get("embedding")
else:
q_1024_vec = embedding_data["data"][0]["embedding"]
print(f"[Parser-INFO] 获取embedding成功长度: {len(q_1024_vec)}")
# 检查向量维度是否为1024
if len(q_1024_vec) != 1024:
error_msg = f"[Parser-ERROR] Embedding向量维度不是1024实际维度: {len(q_1024_vec)}, 建议使用bge-m3模型"