Files
moss-ai/agents/conductor_agent/executor.py
雷雨 8635b84b2d init
2025-12-15 22:05:56 +08:00

97 lines
3.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
from a2a.server.tasks import TaskUpdater
from a2a.types import (
InternalError,
InvalidParamsError,
Part,
TaskState,
TextPart,
UnsupportedOperationError,
)
from a2a.utils import (
new_agent_text_message,
new_task,
)
from a2a.utils.errors import ServerError
from agent import ConductorAgent
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ConductorAgentExecutor(AgentExecutor):
"""Conductor AgentExecutor - 总管理代理执行器."""
def __init__(self):
self.agent = ConductorAgent()
async def execute(
self,
context: RequestContext,
event_queue: EventQueue,
) -> None:
error = self._validate_request(context)
if error:
raise ServerError(error=InvalidParamsError())
query = context.get_user_input()
task = context.current_task
if not task:
task = new_task(context.message) # type: ignore
await event_queue.enqueue_event(task)
updater = TaskUpdater(event_queue, task.id, task.context_id)
try:
# 使用非流式invoke方法
result = await self.agent.invoke(query, task.context_id)
is_task_complete = result.get('is_task_complete', True)
require_user_input = result.get('require_user_input', False)
content = result.get('content', '处理完成')
if require_user_input:
await updater.update_status(
TaskState.input_required,
new_agent_text_message(
content,
task.context_id,
task.id,
),
final=True,
)
elif is_task_complete:
await updater.add_artifact(
[Part(root=TextPart(text=content))],
name='conductor_result',
)
await updater.complete()
else:
# 如果既不需要输入也未完成设置为working状态
await updater.update_status(
TaskState.working,
new_agent_text_message(
content,
task.context_id,
task.id,
),
)
except Exception as e:
logger.error(f'An error occurred while processing the request: {e}')
raise ServerError(error=InternalError()) from e
def _validate_request(self, context: RequestContext) -> bool:
# 这里可以添加请求验证逻辑
# 返回 True 表示有错误False 表示验证通过
return False
async def cancel(
self, context: RequestContext, event_queue: EventQueue
) -> None:
raise ServerError(error=UnsupportedOperationError())