# -*- coding: utf-8 -*-
"""The agent base class in agentscope."""
import asyncio
import json
from asyncio import Task
from collections import OrderedDict
from typing import Callable, Any, Sequence
import shortuuid
from ._agent_meta import _AgentMeta
from ..module import StateModule
from ..message import Msg
from ..types import AgentHookTypes
[docs]
class AgentBase(StateModule, metaclass=_AgentMeta):
"""Base class for asynchronous agents."""
id: str
"""The agent's unique identifier, generated using shortuuid."""
supported_hook_types: list[str] = [
"pre_reply",
"post_reply",
"pre_print",
"post_print",
"pre_observe",
"post_observe",
]
"""Supported hook types for the agent base class."""
_class_pre_reply_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
],
dict[str, Any] | None, # The modified kwargs or None
],
] = OrderedDict()
"""The class-level hook functions that will be called before the reply
function, taking `self` object, the input arguments as input, and
generating the modified arguments (if needed). Then input arguments of the
reply function will be re-organized into a keyword arguments dictionary.
If the one hook returns a new dictionary, the modified arguments will be
passed to the next hook or the original reply function."""
_class_post_reply_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
Msg, # output, the output message
],
Msg | None,
],
] = OrderedDict()
"""The class-level hook functions that will be called after the reply
function, which takes the `self` object and deep copied
positional and keyword arguments (args and kwargs), and the output message
as input. If the hook returns a message, the new message will be passed
to the next hook or the original reply function. Otherwise, the original
output will be passed instead."""
_class_pre_print_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
],
dict[str, Any] | None, # The modified kwargs or None
],
] = OrderedDict()
"""The class-level hook functions that will be called before printing,
which takes the `self` object, a deep copied arguments dictionary as input,
and output the modified arguments (if needed). """
_class_post_print_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
Any, # output, `None` if no output
],
Any,
],
] = OrderedDict()
"""The class-level hook functions that will be called after the speak
function, which takes the `self` object as input."""
_class_pre_observe_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
],
dict[str, Any] | None, # The modified kwargs or None
],
] = OrderedDict()
"""The class-level hook functions that will be called before the observe
function, which takes the `self` object and a deep copied input
arguments dictionary as input. To change the input arguments, the hook
function needs to output the modified arguments dictionary, which will be
used as the input of the next hook function or the original observe
function."""
_class_post_observe_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
None, # The output, `None` if no output
],
None,
],
] = OrderedDict()
"""The class-level hook functions that will be called after the observe
function, which takes the `self` object as input."""
[docs]
def __init__(self) -> None:
"""Initialize the agent."""
super().__init__()
self.id = shortuuid.uuid()
# The replying task and identify of the current replying
self._reply_task: Task | None = None
self._reply_id: str | None = None
# Initialize the instance-level hooks
self._instance_pre_print_hooks = OrderedDict()
self._instance_post_print_hooks = OrderedDict()
self._instance_pre_reply_hooks = OrderedDict()
self._instance_post_reply_hooks = OrderedDict()
self._instance_pre_observe_hooks = OrderedDict()
self._instance_post_observe_hooks = OrderedDict()
# The prefix used in streaming printing
self._stream_prefix = {}
# The subscribers that will receive the reply message by their
# `observe` method.
self._subscribers: list[AgentBase] = []
# We add this variable in case developers want to disable the console
# output of the agent, e.g., in a production environment.
self._disable_console_output: bool = False
[docs]
async def observe(self, msg: Msg | list[Msg] | None) -> None:
"""Receive the given message(s) without generating a reply.
Args:
msg (`Msg | list[Msg] | None`):
The message(s) to be observed.
"""
raise NotImplementedError(
f"The observe function is not implemented in"
f" {self.__class__.__name__} class.",
)
[docs]
async def reply(self, *args: Any, **kwargs: Any) -> Msg:
"""The main logic of the agent, which generates a reply based on the
current state and input arguments."""
raise NotImplementedError(
"The reply function is not implemented in "
f"{self.__class__.__name__} class.",
)
[docs]
async def print(self, msg: Msg, last: bool = True) -> None:
"""The function to display the message.
Args:
msg (`Msg`):
The message object to be printed.
last (`bool`, defaults to `True`):
Whether this is the last one in streaming messages. For
non-streaming message, this should always be `True`.
"""
if self._disable_console_output:
return
thinking_and_text_to_print = []
for block in msg.get_content_blocks():
prefix = self._stream_prefix.get(msg.id, "")
if block["type"] in ["text", "thinking"]:
block_type = block["type"]
format_prefix = "" if block_type == "text" else "(thinking)"
thinking_and_text_to_print.append(
f"{msg.name}{format_prefix}: {block[block_type]}",
)
to_print = "\n".join(thinking_and_text_to_print)
if len(to_print) > len(prefix):
print(to_print[len(prefix) :], end="")
self._stream_prefix[msg.id] = to_print
elif last:
if prefix:
if not prefix.endswith("\n"):
print(
"\n"
+ json.dumps(block, indent=4, ensure_ascii=False),
)
else:
print(json.dumps(block, indent=4, ensure_ascii=False))
else:
print(
f"{msg.name}: "
f"{json.dumps(block, indent=4, ensure_ascii=False)}",
)
if last and msg.id in self._stream_prefix:
last_prefix = self._stream_prefix.pop(msg.id)
if not last_prefix.endswith("\n"):
print()
[docs]
async def __call__(self, *args: Any, **kwargs: Any) -> Msg:
"""Call the reply function with the given arguments."""
self._reply_id = shortuuid.uuid()
reply_msg: Msg | None = None
try:
self._reply_task = asyncio.current_task()
reply_msg = await self.reply(*args, **kwargs)
except asyncio.CancelledError:
reply_msg = await self.handle_interrupt(*args, **kwargs)
finally:
# Broadcast the reply message to all subscribers
if reply_msg:
await self._broadcast_to_subscribers(reply_msg)
self._reply_task = None
return reply_msg
async def _broadcast_to_subscribers(
self,
msg: Msg | list[Msg] | None,
) -> None:
"""Broadcast the message to all subscribers."""
for subscriber in self._subscribers:
await subscriber.observe(msg)
[docs]
async def handle_interrupt(
self,
*args: Any,
**kwargs: Any,
) -> Msg:
"""The post-processing logic when the reply is interrupted by the
user or something else."""
raise NotImplementedError(
f"The handle_interrupt function is not implemented in "
f"{self.__class__.__name__}",
)
[docs]
async def interrupt(self, msg: Msg | list[Msg] | None = None) -> None:
"""Interrupt the current reply process."""
if self._reply_task and not self._reply_task.done():
self._reply_task.cancel(msg)
[docs]
def register_instance_hook(
self,
hook_type: AgentHookTypes,
hook_name: str,
hook: Callable,
) -> None:
"""Register a hook to the agent instance, which only takes effect
for the current instance.
Args:
hook_type (`str`):
The type of the hook, indicating where the hook is to be
triggered.
hook_name (`str`):
The name of the hook. If the name is already registered, the
hook will be overwritten.
hook (`Callable`):
The hook function.
"""
if not isinstance(self, AgentBase):
raise TypeError(
"The register_instance_hook method should be called on an "
f"instance of AsyncAgentBase, but got {self} of "
f"type {type(self)}.",
)
hooks = getattr(self, f"_instance_{hook_type}_hooks")
hooks[hook_name] = hook
[docs]
def remove_instance_hook(
self,
hook_type: AgentHookTypes,
hook_name: str,
) -> None:
"""Remove an instance-level hook from the agent instance.
Args:
hook_type (`AgentHookTypes`):
The type of the hook, indicating where the hook is to be
triggered.
hook_name (`str`):
The name of the hook to remove.
"""
if not isinstance(self, AgentBase):
raise TypeError(
"The remove_instance_hook method should be called on an "
f"instance of AsyncAgentBase, but got {self} of "
f"type {type(self)}.",
)
hooks = getattr(self, f"_instance_{hook_type}_hooks")
if hook_name in hooks:
del hooks[hook_name]
else:
raise ValueError(
f"Hook '{hook_name}' not found in '{hook_type}' hooks of "
f"{self.__class__.__name__} instance.",
)
[docs]
@classmethod
def register_class_hook(
cls,
hook_type: AgentHookTypes,
hook_name: str,
hook: Callable,
) -> None:
"""The universal function to register a hook to the agent class, which
will take effect for all instances of the class.
Args:
hook_type (`AgentHookTypes`):
The type of the hook, indicating where the hook is to be
triggered.
hook_name (`str`):
The name of the hook. If the name is already registered, the
hook will be overwritten.
hook (`Callable`):
The hook function.
"""
assert (
hook_type in cls.supported_hook_types
), f"Invalid hook type: {hook_type}"
hooks = getattr(cls, f"_class_{hook_type}_hooks")
hooks[hook_name] = hook
[docs]
@classmethod
def remove_class_hook(
cls,
hook_type: AgentHookTypes,
hook_name: str,
) -> None:
"""Remove a class-level hook from the agent class.
Args:
hook_type (`AgentHookTypes`):
The type of the hook, indicating where the hook is to be
triggered.
hook_name (`str`):
The name of the hook to remove.
"""
assert (
hook_type in cls.supported_hook_types
), f"Invalid hook type: {hook_type}"
hooks = getattr(cls, f"_class_{hook_type}_hooks")
if hook_name in hooks:
del hooks[hook_name]
else:
raise ValueError(
f"Hook '{hook_name}' not found in '{hook_type}' hooks of "
f"{cls.__name__} class.",
)
[docs]
@classmethod
def clear_class_hooks(
cls,
hook_type: AgentHookTypes | None = None,
) -> None:
"""Clear all class-level hooks.
Args:
hook_type (`AgentHookTypes`, optional):
The type of the hook to clear. If not specified, all
class-level hooks will be cleared.
"""
if hook_type is None:
for typ in cls.supported_hook_types:
hooks = getattr(cls, f"_class_{typ}_hooks")
hooks.clear()
else:
assert (
hook_type in cls.supported_hook_types
), f"Invalid hook type: {hook_type}"
hooks = getattr(cls, f"_class_{hook_type}_hooks")
hooks.clear()
[docs]
def clear_instance_hooks(
self,
hook_type: AgentHookTypes | None = None,
) -> None:
"""If `hook_type` is not specified, clear all instance-level hooks.
Otherwise, clear the specified type of instance-level hooks."""
if hook_type is None:
for typ in self.supported_hook_types:
if not hasattr(self, f"_instance_{typ}_hooks"):
raise ValueError(
f"Call super().__init__() in the constructor "
f"to initialize the instance-level hooks for "
f"{self.__class__.__name__}.",
)
hooks = getattr(self, f"_instance_{typ}_hooks")
hooks.clear()
else:
assert (
hook_type in self.supported_hook_types
), f"Invalid hook type: {hook_type}"
if not hasattr(self, f"_instance_{hook_type}_hooks"):
raise ValueError(
f"Call super().__init__() in the constructor "
f"to initialize the instance-level hooks for "
f"{self.__class__.__name__}.",
)
hooks = getattr(self, f"_instance_{hook_type}_hooks")
hooks.clear()
[docs]
def reset_subscribers(self, subscribers: Sequence["AgentBase"]) -> None:
"""Reset the subscribers of the agent.
Args:
subscribers (`list[AgentBase]`):
A list of agents that will receive the reply message from
this agent via their `observe` method.
"""
self._subscribers = [_ for _ in subscribers if _ != self]
[docs]
def disable_console_output(self) -> None:
"""This function will disable the console output of the agent, e.g.
in a production environment to avoid messy logs."""
self._disable_console_output = True