176 lines
5.8 KiB
Python
176 lines
5.8 KiB
Python
|
|
import yaml
|
|||
|
|
from typing import Optional, Dict, Any, List
|
|||
|
|
from contextlib import contextmanager
|
|||
|
|
import pymysql
|
|||
|
|
from pymysql.cursors import DictCursor
|
|||
|
|
import os
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
# 全局数据库连接实例
|
|||
|
|
db = None
|
|||
|
|
|
|||
|
|
class DatabaseConnectionError(Exception):
|
|||
|
|
"""数据库连接错误"""
|
|||
|
|
|
|||
|
|
def __init__(self, original_error: Exception = None):
|
|||
|
|
"""
|
|||
|
|
初始化数据库连接错误
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
original_error: 原始异常对象(可选)
|
|||
|
|
"""
|
|||
|
|
self.original_error = original_error
|
|||
|
|
|
|||
|
|
# 构建错误消息
|
|||
|
|
message = "请检查数据库连接"
|
|||
|
|
if original_error:
|
|||
|
|
message = f"{message}: {original_error}"
|
|||
|
|
|
|||
|
|
super().__init__(message)
|
|||
|
|
self.message = message
|
|||
|
|
|
|||
|
|
# 数据库连接管理类
|
|||
|
|
class DatabaseConnection:
|
|||
|
|
# 初始化数据库连接
|
|||
|
|
def __init__(self, config_path: str = "../../config.yaml", strict_mode: bool = True):
|
|||
|
|
self.strict_mode = strict_mode
|
|||
|
|
self._connection = None
|
|||
|
|
self.config = self._load_config(config_path)
|
|||
|
|
# 获取数据库类型,默认为 starrocks
|
|||
|
|
self.db_type = self.config.get('database', {}).get('type', 'starrocks')
|
|||
|
|
# 根据类型选择对应的配置
|
|||
|
|
self.db_config = self.config.get('database', {}).get(self.db_type, {})
|
|||
|
|
|
|||
|
|
# 加载配置文件
|
|||
|
|
def _load_config(self, config_path: str) -> dict:
|
|||
|
|
try:
|
|||
|
|
# 处理绝对路径
|
|||
|
|
if os.path.isabs(config_path):
|
|||
|
|
yaml_path = Path(config_path)
|
|||
|
|
else:
|
|||
|
|
# 尝试相对于当前文件的路径
|
|||
|
|
current_dir = Path(__file__).parent
|
|||
|
|
yaml_path = (current_dir / config_path).resolve()
|
|||
|
|
|
|||
|
|
# 如果找不到,向上查找项目根目录的 config.yaml
|
|||
|
|
if not yaml_path.exists():
|
|||
|
|
# 从当前目录开始向上查找
|
|||
|
|
search_dir = Path.cwd()
|
|||
|
|
for _ in range(5): # 最多向上查找5层
|
|||
|
|
candidate = search_dir / "config.yaml"
|
|||
|
|
if candidate.exists():
|
|||
|
|
yaml_path = candidate
|
|||
|
|
break
|
|||
|
|
search_dir = search_dir.parent
|
|||
|
|
|
|||
|
|
if not yaml_path.exists():
|
|||
|
|
raise FileNotFoundError(f"配置文件未找到: {config_path}")
|
|||
|
|
|
|||
|
|
with open(yaml_path, 'r', encoding='utf-8') as f:
|
|||
|
|
return yaml.safe_load(f)
|
|||
|
|
except Exception as e:
|
|||
|
|
if self.strict_mode:
|
|||
|
|
raise DatabaseConnectionError(e)
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
# 获取数据库连接
|
|||
|
|
def get_connection(self):
|
|||
|
|
try:
|
|||
|
|
connection = pymysql.connect(
|
|||
|
|
host=self.db_config.get('host', 'localhost'),
|
|||
|
|
port=self.db_config.get('port', 9030),
|
|||
|
|
user=self.db_config.get('user', 'root'),
|
|||
|
|
password=self.db_config.get('password', ''),
|
|||
|
|
database=self.db_config.get('database', 'smart_home'),
|
|||
|
|
charset=self.db_config.get('charset', 'utf8mb4'),
|
|||
|
|
cursorclass=DictCursor,
|
|||
|
|
autocommit=True,
|
|||
|
|
connect_timeout=5
|
|||
|
|
)
|
|||
|
|
return connection
|
|||
|
|
except Exception as e:
|
|||
|
|
if self.strict_mode:
|
|||
|
|
raise DatabaseConnectionError(e)
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
# 测试数据库连接
|
|||
|
|
def test_connection(self) -> bool:
|
|||
|
|
# 测试数据库连接
|
|||
|
|
try:
|
|||
|
|
connection = self.get_connection()
|
|||
|
|
with connection.cursor() as cursor:
|
|||
|
|
cursor.execute("SELECT 1")
|
|||
|
|
result = cursor.fetchone()
|
|||
|
|
connection.close()
|
|||
|
|
|
|||
|
|
if result:
|
|||
|
|
return True
|
|||
|
|
else:
|
|||
|
|
if self.strict_mode:
|
|||
|
|
raise DatabaseConnectionError()
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
if self.strict_mode:
|
|||
|
|
raise DatabaseConnectionError(e)
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 获取数据库游标(上下文管理器)
|
|||
|
|
@contextmanager
|
|||
|
|
def get_cursor(self):
|
|||
|
|
connection = self.get_connection()
|
|||
|
|
cursor = connection.cursor()
|
|||
|
|
try:
|
|||
|
|
yield cursor
|
|||
|
|
finally:
|
|||
|
|
cursor.close()
|
|||
|
|
connection.close()
|
|||
|
|
|
|||
|
|
# 执行查询SQL
|
|||
|
|
def execute_query(self, sql: str, params: tuple = None) -> List[Dict[str, Any]]:
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
cursor.execute(sql, params)
|
|||
|
|
return cursor.fetchall()
|
|||
|
|
|
|||
|
|
# 执行更新SQL(INSERT/UPDATE/DELETE)
|
|||
|
|
def execute_update(self, sql: str, params: tuple = None) -> int:
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
affected = cursor.execute(sql, params)
|
|||
|
|
return affected
|
|||
|
|
|
|||
|
|
# 批量执行SQL
|
|||
|
|
def execute_many(self, sql: str, params_list: List[tuple]) -> int:
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
affected = cursor.executemany(sql, params_list)
|
|||
|
|
return affected
|
|||
|
|
|
|||
|
|
def init_database(strict_mode: bool = True) -> DatabaseConnection:
|
|||
|
|
global db
|
|||
|
|
if db is None:
|
|||
|
|
db = DatabaseConnection(strict_mode=strict_mode)
|
|||
|
|
db.test_connection()
|
|||
|
|
return db
|
|||
|
|
|
|||
|
|
# 查询
|
|||
|
|
def query(sql: str, params: tuple = None) -> List[Dict[str, Any]]:
|
|||
|
|
if db is None:
|
|||
|
|
raise DatabaseConnectionError()
|
|||
|
|
return db.execute_query(sql, params)
|
|||
|
|
|
|||
|
|
# 更新
|
|||
|
|
def update(sql: str, params: tuple = None) -> int:
|
|||
|
|
if db is None:
|
|||
|
|
raise DatabaseConnectionError()
|
|||
|
|
return db.execute_update(sql, params)
|
|||
|
|
|
|||
|
|
# 插入
|
|||
|
|
def insert(sql: str, params: tuple = None) -> int:
|
|||
|
|
if db is None:
|
|||
|
|
raise DatabaseConnectionError()
|
|||
|
|
return db.execute_update(sql, params)
|
|||
|
|
|
|||
|
|
# 获取数据库类型
|
|||
|
|
def get_db_type() -> str:
|
|||
|
|
if db is None:
|
|||
|
|
raise DatabaseConnectionError()
|
|||
|
|
return db.db_type
|