refactor(database): 重新调整docker容器的连接配置

This commit is contained in:
zstar 2025-05-17 15:55:37 +08:00
parent fd7f1140cd
commit 84603765cb
1 changed files with 42 additions and 33 deletions

View File

@ -1,34 +1,45 @@
import mysql.connector import mysql.connector
import os import os
from utils import generate_uuid, encrypt_password
from datetime import datetime
from minio import Minio from minio import Minio
from dotenv import load_dotenv from dotenv import load_dotenv
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
# 加载环境变量 # 加载环境变量
load_dotenv("../../docker/.env") load_dotenv("../../docker/.env")
# 检测是否在Docker容器中运行 # 检测是否在Docker容器中运行
def is_running_in_docker(): def is_running_in_docker():
# 检查是否存在/.dockerenv文件 # 检查是否存在/.dockerenv文件
docker_env = os.path.exists('/.dockerenv') docker_env = os.path.exists("/.dockerenv")
# 或者检查cgroup中是否包含docker字符串 # 或者检查cgroup中是否包含docker字符串
try: try:
with open('/proc/self/cgroup', 'r') as f: with open("/proc/self/cgroup", "r") as f:
return docker_env or 'docker' in f.read() return docker_env or "docker" in f.read()
except: except: # noqa: E722
return docker_env return docker_env
# 根据运行环境选择合适的主机地址
DB_HOST = 'host.docker.internal' if is_running_in_docker() else 'localhost' # 根据运行环境选择合适的主机地址和端口
MINIO_HOST = 'host.docker.internal' if is_running_in_docker() else 'localhost' if is_running_in_docker():
ES_HOST = 'es01' if is_running_in_docker() else 'localhost' MYSQL_HOST = "mysql"
MYSQL_PORT = 3306
MINIO_HOST = "minio"
MINIO_PORT = 9000
ES_HOST = "es01"
ES_PORT = 9200
else:
MYSQL_HOST = "host.docker.internal"
MYSQL_PORT = int(os.getenv("MYSQL_PORT", "5455"))
MINIO_HOST = "host.docker.internal"
MINIO_PORT = int(os.getenv("MINIO_PORT", "9000"))
ES_HOST = "host.docker.internal"
ES_PORT = int(os.getenv("ES_PORT", "9200"))
# 数据库连接配置 # 数据库连接配置
DB_CONFIG = { DB_CONFIG = {
"host": DB_HOST, "host": MYSQL_HOST,
"port": int(os.getenv("MYSQL_PORT", "5455")), "port": MYSQL_PORT,
"user": "root", "user": "root",
"password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"), "password": os.getenv("MYSQL_PASSWORD", "infini_rag_flow"),
"database": "rag_flow", "database": "rag_flow",
@ -36,20 +47,21 @@ DB_CONFIG = {
# MinIO连接配置 # MinIO连接配置
MINIO_CONFIG = { MINIO_CONFIG = {
"endpoint": f"{MINIO_HOST}:{os.getenv('MINIO_PORT', '9000')}", "endpoint": f"{MINIO_HOST}:{MINIO_PORT}",
"access_key": os.getenv("MINIO_USER", "rag_flow"), "access_key": os.getenv("MINIO_USER", "rag_flow"),
"secret_key": os.getenv("MINIO_PASSWORD", "infini_rag_flow"), "secret_key": os.getenv("MINIO_PASSWORD", "infini_rag_flow"),
"secure": False "secure": False,
} }
# Elasticsearch连接配置 # Elasticsearch连接配置
ES_CONFIG = { ES_CONFIG = {
"host": f"http://{ES_HOST}:{os.getenv('ES_PORT', '9200')}", "host": f"http://{ES_HOST}:{ES_PORT}",
"user": os.getenv("ELASTIC_USER", "elastic"), "user": os.getenv("ELASTIC_USER", "elastic"),
"password": os.getenv("ELASTIC_PASSWORD", "infini_rag_flow"), "password": os.getenv("ELASTIC_PASSWORD", "infini_rag_flow"),
"use_ssl": os.getenv("ES_USE_SSL", "false").lower() == "true" "use_ssl": os.getenv("ES_USE_SSL", "false").lower() == "true",
} }
def get_db_connection(): def get_db_connection():
"""创建MySQL数据库连接""" """创建MySQL数据库连接"""
try: try:
@ -59,43 +71,39 @@ def get_db_connection():
print(f"MySQL连接失败: {str(e)}") print(f"MySQL连接失败: {str(e)}")
raise e raise e
def get_minio_client(): def get_minio_client():
"""创建MinIO客户端连接""" """创建MinIO客户端连接"""
try: try:
minio_client = Minio( minio_client = Minio(endpoint=MINIO_CONFIG["endpoint"], access_key=MINIO_CONFIG["access_key"], secret_key=MINIO_CONFIG["secret_key"], secure=MINIO_CONFIG["secure"])
endpoint=MINIO_CONFIG["endpoint"],
access_key=MINIO_CONFIG["access_key"],
secret_key=MINIO_CONFIG["secret_key"],
secure=MINIO_CONFIG["secure"]
)
return minio_client return minio_client
except Exception as e: except Exception as e:
print(f"MinIO连接失败: {str(e)}") print(f"MinIO连接失败: {str(e)}")
raise e raise e
def get_es_client(): def get_es_client():
"""创建Elasticsearch客户端连接""" """创建Elasticsearch客户端连接"""
try: try:
# 构建连接参数 # 构建连接参数
es_params = { es_params = {"hosts": [ES_CONFIG["host"]]}
"hosts": [ES_CONFIG["host"]]
}
# 如果提供了用户名和密码,添加认证信息 # 如果提供了用户名和密码,添加认证信息
if ES_CONFIG["user"] and ES_CONFIG["password"]: if ES_CONFIG["user"] and ES_CONFIG["password"]:
es_params["basic_auth"] = (ES_CONFIG["user"], ES_CONFIG["password"]) es_params["basic_auth"] = (ES_CONFIG["user"], ES_CONFIG["password"])
# 如果需要SSL添加SSL配置 # 如果需要SSL添加SSL配置
if ES_CONFIG["use_ssl"]: if ES_CONFIG["use_ssl"]:
es_params["use_ssl"] = True es_params["use_ssl"] = True
es_params["verify_certs"] = False # 在开发环境中可以设置为False生产环境应该设置为True es_params["verify_certs"] = False # 在开发环境中可以设置为False生产环境应该设置为True
es_client = Elasticsearch(**es_params) es_client = Elasticsearch(**es_params)
return es_client return es_client
except Exception as e: except Exception as e:
print(f"Elasticsearch连接失败: {str(e)}") print(f"Elasticsearch连接失败: {str(e)}")
raise e raise e
def test_connections(): def test_connections():
"""测试数据库和MinIO连接""" """测试数据库和MinIO连接"""
try: try:
@ -107,12 +115,12 @@ def test_connections():
cursor.close() cursor.close()
db_conn.close() db_conn.close()
print("MySQL连接测试成功") print("MySQL连接测试成功")
# 测试MinIO连接 # 测试MinIO连接
minio_client = get_minio_client() minio_client = get_minio_client()
buckets = minio_client.list_buckets() buckets = minio_client.list_buckets()
print(f"MinIO连接测试成功共有 {len(buckets)} 个存储桶") print(f"MinIO连接测试成功共有 {len(buckets)} 个存储桶")
# 测试Elasticsearch连接 # 测试Elasticsearch连接
try: try:
es_client = get_es_client() es_client = get_es_client()
@ -120,12 +128,13 @@ def test_connections():
print(f"Elasticsearch连接测试成功版本: {es_info.get('version', {}).get('number', '未知')}") print(f"Elasticsearch连接测试成功版本: {es_info.get('version', {}).get('number', '未知')}")
except Exception as e: except Exception as e:
print(f"Elasticsearch连接测试失败: {str(e)}") print(f"Elasticsearch连接测试失败: {str(e)}")
return True return True
except Exception as e: except Exception as e:
print(f"连接测试失败: {str(e)}") print(f"连接测试失败: {str(e)}")
return False return False
if __name__ == "__main__": if __name__ == "__main__":
# 如果直接运行此文件,则测试连接 # 如果直接运行此文件,则测试连接
test_connections() test_connections()