97 lines
3.0 KiB
Python
97 lines
3.0 KiB
Python
|
|
import logging
|
|||
|
|
|
|||
|
|
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
|||
|
|
from a2a.server.events import EventQueue
|
|||
|
|
from a2a.server.tasks import TaskUpdater
|
|||
|
|
from a2a.types import (
|
|||
|
|
InternalError,
|
|||
|
|
InvalidParamsError,
|
|||
|
|
Part,
|
|||
|
|
TaskState,
|
|||
|
|
TextPart,
|
|||
|
|
UnsupportedOperationError,
|
|||
|
|
)
|
|||
|
|
from a2a.utils import (
|
|||
|
|
new_agent_text_message,
|
|||
|
|
new_task,
|
|||
|
|
)
|
|||
|
|
from a2a.utils.errors import ServerError
|
|||
|
|
|
|||
|
|
from agent import ConductorAgent
|
|||
|
|
|
|||
|
|
|
|||
|
|
logging.basicConfig(level=logging.INFO)
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ConductorAgentExecutor(AgentExecutor):
|
|||
|
|
"""Conductor AgentExecutor - 总管理代理执行器."""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self.agent = ConductorAgent()
|
|||
|
|
|
|||
|
|
async def execute(
|
|||
|
|
self,
|
|||
|
|
context: RequestContext,
|
|||
|
|
event_queue: EventQueue,
|
|||
|
|
) -> None:
|
|||
|
|
error = self._validate_request(context)
|
|||
|
|
if error:
|
|||
|
|
raise ServerError(error=InvalidParamsError())
|
|||
|
|
|
|||
|
|
query = context.get_user_input()
|
|||
|
|
task = context.current_task
|
|||
|
|
if not task:
|
|||
|
|
task = new_task(context.message) # type: ignore
|
|||
|
|
await event_queue.enqueue_event(task)
|
|||
|
|
updater = TaskUpdater(event_queue, task.id, task.context_id)
|
|||
|
|
try:
|
|||
|
|
# 使用非流式invoke方法
|
|||
|
|
result = await self.agent.invoke(query, task.context_id)
|
|||
|
|
|
|||
|
|
is_task_complete = result.get('is_task_complete', True)
|
|||
|
|
require_user_input = result.get('require_user_input', False)
|
|||
|
|
content = result.get('content', '处理完成')
|
|||
|
|
|
|||
|
|
if require_user_input:
|
|||
|
|
await updater.update_status(
|
|||
|
|
TaskState.input_required,
|
|||
|
|
new_agent_text_message(
|
|||
|
|
content,
|
|||
|
|
task.context_id,
|
|||
|
|
task.id,
|
|||
|
|
),
|
|||
|
|
final=True,
|
|||
|
|
)
|
|||
|
|
elif is_task_complete:
|
|||
|
|
await updater.add_artifact(
|
|||
|
|
[Part(root=TextPart(text=content))],
|
|||
|
|
name='conductor_result',
|
|||
|
|
)
|
|||
|
|
await updater.complete()
|
|||
|
|
else:
|
|||
|
|
# 如果既不需要输入也未完成,设置为working状态
|
|||
|
|
await updater.update_status(
|
|||
|
|
TaskState.working,
|
|||
|
|
new_agent_text_message(
|
|||
|
|
content,
|
|||
|
|
task.context_id,
|
|||
|
|
task.id,
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f'An error occurred while processing the request: {e}')
|
|||
|
|
raise ServerError(error=InternalError()) from e
|
|||
|
|
|
|||
|
|
def _validate_request(self, context: RequestContext) -> bool:
|
|||
|
|
# 这里可以添加请求验证逻辑
|
|||
|
|
# 返回 True 表示有错误,False 表示验证通过
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
async def cancel(
|
|||
|
|
self, context: RequestContext, event_queue: EventQueue
|
|||
|
|
) -> None:
|
|||
|
|
raise ServerError(error=UnsupportedOperationError())
|
|||
|
|
|