111 lines
4.1 KiB
Python
111 lines
4.1 KiB
Python
import requests
|
||
import time
|
||
|
||
# Ollama配置
|
||
OLLAMA_HOST = "http://localhost:11434" # 默认Ollama地址
|
||
MODEL_NAME = "bge-m3" # 使用的embedding模型
|
||
TEXT_TO_EMBED = "测试文本"
|
||
|
||
# 定义接口URL和对应的请求体结构
|
||
ENDPOINTS = {
|
||
"api/embeddings": {
|
||
"url": f"{OLLAMA_HOST}/api/embeddings", # 原生API路径
|
||
"payload": {"model": MODEL_NAME, "prompt": TEXT_TO_EMBED}, # 原生API用prompt字段
|
||
},
|
||
"v1/embeddings": {
|
||
"url": f"{OLLAMA_HOST}/v1/embeddings", # OpenAI兼容API路径
|
||
"payload": {"model": MODEL_NAME, "input": TEXT_TO_EMBED}, # OpenAI兼容API用input字段
|
||
},
|
||
}
|
||
|
||
headers = {"Content-Type": "application/json"}
|
||
|
||
|
||
def test_endpoint(endpoint_name, endpoint_info):
|
||
"""测试单个端点并返回结果"""
|
||
print(f"\n测试接口: {endpoint_name}")
|
||
url = endpoint_info["url"]
|
||
payload = endpoint_info["payload"]
|
||
|
||
try:
|
||
start_time = time.time()
|
||
response = requests.post(url, headers=headers, json=payload)
|
||
response_time = time.time() - start_time
|
||
|
||
print(f"状态码: {response.status_code}")
|
||
print(f"响应时间: {response_time:.3f}秒")
|
||
|
||
try:
|
||
data = response.json()
|
||
|
||
# 处理不同接口的响应结构差异
|
||
embedding = None
|
||
if endpoint_name == "api/embeddings":
|
||
embedding = data.get("embedding") # 原生API返回embedding字段
|
||
elif endpoint_name == "v1/embeddings":
|
||
embedding = data.get("data", [{}])[0].get("embedding") # OpenAI兼容API返回data数组中的embedding
|
||
|
||
if embedding:
|
||
print(f"Embedding向量长度: {len(embedding)}")
|
||
return {
|
||
"endpoint": endpoint_name,
|
||
"status_code": response.status_code,
|
||
"response_time": response_time,
|
||
"embedding_length": len(embedding),
|
||
"embedding": embedding[:5],
|
||
}
|
||
else:
|
||
print("响应中未找到'embedding'字段")
|
||
return {"endpoint": endpoint_name, "status_code": response.status_code, "error": "No embedding field in response"}
|
||
|
||
except ValueError:
|
||
print("响应不是有效的JSON格式")
|
||
return {"endpoint": endpoint_name, "status_code": response.status_code, "error": "Invalid JSON response"}
|
||
|
||
except Exception as e:
|
||
print(f"请求失败: {str(e)}")
|
||
return {"endpoint": endpoint_name, "error": str(e)}
|
||
|
||
|
||
def compare_endpoints():
|
||
"""比较两个端点的性能"""
|
||
results = []
|
||
|
||
print("=" * 50)
|
||
print(f"开始比较Ollama的embeddings接口,使用模型: {MODEL_NAME}")
|
||
print("=" * 50)
|
||
|
||
for endpoint_name, endpoint_info in ENDPOINTS.items():
|
||
results.append(test_endpoint(endpoint_name, endpoint_info))
|
||
|
||
print("\n" + "=" * 50)
|
||
print("比较结果摘要:")
|
||
print("=" * 50)
|
||
|
||
successful_results = [res for res in results if "embedding_length" in res]
|
||
|
||
if len(successful_results) == 2:
|
||
if successful_results[0]["embedding_length"] == successful_results[1]["embedding_length"]:
|
||
print(f"两个接口返回的embedding维度相同: {successful_results[0]['embedding_length']}")
|
||
else:
|
||
print("两个接口返回的embedding维度不同:")
|
||
for result in successful_results:
|
||
print(f"- {result['endpoint']}: {result['embedding_length']}")
|
||
|
||
print("\nEmbedding前5个元素示例:")
|
||
for result in successful_results:
|
||
print(f"- {result['endpoint']}: {result['embedding']}")
|
||
|
||
faster = min(successful_results, key=lambda x: x["response_time"])
|
||
slower = max(successful_results, key=lambda x: x["response_time"])
|
||
print(f"\n更快的接口: {faster['endpoint']} ({faster['response_time']:.3f}秒 vs {slower['response_time']:.3f}秒)")
|
||
else:
|
||
print("至少有一个接口未返回有效的embedding数据")
|
||
for result in results:
|
||
if "error" in result:
|
||
print(f"- {result['endpoint']} 错误: {result['error']}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
compare_endpoints()
|