fix: 修复Ollama嵌入模型接口和前台不一致的问题 issue#65
This commit is contained in:
parent
c1c517c4c0
commit
3496fff83c
|
@ -52,16 +52,8 @@ class TenantLLMService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [
|
||||
cls.model.llm_factory,
|
||||
LLMFactories.logo,
|
||||
LLMFactories.tags,
|
||||
cls.model.model_type,
|
||||
cls.model.llm_name,
|
||||
cls.model.used_tokens
|
||||
]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
|
||||
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
|
||||
return list(objs)
|
||||
|
||||
|
@ -117,8 +109,7 @@ class TenantLLMService(CommonService):
|
|||
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
||||
if not model_config:
|
||||
if mdlnm == "flag-embedding":
|
||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
||||
"llm_name": llm_name, "api_base": ""}
|
||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
||||
else:
|
||||
if not mdlnm:
|
||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||
|
@ -127,43 +118,32 @@ class TenantLLMService(CommonService):
|
|||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type,
|
||||
llm_name=None, lang="Chinese"):
|
||||
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"):
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return
|
||||
return EmbeddingModel[model_config["llm_factory"]](
|
||||
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.RERANK:
|
||||
if model_config["llm_factory"] not in RerankModel:
|
||||
return
|
||||
return RerankModel[model_config["llm_factory"]](
|
||||
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return
|
||||
return CvModel[model_config["llm_factory"]](
|
||||
model_config["api_key"], model_config["llm_name"], lang,
|
||||
base_url=model_config["api_base"]
|
||||
)
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return
|
||||
return ChatModel[model_config["llm_factory"]](
|
||||
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
return
|
||||
return Seq2txtModel[model_config["llm_factory"]](
|
||||
key=model_config["api_key"], model_name=model_config["llm_name"],
|
||||
lang=lang,
|
||||
base_url=model_config["api_base"]
|
||||
)
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||
if llm_type == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return
|
||||
|
@ -187,7 +167,7 @@ class TenantLLMService(CommonService):
|
|||
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
||||
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
||||
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
||||
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name
|
||||
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
||||
}
|
||||
|
||||
mdlnm = llm_map.get(llm_type)
|
||||
|
@ -198,17 +178,13 @@ class TenantLLMService(CommonService):
|
|||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
|
||||
try:
|
||||
num = cls.model.update(
|
||||
used_tokens=cls.model.used_tokens + used_tokens
|
||||
).where(
|
||||
cls.model.tenant_id == tenant_id,
|
||||
cls.model.llm_name == llm_name,
|
||||
cls.model.llm_factory == llm_factory if llm_factory else True
|
||||
).execute()
|
||||
num = (
|
||||
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||
.execute()
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s",
|
||||
tenant_id, llm_name)
|
||||
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
||||
return 0
|
||||
|
||||
return num
|
||||
|
@ -216,11 +192,7 @@ class TenantLLMService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_openai_models(cls):
|
||||
objs = cls.model.select().where(
|
||||
(cls.model.llm_factory == "OpenAI"),
|
||||
~(cls.model.llm_name == "text-embedding-3-small"),
|
||||
~(cls.model.llm_name == "text-embedding-3-large")
|
||||
).dicts()
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
|
||||
|
@ -229,79 +201,59 @@ class LLMBundle:
|
|||
self.tenant_id = tenant_id
|
||||
self.llm_type = llm_type
|
||||
self.llm_name = llm_name
|
||||
self.mdl = TenantLLMService.model_instance(
|
||||
tenant_id, llm_type, llm_name, lang=lang)
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(
|
||||
tenant_id, llm_type, llm_name)
|
||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang)
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
self.max_length = model_config.get("max_tokens", 8192)
|
||||
|
||||
def encode(self, texts: list):
|
||||
embeddings, used_tokens = self.mdl.encode(texts)
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error(
|
||||
"LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
return embeddings, used_tokens
|
||||
|
||||
def encode_queries(self, query: str):
|
||||
emd, used_tokens = self.mdl.encode_queries(query)
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error(
|
||||
"LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
return emd, used_tokens
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
sim, used_tokens = self.mdl.similarity(query, texts)
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error(
|
||||
"LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
return sim, used_tokens
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
txt, used_tokens = self.mdl.describe(image, max_tokens)
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error(
|
||||
"LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
return txt
|
||||
|
||||
def transcription(self, audio):
|
||||
txt, used_tokens = self.mdl.transcription(audio)
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error(
|
||||
"LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
return txt
|
||||
|
||||
def tts(self, text):
|
||||
for chunk in self.mdl.tts(text):
|
||||
if isinstance(chunk, int):
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, chunk, self.llm_name):
|
||||
logging.error(
|
||||
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
|
||||
logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
||||
return
|
||||
yield chunk
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
||||
if isinstance(txt, int) and not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error(
|
||||
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name,
|
||||
used_tokens))
|
||||
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||
return txt
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
for txt in self.mdl.chat_streamly(system, history, gen_conf):
|
||||
if isinstance(txt, int):
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, txt, self.llm_name):
|
||||
logging.error(
|
||||
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name,
|
||||
txt))
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
|
||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
|
||||
return
|
||||
yield txt
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
import requests
|
||||
import time
|
||||
|
||||
# Ollama配置
|
||||
OLLAMA_HOST = "http://localhost:11434" # 默认Ollama地址
|
||||
MODEL_NAME = "bge-m3" # 使用的embedding模型
|
||||
TEXT_TO_EMBED = "测试文本"
|
||||
|
||||
# 定义接口URL和对应的请求体结构
|
||||
ENDPOINTS = {
|
||||
"api/embeddings": {
|
||||
"url": f"{OLLAMA_HOST}/api/embeddings", # 原生API路径
|
||||
"payload": {"model": MODEL_NAME, "prompt": TEXT_TO_EMBED}, # 原生API用prompt字段
|
||||
},
|
||||
"v1/embeddings": {
|
||||
"url": f"{OLLAMA_HOST}/v1/embeddings", # OpenAI兼容API路径
|
||||
"payload": {"model": MODEL_NAME, "input": TEXT_TO_EMBED}, # OpenAI兼容API用input字段
|
||||
},
|
||||
}
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
def test_endpoint(endpoint_name, endpoint_info):
|
||||
"""测试单个端点并返回结果"""
|
||||
print(f"\n测试接口: {endpoint_name}")
|
||||
url = endpoint_info["url"]
|
||||
payload = endpoint_info["payload"]
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应时间: {response_time:.3f}秒")
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
|
||||
# 处理不同接口的响应结构差异
|
||||
embedding = None
|
||||
if endpoint_name == "api/embeddings":
|
||||
embedding = data.get("embedding") # 原生API返回embedding字段
|
||||
elif endpoint_name == "v1/embeddings":
|
||||
embedding = data.get("data", [{}])[0].get("embedding") # OpenAI兼容API返回data数组中的embedding
|
||||
|
||||
if embedding:
|
||||
print(f"Embedding向量长度: {len(embedding)}")
|
||||
return {
|
||||
"endpoint": endpoint_name,
|
||||
"status_code": response.status_code,
|
||||
"response_time": response_time,
|
||||
"embedding_length": len(embedding),
|
||||
"embedding": embedding[:5],
|
||||
}
|
||||
else:
|
||||
print("响应中未找到'embedding'字段")
|
||||
return {"endpoint": endpoint_name, "status_code": response.status_code, "error": "No embedding field in response"}
|
||||
|
||||
except ValueError:
|
||||
print("响应不是有效的JSON格式")
|
||||
return {"endpoint": endpoint_name, "status_code": response.status_code, "error": "Invalid JSON response"}
|
||||
|
||||
except Exception as e:
|
||||
print(f"请求失败: {str(e)}")
|
||||
return {"endpoint": endpoint_name, "error": str(e)}
|
||||
|
||||
|
||||
def compare_endpoints():
|
||||
"""比较两个端点的性能"""
|
||||
results = []
|
||||
|
||||
print("=" * 50)
|
||||
print(f"开始比较Ollama的embeddings接口,使用模型: {MODEL_NAME}")
|
||||
print("=" * 50)
|
||||
|
||||
for endpoint_name, endpoint_info in ENDPOINTS.items():
|
||||
results.append(test_endpoint(endpoint_name, endpoint_info))
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("比较结果摘要:")
|
||||
print("=" * 50)
|
||||
|
||||
successful_results = [res for res in results if "embedding_length" in res]
|
||||
|
||||
if len(successful_results) == 2:
|
||||
if successful_results[0]["embedding_length"] == successful_results[1]["embedding_length"]:
|
||||
print(f"两个接口返回的embedding维度相同: {successful_results[0]['embedding_length']}")
|
||||
else:
|
||||
print("两个接口返回的embedding维度不同:")
|
||||
for result in successful_results:
|
||||
print(f"- {result['endpoint']}: {result['embedding_length']}")
|
||||
|
||||
print("\nEmbedding前5个元素示例:")
|
||||
for result in successful_results:
|
||||
print(f"- {result['endpoint']}: {result['embedding']}")
|
||||
|
||||
faster = min(successful_results, key=lambda x: x["response_time"])
|
||||
slower = max(successful_results, key=lambda x: x["response_time"])
|
||||
print(f"\n更快的接口: {faster['endpoint']} ({faster['response_time']:.3f}秒 vs {slower['response_time']:.3f}秒)")
|
||||
else:
|
||||
print("至少有一个接口未返回有效的embedding数据")
|
||||
for result in results:
|
||||
if "error" in result:
|
||||
print(f"- {result['endpoint']} 错误: {result['error']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
compare_endpoints()
|
|
@ -0,0 +1,64 @@
|
|||
import numpy as np
|
||||
from abc import ABC
|
||||
from ollama import Client
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
pass
|
||||
|
||||
def encode(self, texts: list):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def total_token_count(self, resp):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
class OllamaEmbed(Base):
|
||||
def __init__(self, model_name, **kwargs):
|
||||
self.client = Client(host="http://localhost:11434", **kwargs)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
arr = []
|
||||
tks_num = 0
|
||||
for txt in texts:
|
||||
res = self.client.embeddings(prompt=txt, model=self.model_name)
|
||||
arr.append(res["embedding"])
|
||||
tks_num += 128
|
||||
return np.array(arr), tks_num
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 初始化嵌入模型
|
||||
embedder = OllamaEmbed(model_name="bge-m3")
|
||||
|
||||
# 测试文本
|
||||
test_texts = ["测试文本"]
|
||||
|
||||
# 获取嵌入向量和token计数
|
||||
embeddings, total_tokens = embedder.encode(test_texts)
|
||||
|
||||
# 打印结果
|
||||
print(f"Total tokens used: {total_tokens}")
|
||||
print("\nEmbedding vectors:")
|
||||
for i, (text, embedding) in enumerate(zip(test_texts, embeddings)):
|
||||
print(f"\nText {i + 1}: '{text}'")
|
||||
print(f"Embedding shape: {embedding.shape}")
|
||||
print(f"First 5 values: {embedding[:5]}")
|
||||
print(f"Embedding dtype: {embedding.dtype}")
|
||||
|
||||
# 打印完整的第一个embedding向量
|
||||
print("\nComplete first embedding vector:")
|
||||
print(embeddings[0])
|
|
@ -5,23 +5,17 @@ import requests
|
|||
url = "https://api.siliconflow.cn/v1/embeddings"
|
||||
api_key = "你的API密钥" # 替换为你的API密钥
|
||||
|
||||
payload = {
|
||||
"model": "BAAI/bge-m3",
|
||||
"input": "Silicon flow embedding online: fast, affordable, and high-quality embedding services. come try it out!"
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
payload = {"model": "BAAI/bge-m3", "input": "测试文本"}
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
response = requests.request("POST", url, json=payload, headers=headers)
|
||||
|
||||
print(response.text)
|
||||
|
||||
# print(response.text.data)
|
||||
# print(response.text.data)
|
||||
|
||||
# embedding_resp = response
|
||||
# embedding_data = embedding_resp.json()
|
||||
# q_1024_vec = embedding_data["data"][0]["embedding"]
|
||||
|
||||
# print("q_1024_vec", q_1024_vec)
|
||||
# print("q_1024_vec", q_1024_vec)
|
||||
|
|
|
@ -248,21 +248,20 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
|
|||
if not embedding_api_base.startswith(("http://", "https://")):
|
||||
embedding_api_base = "http://" + embedding_api_base
|
||||
|
||||
# --- URL 拼接优化 (处理 /v1) ---
|
||||
endpoint_segment = "embeddings"
|
||||
full_endpoint_path = "v1/embeddings"
|
||||
# 移除末尾斜杠以方便判断
|
||||
normalized_base_url = embedding_api_base.rstrip("/")
|
||||
|
||||
if normalized_base_url.endswith("/v1"):
|
||||
# 如果 base_url 已经是 http://host/v1 形式
|
||||
embedding_url = normalized_base_url + "/" + endpoint_segment
|
||||
# 如果请求url端口号为11434,则认为是ollama模型,采用ollama特定的api
|
||||
is_ollama = "11434" in normalized_base_url
|
||||
if is_ollama:
|
||||
# Ollama 的特殊接口路径
|
||||
embedding_url = normalized_base_url + "/api/embeddings"
|
||||
elif normalized_base_url.endswith("/v1"):
|
||||
embedding_url = normalized_base_url + "/embeddings"
|
||||
elif normalized_base_url.endswith("/embeddings"):
|
||||
# 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API,无需再进行处理)
|
||||
embedding_url = normalized_base_url
|
||||
else:
|
||||
# 如果 base_url 是 http://host 或 http://host/api 等其他形式
|
||||
embedding_url = normalized_base_url + "/" + full_endpoint_path
|
||||
embedding_url = normalized_base_url + "/v1/embeddings"
|
||||
|
||||
print(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}")
|
||||
|
||||
|
@ -535,8 +534,14 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
|
|||
|
||||
embedding_resp.raise_for_status()
|
||||
embedding_data = embedding_resp.json()
|
||||
q_1024_vec = embedding_data["data"][0]["embedding"]
|
||||
|
||||
# 对ollama嵌入模型的接口返回值进行特殊处理
|
||||
if is_ollama:
|
||||
q_1024_vec = embedding_data.get("embedding")
|
||||
else:
|
||||
q_1024_vec = embedding_data["data"][0]["embedding"]
|
||||
print(f"[Parser-INFO] 获取embedding成功,长度: {len(q_1024_vec)}")
|
||||
|
||||
# 检查向量维度是否为1024
|
||||
if len(q_1024_vec) != 1024:
|
||||
error_msg = f"[Parser-ERROR] Embedding向量维度不是1024,实际维度: {len(q_1024_vec)}, 建议使用bge-m3模型"
|
||||
|
|
Loading…
Reference in New Issue