fix: 修复Ollama嵌入模型接口和前台不一致的问题 issue#65
This commit is contained in:
parent
c1c517c4c0
commit
3496fff83c
|
@ -52,16 +52,8 @@ class TenantLLMService(CommonService):
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_my_llms(cls, tenant_id):
|
def get_my_llms(cls, tenant_id):
|
||||||
fields = [
|
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
||||||
cls.model.llm_factory,
|
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||||
LLMFactories.logo,
|
|
||||||
LLMFactories.tags,
|
|
||||||
cls.model.model_type,
|
|
||||||
cls.model.llm_name,
|
|
||||||
cls.model.used_tokens
|
|
||||||
]
|
|
||||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
|
|
||||||
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
|
||||||
|
|
||||||
return list(objs)
|
return list(objs)
|
||||||
|
|
||||||
|
@ -117,8 +109,7 @@ class TenantLLMService(CommonService):
|
||||||
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
||||||
if not model_config:
|
if not model_config:
|
||||||
if mdlnm == "flag-embedding":
|
if mdlnm == "flag-embedding":
|
||||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
||||||
"llm_name": llm_name, "api_base": ""}
|
|
||||||
else:
|
else:
|
||||||
if not mdlnm:
|
if not mdlnm:
|
||||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||||
|
@ -127,43 +118,32 @@ class TenantLLMService(CommonService):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def model_instance(cls, tenant_id, llm_type,
|
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"):
|
||||||
llm_name=None, lang="Chinese"):
|
|
||||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||||
if llm_type == LLMType.EMBEDDING.value:
|
if llm_type == LLMType.EMBEDDING.value:
|
||||||
if model_config["llm_factory"] not in EmbeddingModel:
|
if model_config["llm_factory"] not in EmbeddingModel:
|
||||||
return
|
return
|
||||||
return EmbeddingModel[model_config["llm_factory"]](
|
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||||
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
|
||||||
|
|
||||||
if llm_type == LLMType.RERANK:
|
if llm_type == LLMType.RERANK:
|
||||||
if model_config["llm_factory"] not in RerankModel:
|
if model_config["llm_factory"] not in RerankModel:
|
||||||
return
|
return
|
||||||
return RerankModel[model_config["llm_factory"]](
|
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||||
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
|
||||||
|
|
||||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||||
if model_config["llm_factory"] not in CvModel:
|
if model_config["llm_factory"] not in CvModel:
|
||||||
return
|
return
|
||||||
return CvModel[model_config["llm_factory"]](
|
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"])
|
||||||
model_config["api_key"], model_config["llm_name"], lang,
|
|
||||||
base_url=model_config["api_base"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if llm_type == LLMType.CHAT.value:
|
if llm_type == LLMType.CHAT.value:
|
||||||
if model_config["llm_factory"] not in ChatModel:
|
if model_config["llm_factory"] not in ChatModel:
|
||||||
return
|
return
|
||||||
return ChatModel[model_config["llm_factory"]](
|
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||||
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
|
||||||
|
|
||||||
if llm_type == LLMType.SPEECH2TEXT:
|
if llm_type == LLMType.SPEECH2TEXT:
|
||||||
if model_config["llm_factory"] not in Seq2txtModel:
|
if model_config["llm_factory"] not in Seq2txtModel:
|
||||||
return
|
return
|
||||||
return Seq2txtModel[model_config["llm_factory"]](
|
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||||
key=model_config["api_key"], model_name=model_config["llm_name"],
|
|
||||||
lang=lang,
|
|
||||||
base_url=model_config["api_base"]
|
|
||||||
)
|
|
||||||
if llm_type == LLMType.TTS:
|
if llm_type == LLMType.TTS:
|
||||||
if model_config["llm_factory"] not in TTSModel:
|
if model_config["llm_factory"] not in TTSModel:
|
||||||
return
|
return
|
||||||
|
@ -187,7 +167,7 @@ class TenantLLMService(CommonService):
|
||||||
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
||||||
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
||||||
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
||||||
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name
|
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
mdlnm = llm_map.get(llm_type)
|
mdlnm = llm_map.get(llm_type)
|
||||||
|
@ -198,17 +178,13 @@ class TenantLLMService(CommonService):
|
||||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
num = cls.model.update(
|
num = (
|
||||||
used_tokens=cls.model.used_tokens + used_tokens
|
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
||||||
).where(
|
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||||
cls.model.tenant_id == tenant_id,
|
.execute()
|
||||||
cls.model.llm_name == llm_name,
|
)
|
||||||
cls.model.llm_factory == llm_factory if llm_factory else True
|
|
||||||
).execute()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception(
|
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
||||||
"TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s",
|
|
||||||
tenant_id, llm_name)
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
return num
|
return num
|
||||||
|
@ -216,11 +192,7 @@ class TenantLLMService(CommonService):
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_openai_models(cls):
|
def get_openai_models(cls):
|
||||||
objs = cls.model.select().where(
|
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||||
(cls.model.llm_factory == "OpenAI"),
|
|
||||||
~(cls.model.llm_name == "text-embedding-3-small"),
|
|
||||||
~(cls.model.llm_name == "text-embedding-3-large")
|
|
||||||
).dicts()
|
|
||||||
return list(objs)
|
return list(objs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -229,79 +201,59 @@ class LLMBundle:
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.llm_type = llm_type
|
self.llm_type = llm_type
|
||||||
self.llm_name = llm_name
|
self.llm_name = llm_name
|
||||||
self.mdl = TenantLLMService.model_instance(
|
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang)
|
||||||
tenant_id, llm_type, llm_name, lang=lang)
|
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||||
assert self.mdl, "Can't find model for {}/{}/{}".format(
|
|
||||||
tenant_id, llm_type, llm_name)
|
|
||||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||||
self.max_length = model_config.get("max_tokens", 8192)
|
self.max_length = model_config.get("max_tokens", 8192)
|
||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
embeddings, used_tokens = self.mdl.encode(texts)
|
embeddings, used_tokens = self.mdl.encode(texts)
|
||||||
if not TenantLLMService.increase_usage(
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||||
self.tenant_id, self.llm_type, used_tokens):
|
logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||||
logging.error(
|
|
||||||
"LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
|
||||||
return embeddings, used_tokens
|
return embeddings, used_tokens
|
||||||
|
|
||||||
def encode_queries(self, query: str):
|
def encode_queries(self, query: str):
|
||||||
emd, used_tokens = self.mdl.encode_queries(query)
|
emd, used_tokens = self.mdl.encode_queries(query)
|
||||||
if not TenantLLMService.increase_usage(
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||||
self.tenant_id, self.llm_type, used_tokens):
|
logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||||
logging.error(
|
|
||||||
"LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
|
||||||
return emd, used_tokens
|
return emd, used_tokens
|
||||||
|
|
||||||
def similarity(self, query: str, texts: list):
|
def similarity(self, query: str, texts: list):
|
||||||
sim, used_tokens = self.mdl.similarity(query, texts)
|
sim, used_tokens = self.mdl.similarity(query, texts)
|
||||||
if not TenantLLMService.increase_usage(
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||||
self.tenant_id, self.llm_type, used_tokens):
|
logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||||
logging.error(
|
|
||||||
"LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
|
|
||||||
return sim, used_tokens
|
return sim, used_tokens
|
||||||
|
|
||||||
def describe(self, image, max_tokens=300):
|
def describe(self, image, max_tokens=300):
|
||||||
txt, used_tokens = self.mdl.describe(image, max_tokens)
|
txt, used_tokens = self.mdl.describe(image, max_tokens)
|
||||||
if not TenantLLMService.increase_usage(
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||||
self.tenant_id, self.llm_type, used_tokens):
|
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||||
logging.error(
|
|
||||||
"LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
def transcription(self, audio):
|
def transcription(self, audio):
|
||||||
txt, used_tokens = self.mdl.transcription(audio)
|
txt, used_tokens = self.mdl.transcription(audio)
|
||||||
if not TenantLLMService.increase_usage(
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||||
self.tenant_id, self.llm_type, used_tokens):
|
logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||||
logging.error(
|
|
||||||
"LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
def tts(self, text):
|
def tts(self, text):
|
||||||
for chunk in self.mdl.tts(text):
|
for chunk in self.mdl.tts(text):
|
||||||
if isinstance(chunk, int):
|
if isinstance(chunk, int):
|
||||||
if not TenantLLMService.increase_usage(
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
|
||||||
self.tenant_id, self.llm_type, chunk, self.llm_name):
|
logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
||||||
logging.error(
|
|
||||||
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
|
||||||
return
|
return
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf):
|
def chat(self, system, history, gen_conf):
|
||||||
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
||||||
if isinstance(txt, int) and not TenantLLMService.increase_usage(
|
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||||
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||||
logging.error(
|
|
||||||
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name,
|
|
||||||
used_tokens))
|
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf):
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
for txt in self.mdl.chat_streamly(system, history, gen_conf):
|
for txt in self.mdl.chat_streamly(system, history, gen_conf):
|
||||||
if isinstance(txt, int):
|
if isinstance(txt, int):
|
||||||
if not TenantLLMService.increase_usage(
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
|
||||||
self.tenant_id, self.llm_type, txt, self.llm_name):
|
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
|
||||||
logging.error(
|
|
||||||
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name,
|
|
||||||
txt))
|
|
||||||
return
|
return
|
||||||
yield txt
|
yield txt
|
||||||
|
|
|
@ -0,0 +1,110 @@
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Ollama配置
|
||||||
|
OLLAMA_HOST = "http://localhost:11434" # 默认Ollama地址
|
||||||
|
MODEL_NAME = "bge-m3" # 使用的embedding模型
|
||||||
|
TEXT_TO_EMBED = "测试文本"
|
||||||
|
|
||||||
|
# 定义接口URL和对应的请求体结构
|
||||||
|
ENDPOINTS = {
|
||||||
|
"api/embeddings": {
|
||||||
|
"url": f"{OLLAMA_HOST}/api/embeddings", # 原生API路径
|
||||||
|
"payload": {"model": MODEL_NAME, "prompt": TEXT_TO_EMBED}, # 原生API用prompt字段
|
||||||
|
},
|
||||||
|
"v1/embeddings": {
|
||||||
|
"url": f"{OLLAMA_HOST}/v1/embeddings", # OpenAI兼容API路径
|
||||||
|
"payload": {"model": MODEL_NAME, "input": TEXT_TO_EMBED}, # OpenAI兼容API用input字段
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_endpoint(endpoint_name, endpoint_info):
|
||||||
|
"""测试单个端点并返回结果"""
|
||||||
|
print(f"\n测试接口: {endpoint_name}")
|
||||||
|
url = endpoint_info["url"]
|
||||||
|
payload = endpoint_info["payload"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
response = requests.post(url, headers=headers, json=payload)
|
||||||
|
response_time = time.time() - start_time
|
||||||
|
|
||||||
|
print(f"状态码: {response.status_code}")
|
||||||
|
print(f"响应时间: {response_time:.3f}秒")
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# 处理不同接口的响应结构差异
|
||||||
|
embedding = None
|
||||||
|
if endpoint_name == "api/embeddings":
|
||||||
|
embedding = data.get("embedding") # 原生API返回embedding字段
|
||||||
|
elif endpoint_name == "v1/embeddings":
|
||||||
|
embedding = data.get("data", [{}])[0].get("embedding") # OpenAI兼容API返回data数组中的embedding
|
||||||
|
|
||||||
|
if embedding:
|
||||||
|
print(f"Embedding向量长度: {len(embedding)}")
|
||||||
|
return {
|
||||||
|
"endpoint": endpoint_name,
|
||||||
|
"status_code": response.status_code,
|
||||||
|
"response_time": response_time,
|
||||||
|
"embedding_length": len(embedding),
|
||||||
|
"embedding": embedding[:5],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print("响应中未找到'embedding'字段")
|
||||||
|
return {"endpoint": endpoint_name, "status_code": response.status_code, "error": "No embedding field in response"}
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
print("响应不是有效的JSON格式")
|
||||||
|
return {"endpoint": endpoint_name, "status_code": response.status_code, "error": "Invalid JSON response"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"请求失败: {str(e)}")
|
||||||
|
return {"endpoint": endpoint_name, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def compare_endpoints():
|
||||||
|
"""比较两个端点的性能"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"开始比较Ollama的embeddings接口,使用模型: {MODEL_NAME}")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
for endpoint_name, endpoint_info in ENDPOINTS.items():
|
||||||
|
results.append(test_endpoint(endpoint_name, endpoint_info))
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("比较结果摘要:")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
successful_results = [res for res in results if "embedding_length" in res]
|
||||||
|
|
||||||
|
if len(successful_results) == 2:
|
||||||
|
if successful_results[0]["embedding_length"] == successful_results[1]["embedding_length"]:
|
||||||
|
print(f"两个接口返回的embedding维度相同: {successful_results[0]['embedding_length']}")
|
||||||
|
else:
|
||||||
|
print("两个接口返回的embedding维度不同:")
|
||||||
|
for result in successful_results:
|
||||||
|
print(f"- {result['endpoint']}: {result['embedding_length']}")
|
||||||
|
|
||||||
|
print("\nEmbedding前5个元素示例:")
|
||||||
|
for result in successful_results:
|
||||||
|
print(f"- {result['endpoint']}: {result['embedding']}")
|
||||||
|
|
||||||
|
faster = min(successful_results, key=lambda x: x["response_time"])
|
||||||
|
slower = max(successful_results, key=lambda x: x["response_time"])
|
||||||
|
print(f"\n更快的接口: {faster['endpoint']} ({faster['response_time']:.3f}秒 vs {slower['response_time']:.3f}秒)")
|
||||||
|
else:
|
||||||
|
print("至少有一个接口未返回有效的embedding数据")
|
||||||
|
for result in results:
|
||||||
|
if "error" in result:
|
||||||
|
print(f"- {result['endpoint']} 错误: {result['error']}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
compare_endpoints()
|
|
@ -0,0 +1,64 @@
|
||||||
|
import numpy as np
|
||||||
|
from abc import ABC
|
||||||
|
from ollama import Client
|
||||||
|
|
||||||
|
|
||||||
|
class Base(ABC):
|
||||||
|
def __init__(self, key, model_name):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def encode(self, texts: list):
|
||||||
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
|
def encode_queries(self, text: str):
|
||||||
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
|
def total_token_count(self, resp):
|
||||||
|
try:
|
||||||
|
return resp.usage.total_tokens
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return resp["usage"]["total_tokens"]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaEmbed(Base):
|
||||||
|
def __init__(self, model_name, **kwargs):
|
||||||
|
self.client = Client(host="http://localhost:11434", **kwargs)
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def encode(self, texts: list):
|
||||||
|
arr = []
|
||||||
|
tks_num = 0
|
||||||
|
for txt in texts:
|
||||||
|
res = self.client.embeddings(prompt=txt, model=self.model_name)
|
||||||
|
arr.append(res["embedding"])
|
||||||
|
tks_num += 128
|
||||||
|
return np.array(arr), tks_num
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 初始化嵌入模型
|
||||||
|
embedder = OllamaEmbed(model_name="bge-m3")
|
||||||
|
|
||||||
|
# 测试文本
|
||||||
|
test_texts = ["测试文本"]
|
||||||
|
|
||||||
|
# 获取嵌入向量和token计数
|
||||||
|
embeddings, total_tokens = embedder.encode(test_texts)
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
print(f"Total tokens used: {total_tokens}")
|
||||||
|
print("\nEmbedding vectors:")
|
||||||
|
for i, (text, embedding) in enumerate(zip(test_texts, embeddings)):
|
||||||
|
print(f"\nText {i + 1}: '{text}'")
|
||||||
|
print(f"Embedding shape: {embedding.shape}")
|
||||||
|
print(f"First 5 values: {embedding[:5]}")
|
||||||
|
print(f"Embedding dtype: {embedding.dtype}")
|
||||||
|
|
||||||
|
# 打印完整的第一个embedding向量
|
||||||
|
print("\nComplete first embedding vector:")
|
||||||
|
print(embeddings[0])
|
|
@ -5,23 +5,17 @@ import requests
|
||||||
url = "https://api.siliconflow.cn/v1/embeddings"
|
url = "https://api.siliconflow.cn/v1/embeddings"
|
||||||
api_key = "你的API密钥" # 替换为你的API密钥
|
api_key = "你的API密钥" # 替换为你的API密钥
|
||||||
|
|
||||||
payload = {
|
payload = {"model": "BAAI/bge-m3", "input": "测试文本"}
|
||||||
"model": "BAAI/bge-m3",
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||||
"input": "Silicon flow embedding online: fast, affordable, and high-quality embedding services. come try it out!"
|
|
||||||
}
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.request("POST", url, json=payload, headers=headers)
|
response = requests.request("POST", url, json=payload, headers=headers)
|
||||||
|
|
||||||
print(response.text)
|
print(response.text)
|
||||||
|
|
||||||
# print(response.text.data)
|
# print(response.text.data)
|
||||||
|
|
||||||
# embedding_resp = response
|
# embedding_resp = response
|
||||||
# embedding_data = embedding_resp.json()
|
# embedding_data = embedding_resp.json()
|
||||||
# q_1024_vec = embedding_data["data"][0]["embedding"]
|
# q_1024_vec = embedding_data["data"][0]["embedding"]
|
||||||
|
|
||||||
# print("q_1024_vec", q_1024_vec)
|
# print("q_1024_vec", q_1024_vec)
|
||||||
|
|
|
@ -248,21 +248,20 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
|
||||||
if not embedding_api_base.startswith(("http://", "https://")):
|
if not embedding_api_base.startswith(("http://", "https://")):
|
||||||
embedding_api_base = "http://" + embedding_api_base
|
embedding_api_base = "http://" + embedding_api_base
|
||||||
|
|
||||||
# --- URL 拼接优化 (处理 /v1) ---
|
|
||||||
endpoint_segment = "embeddings"
|
|
||||||
full_endpoint_path = "v1/embeddings"
|
|
||||||
# 移除末尾斜杠以方便判断
|
# 移除末尾斜杠以方便判断
|
||||||
normalized_base_url = embedding_api_base.rstrip("/")
|
normalized_base_url = embedding_api_base.rstrip("/")
|
||||||
|
|
||||||
if normalized_base_url.endswith("/v1"):
|
# 如果请求url端口号为11434,则认为是ollama模型,采用ollama特定的api
|
||||||
# 如果 base_url 已经是 http://host/v1 形式
|
is_ollama = "11434" in normalized_base_url
|
||||||
embedding_url = normalized_base_url + "/" + endpoint_segment
|
if is_ollama:
|
||||||
|
# Ollama 的特殊接口路径
|
||||||
|
embedding_url = normalized_base_url + "/api/embeddings"
|
||||||
|
elif normalized_base_url.endswith("/v1"):
|
||||||
|
embedding_url = normalized_base_url + "/embeddings"
|
||||||
elif normalized_base_url.endswith("/embeddings"):
|
elif normalized_base_url.endswith("/embeddings"):
|
||||||
# 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API,无需再进行处理)
|
|
||||||
embedding_url = normalized_base_url
|
embedding_url = normalized_base_url
|
||||||
else:
|
else:
|
||||||
# 如果 base_url 是 http://host 或 http://host/api 等其他形式
|
embedding_url = normalized_base_url + "/v1/embeddings"
|
||||||
embedding_url = normalized_base_url + "/" + full_endpoint_path
|
|
||||||
|
|
||||||
print(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}")
|
print(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}")
|
||||||
|
|
||||||
|
@ -535,8 +534,14 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info):
|
||||||
|
|
||||||
embedding_resp.raise_for_status()
|
embedding_resp.raise_for_status()
|
||||||
embedding_data = embedding_resp.json()
|
embedding_data = embedding_resp.json()
|
||||||
q_1024_vec = embedding_data["data"][0]["embedding"]
|
|
||||||
|
# 对ollama嵌入模型的接口返回值进行特殊处理
|
||||||
|
if is_ollama:
|
||||||
|
q_1024_vec = embedding_data.get("embedding")
|
||||||
|
else:
|
||||||
|
q_1024_vec = embedding_data["data"][0]["embedding"]
|
||||||
print(f"[Parser-INFO] 获取embedding成功,长度: {len(q_1024_vec)}")
|
print(f"[Parser-INFO] 获取embedding成功,长度: {len(q_1024_vec)}")
|
||||||
|
|
||||||
# 检查向量维度是否为1024
|
# 检查向量维度是否为1024
|
||||||
if len(q_1024_vec) != 1024:
|
if len(q_1024_vec) != 1024:
|
||||||
error_msg = f"[Parser-ERROR] Embedding向量维度不是1024,实际维度: {len(q_1024_vec)}, 建议使用bge-m3模型"
|
error_msg = f"[Parser-ERROR] Embedding向量维度不是1024,实际维度: {len(q_1024_vec)}, 建议使用bge-m3模型"
|
||||||
|
|
Loading…
Reference in New Issue