Files
moss-ai/agents/config_loader.py
雷雨 8635b84b2d init
2025-12-15 22:05:56 +08:00

347 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys
import os
import yaml
import pymysql
from pymysql.cursors import DictCursor
import logging
from typing import Optional, Dict, Any
from pathlib import Path
logger = logging.getLogger(__name__)
class DatabaseConnectionError(Exception):
"""数据库连接错误"""
pass
class ConfigLoadError(Exception):
"""配置加载错误"""
pass
class AgentConfigLoader:
"""Agent配置加载器类 - 支持从 config.yaml 和数据库加载配置"""
def __init__(self, config_path: str = None, strict_mode: bool = True):
"""
初始化配置加载器
Args:
config_path: YAML配置文件路径如果为None则自动查找
strict_mode: 严格模式如果为True则数据库连接失败时抛出异常
"""
# 自动查找 config.yaml
if config_path is None:
config_path = self._find_config_file()
self.config = self._load_yaml_config(config_path)
self.db_config = self.config.get('database', {}).get('starrocks', {})
self.agents_config = self.config.get('agents', {})
self.backend_config = self.config.get('backend', {})
self.logging_config = self.config.get('logging', {})
self.strict_mode = strict_mode
self._connection_tested = False
def _find_config_file(self) -> str:
"""自动查找 config.yaml 文件"""
# 1. 从当前文件所在目录向上查找
current_dir = Path(__file__).parent
# 向上查找最多3层
for _ in range(3):
config_path = current_dir / "config.yaml"
if config_path.exists():
logger.info(f"找到配置文件: {config_path}")
return str(config_path)
current_dir = current_dir.parent
# 2. 默认路径(项目根目录)
default_path = Path(__file__).parent.parent / "config.yaml"
if default_path.exists():
return str(default_path)
raise FileNotFoundError("未找到 config.yaml 配置文件")
def _load_yaml_config(self, config_path: str) -> dict:
"""加载YAML配置文件"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
return {}
def _get_db_connection(self):
"""获取数据库连接"""
try:
connection = pymysql.connect(
host=self.db_config.get('host', 'localhost'),
port=self.db_config.get('port', 9030),
user=self.db_config.get('user', 'root'),
password=self.db_config.get('password', ''),
database=self.db_config.get('database', 'smart_home'),
charset=self.db_config.get('charset', 'utf8mb4'),
cursorclass=DictCursor,
connect_timeout=5 # 5秒连接超时
)
self._connection_tested = True
return connection
except Exception as e:
error_msg = f"数据库连接失败: {e}"
logger.error(error_msg)
if self.strict_mode:
raise DatabaseConnectionError(error_msg) from e
raise
def get_default_ai_model_config(self) -> Optional[Dict[str, Any]]:
"""
获取默认的AI模型配置
Returns:
AI模型配置字典包含 model, api_key, api_base, temperature, max_tokens
"""
try:
conn = self._get_db_connection()
cursor = conn.cursor()
sql = """
SELECT model_name, provider, api_key, api_base, temperature, max_tokens
FROM ai_model_config
WHERE is_default = TRUE AND is_active = TRUE
LIMIT 1
"""
cursor.execute(sql)
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"成功从数据库加载AI模型配置: {result['model_name']}")
return {
'model': result['model_name'],
'api_key': result['api_key'],
'api_base': result['api_base'],
'temperature': float(result['temperature']),
'max_tokens': result['max_tokens']
}
else:
error_msg = "数据库中未找到默认AI模型配置"
logger.error(error_msg)
if self.strict_mode:
raise ConfigLoadError(error_msg)
return None
except (DatabaseConnectionError, ConfigLoadError):
raise
except Exception as e:
error_msg = f"获取AI模型配置失败: {e}"
logger.error(error_msg)
if self.strict_mode:
raise ConfigLoadError(error_msg) from e
return None
def get_agent_config(self, agent_code: str) -> Optional[Dict[str, Any]]:
"""
获取Agent配置
Args:
agent_code: Agent代码标识
Returns:
Agent配置字典
"""
try:
conn = self._get_db_connection()
cursor = conn.cursor()
sql = """
SELECT agent_code, agent_name, host, port, description, is_enabled
FROM agent_config
WHERE agent_code = %s AND is_enabled = TRUE
LIMIT 1
"""
cursor.execute(sql, (agent_code,))
result = cursor.fetchone()
cursor.close()
conn.close()
return result
except Exception as e:
logger.error(f"获取Agent配置失败: {e}")
return None
def get_agent_prompt(self, agent_code: str) -> Optional[str]:
"""
获取Agent的系统提示词
Args:
agent_code: Agent代码标识
Returns:
系统提示词文本
Raises:
ConfigLoadError: 严格模式下配置加载失败时抛出
"""
try:
conn = self._get_db_connection()
cursor = conn.cursor()
sql = """
SELECT prompt_text
FROM agent_prompt
WHERE agent_code = %s AND is_active = TRUE
ORDER BY id DESC
LIMIT 1
"""
cursor.execute(sql, (agent_code,))
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"成功从数据库加载 {agent_code} 的系统提示词")
return result['prompt_text']
else:
error_msg = f"数据库中未找到 {agent_code} 的系统提示词"
logger.error(error_msg)
if self.strict_mode:
raise ConfigLoadError(error_msg)
return None
except (DatabaseConnectionError, ConfigLoadError):
raise
except Exception as e:
error_msg = f"获取Agent提示词失败: {e}"
logger.error(error_msg)
if self.strict_mode:
raise ConfigLoadError(error_msg) from e
return None
def get_device_config(self, device_code: str) -> Optional[Dict[str, Any]]:
"""
获取设备配置
Args:
device_code: 设备代码
Returns:
设备配置字典
"""
try:
conn = self._get_db_connection()
cursor = conn.cursor()
sql = """
SELECT device_code, device_name, device_type, agent_code,
ip_address, token, model, extra_config, is_active
FROM device_config
WHERE device_code = %s AND is_active = TRUE
LIMIT 1
"""
cursor.execute(sql, (device_code,))
result = cursor.fetchone()
cursor.close()
conn.close()
return result
except Exception as e:
logger.error(f"获取设备配置失败: {e}")
return None
def get_xiaomi_account(self) -> Optional[Dict[str, Any]]:
"""
获取默认小米账号配置
Returns:
小米账号配置字典
"""
try:
conn = self._get_db_connection()
cursor = conn.cursor()
sql = """
SELECT username, password, region
FROM xiaomi_account
WHERE is_default = TRUE AND is_active = TRUE
LIMIT 1
"""
cursor.execute(sql)
result = cursor.fetchone()
cursor.close()
conn.close()
return result
except Exception as e:
logger.error(f"获取小米账号配置失败: {e}")
return None
# ==================== 新增:统一配置读取方法 ====================
def get_agent_host_port(self, agent_name: str) -> tuple[str, int]:
"""
从 config.yaml 获取 Agent 的 host 和 port
Args:
agent_name: Agent 名称 (如 'conductor', 'air_conditioner')
Returns:
(host, port) 元组
"""
agent_cfg = self.agents_config.get(agent_name, {})
host = agent_cfg.get('host', 'localhost')
port = agent_cfg.get('port', 12000)
return host, port
def get_backend_config_value(self, key: str, default: Any = None) -> Any:
"""
获取后端配置值
Args:
key: 配置键,支持点号分隔的路径 (如 'python.host')
default: 默认值
Returns:
配置值
"""
keys = key.split('.')
value = self.backend_config
for k in keys:
if isinstance(value, dict):
value = value.get(k)
else:
return default
return value if value is not None else default
def get_logging_config(self) -> Dict[str, Any]:
"""获取日志配置"""
return self.logging_config
def get_all_agents_config(self) -> Dict[str, Dict[str, Any]]:
"""获取所有 Agent 的配置"""
return self.agents_config
# 全局配置加载器实例
_config_loader = None
def get_config_loader(strict_mode: bool = True) -> AgentConfigLoader:
"""
获取全局配置加载器实例
Args:
strict_mode: 严格模式如果为True则数据库连接失败时抛出异常
"""
global _config_loader
if _config_loader is None:
_config_loader = AgentConfigLoader(strict_mode=strict_mode)
return _config_loader