init
This commit is contained in:
7
agents/data_mining_agent/__init__.py
Normal file
7
agents/data_mining_agent/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Data Mining Agent - Moss AI 用户行为数据挖掘与场景分析系统
|
||||
使用GMM算法分析用户智能家居使用习惯,提供个性化场景推荐
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
|
||||
118
agents/data_mining_agent/__main__.py
Normal file
118
agents/data_mining_agent/__main__.py
Normal file
@@ -0,0 +1,118 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
"""
|
||||
主入口点 - 用于 `uv run .` 或 `python -m data_mining_agent` 启动服务
|
||||
数据挖掘 Agent - 用户行为分析与场景推荐
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import click
|
||||
import logging
|
||||
import uvicorn
|
||||
from a2a.types import (
|
||||
AgentCapabilities,
|
||||
AgentCard,
|
||||
AgentSkill,
|
||||
)
|
||||
from a2a.server.apps import A2AStarletteApplication
|
||||
from a2a.server.request_handlers import DefaultRequestHandler
|
||||
import httpx
|
||||
from a2a.server.tasks import (
|
||||
BasePushNotificationSender,
|
||||
InMemoryPushNotificationConfigStore,
|
||||
InMemoryTaskStore,
|
||||
)
|
||||
|
||||
# 确保当前目录和父目录在 Python 路径中
|
||||
current_dir = Path(__file__).parent
|
||||
parent_dir = current_dir.parent
|
||||
if str(current_dir) not in sys.path:
|
||||
sys.path.insert(0, str(current_dir))
|
||||
if str(parent_dir) not in sys.path:
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
|
||||
from executor import DataMiningAgentExecutor
|
||||
from agent import DataMiningAgent
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--host", "host", default=None, help="服务主机地址(默认从config.yaml读取)")
|
||||
@click.option("--port", "port", default=None, type=int, help="服务端口(默认从config.yaml读取)")
|
||||
def main(host, port):
|
||||
"""Starts the Data Mining Agent server."""
|
||||
try:
|
||||
# 从配置文件读取 host 和 port(如果命令行未指定)
|
||||
if host is None or port is None:
|
||||
from config_loader import get_config_loader
|
||||
config_loader = get_config_loader(strict_mode=False)
|
||||
default_host, default_port = config_loader.get_agent_host_port('data_mining')
|
||||
host = host or default_host
|
||||
port = port or default_port
|
||||
|
||||
capabilities = AgentCapabilities(
|
||||
push_notifications=False,
|
||||
state_transition_history=False,
|
||||
streaming=False,
|
||||
)
|
||||
skill = AgentSkill(
|
||||
id="analyze_user_behavior",
|
||||
name="User Behavior Analysis & Scene Mining",
|
||||
description="使用GMM算法分析用户智能家居使用习惯,提供个性化场景推荐。从StarRocks数据库挖掘设备操作历史,识别用户行为模式。",
|
||||
tags=["data mining", "user behavior", "GMM clustering", "scene analysis", "personalization", "smart home"],
|
||||
examples=[
|
||||
"分析用户睡觉时的习惯",
|
||||
"查询用户起床后通常做什么",
|
||||
"我要出门了,推荐设备操作",
|
||||
"分析用户晚上回家的习惯",
|
||||
"查看数据挖掘Agent状态",
|
||||
],
|
||||
)
|
||||
agent_card = AgentCard(
|
||||
name="Data Mining Agent",
|
||||
description="智能家居用户行为数据挖掘与场景分析专家。使用高斯混合模型(GMM)对用户历史操作进行场景聚类,为Conductor Agent提供个性化推荐。",
|
||||
url=f"http://{host}:{port}/",
|
||||
version="1.0.0",
|
||||
default_input_modes=DataMiningAgent.SUPPORTED_CONTENT_TYPES,
|
||||
default_output_modes=DataMiningAgent.SUPPORTED_CONTENT_TYPES,
|
||||
capabilities=capabilities,
|
||||
skills=[skill],
|
||||
)
|
||||
|
||||
# --8<-- [start:DefaultRequestHandler]
|
||||
httpx_client = httpx.AsyncClient()
|
||||
push_config_store = InMemoryPushNotificationConfigStore()
|
||||
push_sender = BasePushNotificationSender(
|
||||
httpx_client=httpx_client, config_store=push_config_store
|
||||
)
|
||||
request_handler = DefaultRequestHandler(
|
||||
agent_executor=DataMiningAgentExecutor(),
|
||||
task_store=InMemoryTaskStore(),
|
||||
push_config_store=push_config_store,
|
||||
push_sender=push_sender,
|
||||
)
|
||||
server = A2AStarletteApplication(
|
||||
agent_card=agent_card, http_handler=request_handler
|
||||
)
|
||||
|
||||
logger.info(f"🚀 数据挖掘 Agent 启动成功")
|
||||
logger.info(f"📊 提供用户行为分析与场景挖掘服务")
|
||||
logger.info(f"🔗 服务地址: http://{host}:{port}/")
|
||||
|
||||
uvicorn.run(server.build(), host=host, port=port)
|
||||
# --8<-- [end:DefaultRequestHandler]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during server startup: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
198
agents/data_mining_agent/agent.py
Normal file
198
agents/data_mining_agent/agent.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from typing import Any
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加父目录到路径以导入config_loader
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from config_loader import get_config_loader
|
||||
|
||||
from tools import (
|
||||
query_user_scene_habits,
|
||||
get_data_mining_status,
|
||||
submit_user_feedback
|
||||
)
|
||||
|
||||
memory = MemorySaver()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataMiningAgent:
|
||||
SUPPORTED_CONTENT_TYPES = ['text', 'text/plain']
|
||||
|
||||
# 默认系统提示词(备用)
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
'你是一个专业的用户行为数据挖掘助手,负责分析智能家居系统中的用户使用习惯。'
|
||||
'你的主要职责是:'
|
||||
'1. 从StarRocks数据库中读取用户的设备操作历史'
|
||||
'2. 使用高斯混合模型(GMM)对用户行为进行场景聚类分析'
|
||||
'3. 识别用户的使用模式和习惯'
|
||||
'4. 为Conductor Agent提供个性化的场景推荐'
|
||||
''
|
||||
'工具使用指南:'
|
||||
'1. 场景习惯查询:当需要分析用户在特定场景下的习惯时,'
|
||||
' 调用 query_user_scene_habits 工具,传入用户查询(如"睡觉"、"起床"、"回家")'
|
||||
' 该工具会:'
|
||||
' - 从数据库读取用户最近N天的设备操作记录'
|
||||
' - 使用GMM算法进行场景聚类'
|
||||
' - 分析每个场景的设备操作特征'
|
||||
' - 匹配与用户查询最相关的场景'
|
||||
' - 返回该场景的设备操作建议'
|
||||
''
|
||||
'2. 状态查询:当需要了解数据挖掘Agent的运行状态时,'
|
||||
' 调用 get_data_mining_status 工具获取系统状态和统计信息'
|
||||
''
|
||||
'数据分析流程:'
|
||||
'第一步:特征提取'
|
||||
' - 从操作时间提取:小时、分钟、星期几、是否周末、时段特征'
|
||||
' - 从设备类型提取:设备类别编码'
|
||||
''
|
||||
'第二步:GMM聚类'
|
||||
' - 使用高斯混合模型对特征进行聚类'
|
||||
' - 自动确定最优聚类数量(2-5个场景)'
|
||||
' - 每个聚类代表一个用户使用场景'
|
||||
''
|
||||
'第三步:场景分析'
|
||||
' - 分析每个场景的时间特征(早上/下午/晚上/夜晚)'
|
||||
' - 统计每个场景中的设备操作频次'
|
||||
' - 提取最常见的操作和参数'
|
||||
''
|
||||
'第四步:场景匹配'
|
||||
' - 根据用户查询中的关键词匹配场景'
|
||||
' - 考虑时间特征和设备类型'
|
||||
' - 返回最相关场景的操作建议'
|
||||
''
|
||||
'数据不足处理:'
|
||||
'当历史数据不足时(少于10条记录),明确告知调用方:'
|
||||
' - 返回 status: "insufficient_data"'
|
||||
' - 建议使用通用最佳实践'
|
||||
' - Conductor Agent会启用保底方案(AI搜索)'
|
||||
''
|
||||
'响应格式:'
|
||||
'始终返回JSON格式的分析结果,包含:'
|
||||
' - status: 状态(success/insufficient_data/error)'
|
||||
' - message: 描述信息'
|
||||
' - matched_scene: 匹配的场景信息'
|
||||
' - recommendation: 具体的设备操作建议'
|
||||
' - all_scenes: 所有识别的场景列表'
|
||||
''
|
||||
'与Conductor Agent的协作:'
|
||||
'你是Conductor Agent的数据支持服务,专注于:'
|
||||
' - 提供基于历史数据的个性化建议'
|
||||
' - 识别用户的使用习惯和偏好'
|
||||
' - 当数据不足时,及时告知以便启用备选方案'
|
||||
''
|
||||
'始终以中文回复,提供清晰、结构化的分析结果。'
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
# 从数据库加载配置(严格模式:配置加载失败则退出)
|
||||
try:
|
||||
config_loader = get_config_loader(strict_mode=True)
|
||||
|
||||
# 加载AI模型配置
|
||||
ai_config = config_loader.get_default_ai_model_config()
|
||||
self.model = ChatOpenAI(
|
||||
model=ai_config['model'],
|
||||
api_key=ai_config['api_key'],
|
||||
base_url=ai_config['api_base'],
|
||||
temperature=ai_config['temperature'],
|
||||
)
|
||||
|
||||
# 加载系统提示词
|
||||
system_prompt = config_loader.get_agent_prompt('data_mining')
|
||||
if system_prompt:
|
||||
self.SYSTEM_PROMPT = system_prompt
|
||||
else:
|
||||
logger.warning("⚠️ 未找到数据挖掘Agent的系统提示词,使用默认提示词")
|
||||
self.SYSTEM_PROMPT = self.DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 配置加载失败: {e}")
|
||||
logger.error("⚠️ 请确保:")
|
||||
logger.error(" 1. StarRocks 数据库已启动")
|
||||
logger.error(" 2. 已执行数据库初始化脚本: data/init_config.sql 和 data/ai_config.sql")
|
||||
logger.error(" 3. config.yaml 中的数据库连接配置正确")
|
||||
raise SystemExit(1) from e
|
||||
|
||||
self.tools = [
|
||||
query_user_scene_habits,
|
||||
get_data_mining_status,
|
||||
submit_user_feedback
|
||||
]
|
||||
|
||||
self.graph = create_react_agent(
|
||||
self.model,
|
||||
tools=self.tools,
|
||||
checkpointer=memory,
|
||||
prompt=self.SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
async def invoke(self, query, context_id) -> dict[str, Any]:
|
||||
"""非流式调用,直接返回最终结果"""
|
||||
inputs = {'messages': [('user', query)]}
|
||||
config = {'configurable': {'thread_id': context_id}}
|
||||
|
||||
# 直接调用invoke,不使用stream
|
||||
result = self.graph.invoke(inputs, config)
|
||||
|
||||
return self.get_agent_response(config)
|
||||
|
||||
def _extract_text_from_message(self, msg: AIMessage | ToolMessage | Any) -> str:
|
||||
try:
|
||||
content = getattr(msg, 'content', None)
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and 'text' in part:
|
||||
parts.append(part['text'])
|
||||
if parts:
|
||||
return '\n'.join(parts)
|
||||
except Exception:
|
||||
pass
|
||||
return ''
|
||||
|
||||
def get_agent_response(self, config):
|
||||
current_state = self.graph.get_state(config)
|
||||
messages = current_state.values.get('messages') if hasattr(current_state, 'values') else None
|
||||
|
||||
# 优先返回最近一次AI消息内容(包含工具调用结果)
|
||||
if isinstance(messages, list) and messages:
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
ai_text = self._extract_text_from_message(msg)
|
||||
if ai_text:
|
||||
return {
|
||||
'is_task_complete': True,
|
||||
'require_user_input': False,
|
||||
'content': ai_text,
|
||||
}
|
||||
|
||||
# 回退到最后一条消息
|
||||
final_text = ''
|
||||
if isinstance(messages, list) and messages:
|
||||
last_msg = messages[-1]
|
||||
final_text = self._extract_text_from_message(last_msg)
|
||||
|
||||
if not final_text:
|
||||
return {
|
||||
'is_task_complete': False,
|
||||
'require_user_input': True,
|
||||
'content': '当前无法处理您的请求,请稍后重试。',
|
||||
}
|
||||
|
||||
return {
|
||||
'is_task_complete': True,
|
||||
'require_user_input': False,
|
||||
'content': final_text,
|
||||
}
|
||||
|
||||
96
agents/data_mining_agent/executor.py
Normal file
96
agents/data_mining_agent/executor.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import logging
|
||||
|
||||
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
from a2a.server.tasks import TaskUpdater
|
||||
from a2a.types import (
|
||||
InternalError,
|
||||
InvalidParamsError,
|
||||
Part,
|
||||
TaskState,
|
||||
TextPart,
|
||||
UnsupportedOperationError,
|
||||
)
|
||||
from a2a.utils import (
|
||||
new_agent_text_message,
|
||||
new_task,
|
||||
)
|
||||
from a2a.utils.errors import ServerError
|
||||
|
||||
from agent import DataMiningAgent
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataMiningAgentExecutor(AgentExecutor):
|
||||
"""Data Mining AgentExecutor."""
|
||||
|
||||
def __init__(self):
|
||||
self.agent = DataMiningAgent()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
) -> None:
|
||||
error = self._validate_request(context)
|
||||
if error:
|
||||
raise ServerError(error=InvalidParamsError())
|
||||
|
||||
query = context.get_user_input()
|
||||
task = context.current_task
|
||||
if not task:
|
||||
task = new_task(context.message) # type: ignore
|
||||
await event_queue.enqueue_event(task)
|
||||
updater = TaskUpdater(event_queue, task.id, task.context_id)
|
||||
try:
|
||||
# 使用非流式invoke方法
|
||||
result = await self.agent.invoke(query, task.context_id)
|
||||
|
||||
is_task_complete = result.get('is_task_complete', True)
|
||||
require_user_input = result.get('require_user_input', False)
|
||||
content = result.get('content', '处理完成')
|
||||
|
||||
if require_user_input:
|
||||
await updater.update_status(
|
||||
TaskState.input_required,
|
||||
new_agent_text_message(
|
||||
content,
|
||||
task.context_id,
|
||||
task.id,
|
||||
),
|
||||
final=True,
|
||||
)
|
||||
elif is_task_complete:
|
||||
await updater.add_artifact(
|
||||
[Part(root=TextPart(text=content))],
|
||||
name='data_mining_result',
|
||||
)
|
||||
await updater.complete()
|
||||
else:
|
||||
# 如果既不需要输入也未完成,设置为working状态
|
||||
await updater.update_status(
|
||||
TaskState.working,
|
||||
new_agent_text_message(
|
||||
content,
|
||||
task.context_id,
|
||||
task.id,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'An error occurred while processing the request: {e}')
|
||||
raise ServerError(error=InternalError()) from e
|
||||
|
||||
def _validate_request(self, context: RequestContext) -> bool:
|
||||
# 这里可以添加请求验证逻辑
|
||||
# 返回 True 表示有错误,False 表示验证通过
|
||||
return False
|
||||
|
||||
async def cancel(
|
||||
self, context: RequestContext, event_queue: EventQueue
|
||||
) -> None:
|
||||
raise ServerError(error=UnsupportedOperationError())
|
||||
|
||||
28
agents/data_mining_agent/pyproject.toml
Normal file
28
agents/data_mining_agent/pyproject.toml
Normal file
@@ -0,0 +1,28 @@
|
||||
[project]
|
||||
name = "data-mining-agent"
|
||||
version = "1.0.0"
|
||||
description = "Moss AI 用户行为数据挖掘与场景分析 Agent"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"a2a>=0.1.0",
|
||||
"click>=8.0.0",
|
||||
"uvicorn[standard]>=0.24.0",
|
||||
"httpx>=0.25.0",
|
||||
"PyYAML>=6.0.0",
|
||||
"starlette>=0.27.0",
|
||||
"pymysql>=1.0.0",
|
||||
"numpy>=1.24.0",
|
||||
"scikit-learn>=1.3.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = []
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["."]
|
||||
|
||||
1169
agents/data_mining_agent/tools.py
Normal file
1169
agents/data_mining_agent/tools.py
Normal file
File diff suppressed because it is too large
Load Diff
1679
agents/data_mining_agent/uv.lock
generated
Normal file
1679
agents/data_mining_agent/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user