Files
moss-ai/app/backend-python/database.py
雷雨 8635b84b2d init
2025-12-15 22:05:56 +08:00

176 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import yaml
from typing import Optional, Dict, Any, List
from contextlib import contextmanager
import pymysql
from pymysql.cursors import DictCursor
import os
from pathlib import Path
# 全局数据库连接实例
db = None
class DatabaseConnectionError(Exception):
"""数据库连接错误"""
def __init__(self, original_error: Exception = None):
"""
初始化数据库连接错误
Args:
original_error: 原始异常对象(可选)
"""
self.original_error = original_error
# 构建错误消息
message = "请检查数据库连接"
if original_error:
message = f"{message}: {original_error}"
super().__init__(message)
self.message = message
# 数据库连接管理类
class DatabaseConnection:
# 初始化数据库连接
def __init__(self, config_path: str = "../../config.yaml", strict_mode: bool = True):
self.strict_mode = strict_mode
self._connection = None
self.config = self._load_config(config_path)
# 获取数据库类型,默认为 starrocks
self.db_type = self.config.get('database', {}).get('type', 'starrocks')
# 根据类型选择对应的配置
self.db_config = self.config.get('database', {}).get(self.db_type, {})
# 加载配置文件
def _load_config(self, config_path: str) -> dict:
try:
# 处理绝对路径
if os.path.isabs(config_path):
yaml_path = Path(config_path)
else:
# 尝试相对于当前文件的路径
current_dir = Path(__file__).parent
yaml_path = (current_dir / config_path).resolve()
# 如果找不到,向上查找项目根目录的 config.yaml
if not yaml_path.exists():
# 从当前目录开始向上查找
search_dir = Path.cwd()
for _ in range(5): # 最多向上查找5层
candidate = search_dir / "config.yaml"
if candidate.exists():
yaml_path = candidate
break
search_dir = search_dir.parent
if not yaml_path.exists():
raise FileNotFoundError(f"配置文件未找到: {config_path}")
with open(yaml_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
except Exception as e:
if self.strict_mode:
raise DatabaseConnectionError(e)
raise
# 获取数据库连接
def get_connection(self):
try:
connection = pymysql.connect(
host=self.db_config.get('host', 'localhost'),
port=self.db_config.get('port', 9030),
user=self.db_config.get('user', 'root'),
password=self.db_config.get('password', ''),
database=self.db_config.get('database', 'smart_home'),
charset=self.db_config.get('charset', 'utf8mb4'),
cursorclass=DictCursor,
autocommit=True,
connect_timeout=5
)
return connection
except Exception as e:
if self.strict_mode:
raise DatabaseConnectionError(e)
raise
# 测试数据库连接
def test_connection(self) -> bool:
# 测试数据库连接
try:
connection = self.get_connection()
with connection.cursor() as cursor:
cursor.execute("SELECT 1")
result = cursor.fetchone()
connection.close()
if result:
return True
else:
if self.strict_mode:
raise DatabaseConnectionError()
return False
except Exception as e:
if self.strict_mode:
raise DatabaseConnectionError(e)
return False
# 获取数据库游标(上下文管理器)
@contextmanager
def get_cursor(self):
connection = self.get_connection()
cursor = connection.cursor()
try:
yield cursor
finally:
cursor.close()
connection.close()
# 执行查询SQL
def execute_query(self, sql: str, params: tuple = None) -> List[Dict[str, Any]]:
with self.get_cursor() as cursor:
cursor.execute(sql, params)
return cursor.fetchall()
# 执行更新SQLINSERT/UPDATE/DELETE
def execute_update(self, sql: str, params: tuple = None) -> int:
with self.get_cursor() as cursor:
affected = cursor.execute(sql, params)
return affected
# 批量执行SQL
def execute_many(self, sql: str, params_list: List[tuple]) -> int:
with self.get_cursor() as cursor:
affected = cursor.executemany(sql, params_list)
return affected
def init_database(strict_mode: bool = True) -> DatabaseConnection:
global db
if db is None:
db = DatabaseConnection(strict_mode=strict_mode)
db.test_connection()
return db
# 查询
def query(sql: str, params: tuple = None) -> List[Dict[str, Any]]:
if db is None:
raise DatabaseConnectionError()
return db.execute_query(sql, params)
# 更新
def update(sql: str, params: tuple = None) -> int:
if db is None:
raise DatabaseConnectionError()
return db.execute_update(sql, params)
# 插入
def insert(sql: str, params: tuple = None) -> int:
if db is None:
raise DatabaseConnectionError()
return db.execute_update(sql, params)
# 获取数据库类型
def get_db_type() -> str:
if db is None:
raise DatabaseConnectionError()
return db.db_type