# -*- coding: utf-8 -*-
"""A2A agent implementation for AgentScope.
This module provides the A2A (Agent-to-Agent) protocol implementation,
enabling AgentScope agents to communicate with remote agents using the
A2A standard protocol.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Type
import httpx
from pydantic import BaseModel
from ._agent_base import AgentBase
from ..message import Msg
from ..formatter import A2AChatFormatter
if TYPE_CHECKING:
from a2a.types import AgentCard
from a2a.client import ClientConfig, Consumer
from a2a.client.client_factory import TransportProducer
else:
AgentCard = "a2a.types.AgentCard"
ClientConfig = "a2a.client.ClientConfig"
Consumer = "a2a.client.Consumer"
TransportProducer = "a2a.client.client_factory.TransportProducer"
[docs]
class A2AAgent(AgentBase):
"""An A2A agent implementation in AgentScope, which supports
- Communication with remote agents using the A2A protocol
- Bidirectional message conversion between AgentScope and A2A formats
- Task lifecycle management with streaming and polling
- Artifact handling and status tracking
.. note:: Due to the limitation of A2A protocol. The A2AAgent class
- Only support chatbot-scenario (a user and an assistant) interactions.
To support multi-agent interactions requires the server side to handle
the `name` field in the A2A messages properly.
- Does not support structured output in `reply()` method due to the lack
of structured output support in A2A protocol.
- Stores observed messages locally and merges them with input messages
when `reply()` is called. Observed messages are cleared after processing.
"""
[docs]
def __init__(
self,
agent_card: AgentCard,
client_config: ClientConfig | None = None,
consumers: list[Consumer] | None = None,
additional_transport_producers: dict[str, TransportProducer]
| None = None,
) -> None:
"""Initialize the A2A agent instance by the given agent card.
Args:
agent_card (`AgentCard`):
The agent card that contains the information about the remote
agent, such as its URL and capabilities.
client_config (`ClientConfig | None`, optional):
The configuration for the A2A client, including transport
preferences and streaming options.
consumers (`list[Consumer] | None`, optional):
The list of consumers for handling A2A client events.
These intercept request/response flows for logging,
metrics, and security.
additional_transport_producers (`dict[str, TransportProducer] | \
None`, optional):
Additional transport producers for creating A2A clients
with specific transport protocols.
"""
super().__init__()
from a2a.types import AgentCard
from a2a.client import ClientConfig, ClientFactory
if not isinstance(agent_card, AgentCard):
raise ValueError(
f"agent_card must be an instance of AgentCard, "
f"got {type(agent_card)}",
)
self.name: str = agent_card.name
self.agent_card = agent_card
# Create the client factory so that we can create clients later
# in reply()
self._a2a_client_factory = ClientFactory(
config=client_config
or ClientConfig(
httpx_client=httpx.AsyncClient(
timeout=httpx.Timeout(timeout=600),
),
),
consumers=consumers,
)
# Register additional transport producers if provided
if additional_transport_producers:
for label, producer in additional_transport_producers.items():
self._a2a_client_factory.register(
label,
producer,
)
# The variables to store observed messages
self._observed_msgs: list[Msg] = []
# The formatter used for message conversion
self.formatter = A2AChatFormatter()
[docs]
def state_dict(self) -> dict:
"""Get the state dictionary of the A2A agent.
Returns:
`dict`:
The state dictionary containing the observed messages.
"""
return {
"_observed_msgs": [msg.to_dict() for msg in self._observed_msgs],
}
[docs]
def load_state_dict(self, state_dict: dict, strict: bool = True) -> None:
"""Load the state dictionary into the module.
Args:
state_dict (`dict`):
The state dictionary to load.
strict (`bool`, defaults to `True`):
If `True`, raises an error if any key in the module is not
found in the state_dict. If `False`, skips missing keys.
Raises:
`KeyError`:
If a required key is missing in the state_dict when strict
is `True`.
"""
if "_observed_msgs" in state_dict:
self._observed_msgs = [
Msg.from_dict(d) for d in state_dict["_observed_msgs"]
]
else:
raise KeyError(
"_observed_msgs key not found in state_dict",
)
if strict:
for key in state_dict.keys():
if key != "_observed_msgs":
raise KeyError(f"Unexpected key {key} in state_dict")
[docs]
async def observe(self, msg: Msg | list[Msg] | None) -> None:
"""Receive the given message(s) without generating a reply.
The observed messages are stored and will be merged with the
input messages when `reply` is called. After `reply` completes,
the stored messages will be cleared.
Args:
msg (`Msg | list[Msg] | None`):
The message(s) to be observed. If None, no action is taken.
"""
if msg is None:
return
if isinstance(msg, Msg):
self._observed_msgs.append(msg)
elif isinstance(msg, list) and all(isinstance(m, Msg) for m in msg):
self._observed_msgs.extend(msg)
else:
raise TypeError(
f"msg must be a Msg or a list of Msg, got {type(msg)}",
)
[docs]
async def reply(
self,
msg: Msg | list[Msg] | None = None,
**kwargs: Any,
) -> Msg:
"""Send message(s) to the remote A2A agent and receive a response.
.. note:: This method merges any previously observed messages with the
input messages, sends them to the remote agent, and clears the
observed messages after processing.
.. note:: The A2A protocol does not support structured output, so the
`structured_model` parameter is not supported in this method.
Args:
msg (`Msg | list[Msg] | None`, optional):
The message(s) to send to the remote agent. Can be a single
Msg, a list of Msgs, or None. If None, only observed messages
will be sent. Defaults to None.
Returns:
`Msg`:
The response message from the remote agent. For tasks, this
may be either a status update message or the final artifacts
message, depending on the task state. If no messages are
provided (both msg and observed messages are empty), returns
a prompt message. If an error occurs during communication,
returns an error message.
"""
if "structured_model" in kwargs:
raise ValueError(
"structured_model is not supported in A2AAgent.reply() "
"due to the lack of structured output support in A2A "
"protocol.",
)
from a2a.types import Message as A2AMessage
# Merge observed messages with input messages
msgs_list = self._observed_msgs
if msg is not None:
if isinstance(msg, Msg):
msgs_list.append(msg)
else:
msgs_list.extend(msg)
# Create A2A client and send message
client = self._a2a_client_factory.create(
card=self.agent_card,
)
# Convert Msg objects into A2A Message object
a2a_message = await self.formatter.format([_ for _ in msgs_list if _])
response_msg = None
async for item in client.send_message(a2a_message):
if isinstance(item, A2AMessage):
response_msg = await self.formatter.format_a2a_message(
self.name,
item,
)
await self.print(response_msg)
elif isinstance(item, tuple):
task, _ = item
if task is not None:
for _ in await self.formatter.format_a2a_task(
self.name,
task,
):
await self.print(_)
response_msg = _
# Clear the observed messages after processing
self._observed_msgs.clear()
if response_msg:
return response_msg
raise ValueError(
"No response received from remote agent",
)
# pylint: disable=unused-argument
[docs]
async def handle_interrupt(
self,
msg: Msg | list[Msg] | None = None,
structured_model: Type[BaseModel] | None = None,
) -> Msg:
"""The post-processing logic when the reply is interrupted by the
user or something else.
"""
response_msg = Msg(
self.name,
"I noticed that you have interrupted me. What can I "
"do for you?",
"assistant",
metadata={
# Expose this field to indicate the interruption
"_is_interrupted": True,
},
)
await self.print(response_msg, True)
# Add to observed messages for context in next reply
self._observed_msgs.append(response_msg)
return response_msg