feat: 添加模型下载和测试脚本
添加 `download_model.py` 用于从 Hugging Face 下载模型,支持断点续传。添加 `model_test.py` 用于测试下载的嵌入模型和文本生成模型,确保模型功能正常。
This commit is contained in:
parent
5a72b69d7f
commit
0e213aaa09
|
@ -0,0 +1,34 @@
|
||||||
|
import os
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
# 1. 设置镜像源(国内加速)
|
||||||
|
# os.environ["HF_ENDPOINT"] = "https://mirrors.tuna.tsinghua.edu.cn/hugging-face/"
|
||||||
|
|
||||||
|
# 2. 定义模型列表(名称 + 下载路径)
|
||||||
|
models_to_download = [
|
||||||
|
{
|
||||||
|
"repo_id": "BAAI/bge-m3", # Embedding 模型
|
||||||
|
"local_dir": os.path.expanduser("./models/bge-m3"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"repo_id": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", # LLM 模型
|
||||||
|
"local_dir": os.path.expanduser("./models/DeepSeek-R1-1.5B"),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# 3. 遍历下载所有模型
|
||||||
|
for model in models_to_download:
|
||||||
|
while True: # 断点续传重试机制
|
||||||
|
try:
|
||||||
|
print(f"开始下载模型: {model['repo_id']} 到目录: {model['local_dir']}")
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=model["repo_id"],
|
||||||
|
local_dir=model["local_dir"],
|
||||||
|
resume_download=True, # 启用断点续传
|
||||||
|
force_download=False, # 避免重复下载已有文件
|
||||||
|
token=None, # 如需访问私有模型,替换为你的 token
|
||||||
|
)
|
||||||
|
print(f"模型 {model['repo_id']} 下载完成!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"下载失败: {e}, 重试中...")
|
|
@ -0,0 +1,49 @@
|
||||||
|
import requests
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# 测试 embedding 模型 (vllm-bge)
|
||||||
|
def test_embedding(model, text):
|
||||||
|
"""测试嵌入模型"""
|
||||||
|
client = OpenAI(base_url="http://localhost:8000/v1", api_key="1")
|
||||||
|
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model=model, # 使用支持嵌入的模型
|
||||||
|
input=text # 需要嵌入的文本
|
||||||
|
)
|
||||||
|
|
||||||
|
# 打印嵌入响应内容
|
||||||
|
# print(f"Embedding response: {response}")
|
||||||
|
|
||||||
|
result = response.data[0].embedding
|
||||||
|
|
||||||
|
if response and response.data:
|
||||||
|
print(len(result))
|
||||||
|
else:
|
||||||
|
print("Failed to get embedding.")
|
||||||
|
|
||||||
|
# 测试文本生成模型 (vllm-deepseek)
|
||||||
|
def test_chat(model, prompt):
|
||||||
|
"""测试文本生成模型"""
|
||||||
|
client = OpenAI(base_url="http://localhost:8001/v1", api_key="1")
|
||||||
|
|
||||||
|
response = client.completions.create(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
# 打印生成的文本
|
||||||
|
print(f"Chat response: {response.choices[0].text}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 测试文本生成模型 deepseek-r1
|
||||||
|
prompt = "你好,今天的天气怎么样?"
|
||||||
|
print("Testing vllm-deepseek model for chat...")
|
||||||
|
test_chat("deepseek-r1", prompt)
|
||||||
|
|
||||||
|
# 测试嵌入模型 bge-m3
|
||||||
|
embedding_text = "我喜欢编程,尤其是做AI模型。"
|
||||||
|
print("\nTesting vllm-bge model for embedding...")
|
||||||
|
test_embedding("bge-m3", embedding_text)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue