From 3496fff83ccede3a81a712d1192961369e87f1d8 Mon Sep 17 00:00:00 2001 From: zstar <65890619+zstar1003@users.noreply.github.com> Date: Thu, 5 Jun 2025 14:35:23 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8DOllama=E5=B5=8C?= =?UTF-8?q?=E5=85=A5=E6=A8=A1=E5=9E=8B=E6=8E=A5=E5=8F=A3=E5=92=8C=E5=89=8D?= =?UTF-8?q?=E5=8F=B0=E4=B8=8D=E4=B8=80=E8=87=B4=E7=9A=84=E9=97=AE=E9=A2=98?= =?UTF-8?q?=20issue#65?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/db/services/llm_service.py | 118 ++++++------------ management/server/scripts/embedding_test.py | 110 ++++++++++++++++ management/server/scripts/ollama_test.py | 64 ++++++++++ .../server/scripts/siliconflow_emb_test.py | 14 +-- .../knowledgebases/document_parser.py | 25 ++-- 5 files changed, 228 insertions(+), 103 deletions(-) create mode 100644 management/server/scripts/embedding_test.py create mode 100644 management/server/scripts/ollama_test.py diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 039d459..5f93e8a 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -52,16 +52,8 @@ class TenantLLMService(CommonService): @classmethod @DB.connection_context() def get_my_llms(cls, tenant_id): - fields = [ - cls.model.llm_factory, - LLMFactories.logo, - LLMFactories.tags, - cls.model.model_type, - cls.model.llm_name, - cls.model.used_tokens - ] - objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where( - cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts() + fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens] + objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts() return list(objs) @@ -117,8 +109,7 @@ class TenantLLMService(CommonService): model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""} if not model_config: if mdlnm == "flag-embedding": - model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", - "llm_name": llm_name, "api_base": ""} + model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""} else: if not mdlnm: raise LookupError(f"Type of {llm_type} model is not set.") @@ -127,43 +118,32 @@ class TenantLLMService(CommonService): @classmethod @DB.connection_context() - def model_instance(cls, tenant_id, llm_type, - llm_name=None, lang="Chinese"): + def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"): model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) if llm_type == LLMType.EMBEDDING.value: if model_config["llm_factory"] not in EmbeddingModel: return - return EmbeddingModel[model_config["llm_factory"]]( - model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) + return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) if llm_type == LLMType.RERANK: if model_config["llm_factory"] not in RerankModel: return - return RerankModel[model_config["llm_factory"]]( - model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) + return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) if llm_type == LLMType.IMAGE2TEXT.value: if model_config["llm_factory"] not in CvModel: return - return CvModel[model_config["llm_factory"]]( - model_config["api_key"], model_config["llm_name"], lang, - base_url=model_config["api_base"] - ) + return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"]) if llm_type == LLMType.CHAT.value: if model_config["llm_factory"] not in ChatModel: return - return ChatModel[model_config["llm_factory"]]( - model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) + return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) if llm_type == LLMType.SPEECH2TEXT: if model_config["llm_factory"] not in Seq2txtModel: return - return Seq2txtModel[model_config["llm_factory"]]( - key=model_config["api_key"], model_name=model_config["llm_name"], - lang=lang, - base_url=model_config["api_base"] - ) + return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"]) if llm_type == LLMType.TTS: if model_config["llm_factory"] not in TTSModel: return @@ -187,7 +167,7 @@ class TenantLLMService(CommonService): LLMType.IMAGE2TEXT.value: tenant.img2txt_id, LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name, LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name, - LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name + LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name, } mdlnm = llm_map.get(llm_type) @@ -198,17 +178,13 @@ class TenantLLMService(CommonService): llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm) try: - num = cls.model.update( - used_tokens=cls.model.used_tokens + used_tokens - ).where( - cls.model.tenant_id == tenant_id, - cls.model.llm_name == llm_name, - cls.model.llm_factory == llm_factory if llm_factory else True - ).execute() + num = ( + cls.model.update(used_tokens=cls.model.used_tokens + used_tokens) + .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True) + .execute() + ) except Exception: - logging.exception( - "TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", - tenant_id, llm_name) + logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name) return 0 return num @@ -216,11 +192,7 @@ class TenantLLMService(CommonService): @classmethod @DB.connection_context() def get_openai_models(cls): - objs = cls.model.select().where( - (cls.model.llm_factory == "OpenAI"), - ~(cls.model.llm_name == "text-embedding-3-small"), - ~(cls.model.llm_name == "text-embedding-3-large") - ).dicts() + objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() return list(objs) @@ -229,79 +201,59 @@ class LLMBundle: self.tenant_id = tenant_id self.llm_type = llm_type self.llm_name = llm_name - self.mdl = TenantLLMService.model_instance( - tenant_id, llm_type, llm_name, lang=lang) - assert self.mdl, "Can't find model for {}/{}/{}".format( - tenant_id, llm_type, llm_name) + self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang) + assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name) model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) self.max_length = model_config.get("max_tokens", 8192) def encode(self, texts: list): embeddings, used_tokens = self.mdl.encode(texts) - if not TenantLLMService.increase_usage( - self.tenant_id, self.llm_type, used_tokens): - logging.error( - "LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens)) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): + logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens)) return embeddings, used_tokens def encode_queries(self, query: str): emd, used_tokens = self.mdl.encode_queries(query) - if not TenantLLMService.increase_usage( - self.tenant_id, self.llm_type, used_tokens): - logging.error( - "LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens)) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): + logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens)) return emd, used_tokens def similarity(self, query: str, texts: list): sim, used_tokens = self.mdl.similarity(query, texts) - if not TenantLLMService.increase_usage( - self.tenant_id, self.llm_type, used_tokens): - logging.error( - "LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens)) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): + logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens)) return sim, used_tokens def describe(self, image, max_tokens=300): txt, used_tokens = self.mdl.describe(image, max_tokens) - if not TenantLLMService.increase_usage( - self.tenant_id, self.llm_type, used_tokens): - logging.error( - "LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens)) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): + logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens)) return txt def transcription(self, audio): txt, used_tokens = self.mdl.transcription(audio) - if not TenantLLMService.increase_usage( - self.tenant_id, self.llm_type, used_tokens): - logging.error( - "LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens)) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): + logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens)) return txt def tts(self, text): for chunk in self.mdl.tts(text): if isinstance(chunk, int): - if not TenantLLMService.increase_usage( - self.tenant_id, self.llm_type, chunk, self.llm_name): - logging.error( - "LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id)) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name): + logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id)) return yield chunk def chat(self, system, history, gen_conf): txt, used_tokens = self.mdl.chat(system, history, gen_conf) - if isinstance(txt, int) and not TenantLLMService.increase_usage( - self.tenant_id, self.llm_type, used_tokens, self.llm_name): - logging.error( - "LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, - used_tokens)) + if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): + logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) return txt def chat_streamly(self, system, history, gen_conf): for txt in self.mdl.chat_streamly(system, history, gen_conf): if isinstance(txt, int): - if not TenantLLMService.increase_usage( - self.tenant_id, self.llm_type, txt, self.llm_name): - logging.error( - "LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, - txt)) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name): + logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) return yield txt diff --git a/management/server/scripts/embedding_test.py b/management/server/scripts/embedding_test.py new file mode 100644 index 0000000..477899d --- /dev/null +++ b/management/server/scripts/embedding_test.py @@ -0,0 +1,110 @@ +import requests +import time + +# Ollama配置 +OLLAMA_HOST = "http://localhost:11434" # 默认Ollama地址 +MODEL_NAME = "bge-m3" # 使用的embedding模型 +TEXT_TO_EMBED = "测试文本" + +# 定义接口URL和对应的请求体结构 +ENDPOINTS = { + "api/embeddings": { + "url": f"{OLLAMA_HOST}/api/embeddings", # 原生API路径 + "payload": {"model": MODEL_NAME, "prompt": TEXT_TO_EMBED}, # 原生API用prompt字段 + }, + "v1/embeddings": { + "url": f"{OLLAMA_HOST}/v1/embeddings", # OpenAI兼容API路径 + "payload": {"model": MODEL_NAME, "input": TEXT_TO_EMBED}, # OpenAI兼容API用input字段 + }, +} + +headers = {"Content-Type": "application/json"} + + +def test_endpoint(endpoint_name, endpoint_info): + """测试单个端点并返回结果""" + print(f"\n测试接口: {endpoint_name}") + url = endpoint_info["url"] + payload = endpoint_info["payload"] + + try: + start_time = time.time() + response = requests.post(url, headers=headers, json=payload) + response_time = time.time() - start_time + + print(f"状态码: {response.status_code}") + print(f"响应时间: {response_time:.3f}秒") + + try: + data = response.json() + + # 处理不同接口的响应结构差异 + embedding = None + if endpoint_name == "api/embeddings": + embedding = data.get("embedding") # 原生API返回embedding字段 + elif endpoint_name == "v1/embeddings": + embedding = data.get("data", [{}])[0].get("embedding") # OpenAI兼容API返回data数组中的embedding + + if embedding: + print(f"Embedding向量长度: {len(embedding)}") + return { + "endpoint": endpoint_name, + "status_code": response.status_code, + "response_time": response_time, + "embedding_length": len(embedding), + "embedding": embedding[:5], + } + else: + print("响应中未找到'embedding'字段") + return {"endpoint": endpoint_name, "status_code": response.status_code, "error": "No embedding field in response"} + + except ValueError: + print("响应不是有效的JSON格式") + return {"endpoint": endpoint_name, "status_code": response.status_code, "error": "Invalid JSON response"} + + except Exception as e: + print(f"请求失败: {str(e)}") + return {"endpoint": endpoint_name, "error": str(e)} + + +def compare_endpoints(): + """比较两个端点的性能""" + results = [] + + print("=" * 50) + print(f"开始比较Ollama的embeddings接口,使用模型: {MODEL_NAME}") + print("=" * 50) + + for endpoint_name, endpoint_info in ENDPOINTS.items(): + results.append(test_endpoint(endpoint_name, endpoint_info)) + + print("\n" + "=" * 50) + print("比较结果摘要:") + print("=" * 50) + + successful_results = [res for res in results if "embedding_length" in res] + + if len(successful_results) == 2: + if successful_results[0]["embedding_length"] == successful_results[1]["embedding_length"]: + print(f"两个接口返回的embedding维度相同: {successful_results[0]['embedding_length']}") + else: + print("两个接口返回的embedding维度不同:") + for result in successful_results: + print(f"- {result['endpoint']}: {result['embedding_length']}") + + print("\nEmbedding前5个元素示例:") + for result in successful_results: + print(f"- {result['endpoint']}: {result['embedding']}") + + faster = min(successful_results, key=lambda x: x["response_time"]) + slower = max(successful_results, key=lambda x: x["response_time"]) + print(f"\n更快的接口: {faster['endpoint']} ({faster['response_time']:.3f}秒 vs {slower['response_time']:.3f}秒)") + else: + print("至少有一个接口未返回有效的embedding数据") + for result in results: + if "error" in result: + print(f"- {result['endpoint']} 错误: {result['error']}") + + +if __name__ == "__main__": + compare_endpoints() diff --git a/management/server/scripts/ollama_test.py b/management/server/scripts/ollama_test.py new file mode 100644 index 0000000..4e9d606 --- /dev/null +++ b/management/server/scripts/ollama_test.py @@ -0,0 +1,64 @@ +import numpy as np +from abc import ABC +from ollama import Client + + +class Base(ABC): + def __init__(self, key, model_name): + pass + + def encode(self, texts: list): + raise NotImplementedError("Please implement encode method!") + + def encode_queries(self, text: str): + raise NotImplementedError("Please implement encode method!") + + def total_token_count(self, resp): + try: + return resp.usage.total_tokens + except Exception: + pass + try: + return resp["usage"]["total_tokens"] + except Exception: + pass + return 0 + + +class OllamaEmbed(Base): + def __init__(self, model_name, **kwargs): + self.client = Client(host="http://localhost:11434", **kwargs) + self.model_name = model_name + + def encode(self, texts: list): + arr = [] + tks_num = 0 + for txt in texts: + res = self.client.embeddings(prompt=txt, model=self.model_name) + arr.append(res["embedding"]) + tks_num += 128 + return np.array(arr), tks_num + + +if __name__ == "__main__": + # 初始化嵌入模型 + embedder = OllamaEmbed(model_name="bge-m3") + + # 测试文本 + test_texts = ["测试文本"] + + # 获取嵌入向量和token计数 + embeddings, total_tokens = embedder.encode(test_texts) + + # 打印结果 + print(f"Total tokens used: {total_tokens}") + print("\nEmbedding vectors:") + for i, (text, embedding) in enumerate(zip(test_texts, embeddings)): + print(f"\nText {i + 1}: '{text}'") + print(f"Embedding shape: {embedding.shape}") + print(f"First 5 values: {embedding[:5]}") + print(f"Embedding dtype: {embedding.dtype}") + + # 打印完整的第一个embedding向量 + print("\nComplete first embedding vector:") + print(embeddings[0]) diff --git a/management/server/scripts/siliconflow_emb_test.py b/management/server/scripts/siliconflow_emb_test.py index 5cab2d7..b298085 100644 --- a/management/server/scripts/siliconflow_emb_test.py +++ b/management/server/scripts/siliconflow_emb_test.py @@ -5,23 +5,17 @@ import requests url = "https://api.siliconflow.cn/v1/embeddings" api_key = "你的API密钥" # 替换为你的API密钥 -payload = { - "model": "BAAI/bge-m3", - "input": "Silicon flow embedding online: fast, affordable, and high-quality embedding services. come try it out!" -} -headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json" -} +payload = {"model": "BAAI/bge-m3", "input": "测试文本"} +headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} response = requests.request("POST", url, json=payload, headers=headers) print(response.text) -# print(response.text.data) +# print(response.text.data) # embedding_resp = response # embedding_data = embedding_resp.json() # q_1024_vec = embedding_data["data"][0]["embedding"] -# print("q_1024_vec", q_1024_vec) \ No newline at end of file +# print("q_1024_vec", q_1024_vec) diff --git a/management/server/services/knowledgebases/document_parser.py b/management/server/services/knowledgebases/document_parser.py index 3b92a1a..6069fac 100644 --- a/management/server/services/knowledgebases/document_parser.py +++ b/management/server/services/knowledgebases/document_parser.py @@ -248,21 +248,20 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info): if not embedding_api_base.startswith(("http://", "https://")): embedding_api_base = "http://" + embedding_api_base - # --- URL 拼接优化 (处理 /v1) --- - endpoint_segment = "embeddings" - full_endpoint_path = "v1/embeddings" # 移除末尾斜杠以方便判断 normalized_base_url = embedding_api_base.rstrip("/") - if normalized_base_url.endswith("/v1"): - # 如果 base_url 已经是 http://host/v1 形式 - embedding_url = normalized_base_url + "/" + endpoint_segment + # 如果请求url端口号为11434,则认为是ollama模型,采用ollama特定的api + is_ollama = "11434" in normalized_base_url + if is_ollama: + # Ollama 的特殊接口路径 + embedding_url = normalized_base_url + "/api/embeddings" + elif normalized_base_url.endswith("/v1"): + embedding_url = normalized_base_url + "/embeddings" elif normalized_base_url.endswith("/embeddings"): - # 如果 base_url 已经是 http://host/embeddings 形式(比如硅基流动API,无需再进行处理) embedding_url = normalized_base_url else: - # 如果 base_url 是 http://host 或 http://host/api 等其他形式 - embedding_url = normalized_base_url + "/" + full_endpoint_path + embedding_url = normalized_base_url + "/v1/embeddings" print(f"[Parser-INFO] 使用 Embedding 配置: URL='{embedding_url}', Model='{embedding_model_name}', Key={embedding_api_key}") @@ -535,8 +534,14 @@ def perform_parse(doc_id, doc_info, file_info, embedding_config, kb_info): embedding_resp.raise_for_status() embedding_data = embedding_resp.json() - q_1024_vec = embedding_data["data"][0]["embedding"] + + # 对ollama嵌入模型的接口返回值进行特殊处理 + if is_ollama: + q_1024_vec = embedding_data.get("embedding") + else: + q_1024_vec = embedding_data["data"][0]["embedding"] print(f"[Parser-INFO] 获取embedding成功,长度: {len(q_1024_vec)}") + # 检查向量维度是否为1024 if len(q_1024_vec) != 1024: error_msg = f"[Parser-ERROR] Embedding向量维度不是1024,实际维度: {len(q_1024_vec)}, 建议使用bge-m3模型"