Source code for agentscope.agent._react_agent

# -*- coding: utf-8 -*-
# pylint: disable=not-an-iterable
# mypy: disable-error-code="list-item"
"""ReAct agent class in agentscope."""
import asyncio
from typing import Type, Any, AsyncGenerator, Literal

import shortuuid
from pydantic import BaseModel, ValidationError

from ._react_agent_base import ReActAgentBase
from ..formatter import FormatterBase
from ..memory import MemoryBase, LongTermMemoryBase, InMemoryMemory
from ..message import Msg, ToolUseBlock, ToolResultBlock, TextBlock
from ..model import ChatModelBase
from ..tool import Toolkit, ToolResponse
from ..tracing import trace_reply


def finish_function_pre_print_hook(
    self: "ReActAgent",
    kwargs: dict[str, Any],
) -> dict[str, Any] | None:
    """A pre-speak hook function that check if finish_function is called. If
    so, it will wrap the response argument into a message and return it to
    replace the original message. By this way, the calling of the finish
    function will be displayed as a text reply instead of a tool call."""

    msg = kwargs["msg"]

    if isinstance(msg.content, str):
        return None

    if isinstance(msg.content, list):
        for i, block in enumerate(msg.content):
            if (
                block["type"] == "tool_use"
                and block["name"] == self.finish_function_name
            ):
                # Convert the response argument into a text block for
                # displaying
                try:
                    msg.content[i] = TextBlock(
                        type="text",
                        text=block["input"].get("response", ""),
                    )
                    return kwargs
                except Exception:
                    print("Error in block input", block["input"])

    return None


[docs] class ReActAgent(ReActAgentBase): """A ReAct agent implementation in AgentScope, which supports - Realtime steering - API-based (parallel) tool calling - Hooks around reasoning, acting, reply, observe and print functions - Structured output generation """ finish_function_name: str = "generate_response" """The function name used to finish replying and return a response to the user."""
[docs] def __init__( self, name: str, sys_prompt: str, model: ChatModelBase, formatter: FormatterBase, toolkit: Toolkit | None = None, memory: MemoryBase | None = None, long_term_memory: LongTermMemoryBase | None = None, long_term_memory_mode: Literal[ "agent_control", "static_control", "both", ] = "both", enable_meta_tool: bool = False, parallel_tool_calls: bool = False, max_iters: int = 10, ) -> None: """Initialize the ReAct agent Args: name (`str`): The name of the agent. sys_prompt (`str`): The system prompt of the agent. model (`ChatModelBase`): The chat model used by the agent. formatter (`FormatterBase`): The formatter used to format the messages into the required format of the model API provider. toolkit (`Toolkit | None`, optional): A `Toolkit` object that contains the tool functions. If not provided, a default empty `Toolkit` will be created. memory (`MemoryBase | None`, optional): The memory used to store the dialogue history. If not provided, a default `InMemoryMemory` will be created, which stores messages in a list in memory. long_term_memory (`LongTermMemoryBase | None`, optional): The optional long-term memory, which will provide two tool functions: `retrieve_from_memory` and `record_to_memory`, and will attach the retrieved information to the system prompt before each reply. enable_meta_tool (`bool`, defaults to `False`): If `True`, a meta tool function `reset_equipped_tools` will be added to the toolkit, which allows the agent to manage its equipped tools dynamically. long_term_memory_mode (`Literal['agent_control', 'static_control',\ 'both']`, defaults to `both`): The mode of the long-term memory. If `agent_control`, two tool functions `retrieve_from_memory` and `record_to_memory` will be registered in the toolkit to allow the agent to manage the long-term memory. If `static_control`, retrieving and recording will happen in the beginning and end of each reply respectively. parallel_tool_calls (`bool`, defaults to `False`): When LLM generates multiple tool calls, whether to execute them in parallel. max_iters (`int`, defaults to `10`): The maximum number of iterations of the reasoning-acting loops. """ super().__init__() assert long_term_memory_mode in [ "agent_control", "static_control", "both", ] # Static variables in the agent self.name = name self._sys_prompt = sys_prompt self.max_iters = max_iters self.model = model self.formatter = formatter # Record the dialogue history in the memory self.memory = memory or InMemoryMemory() # If provide the long-term memory, it will be used to retrieve info # in the beginning of each reply, and the result will be added to the # system prompt self.long_term_memory = long_term_memory # The long-term memory mode self._static_control = long_term_memory and long_term_memory_mode in [ "static_control", "both", ] self._agent_control = long_term_memory and not self._static_control # If None, a default Toolkit will be created self.toolkit = toolkit or Toolkit() self.toolkit.register_tool_function( getattr(self, self.finish_function_name), ) if self._agent_control: # Adding two tool functions into the toolkit to allow self-control self.toolkit.register_tool_function( long_term_memory.record_to_memory, ) self.toolkit.register_tool_function( long_term_memory.retrieve_from_memory, ) # Add a meta tool function to allow agent-controlled tool management if enable_meta_tool: self.toolkit.register_tool_function( self.toolkit.reset_equipped_tools, ) self.parallel_tool_calls = parallel_tool_calls self.max_iters = max_iters # Variables to record the intermediate state # If required structured output model is provided self._required_structured_model: Type[BaseModel] | None = None # Register the status variables self.register_state("name") self.register_state("_sys_prompt") self.register_instance_hook( "pre_print", "finish_function_pre_print_hook", finish_function_pre_print_hook, )
@property def sys_prompt(self) -> str: """The dynamic system prompt of the agent.""" return self._sys_prompt
[docs] @trace_reply async def reply( self, msg: Msg | list[Msg] | None = None, structured_model: Type[BaseModel] | None = None, ) -> Msg: """Generate a reply based on the current state and input arguments. Args: msg (`Msg | list[Msg] | None`, optional): The input message(s) to the agent. structured_model (`Type[BaseModel] | None`, optional): The required structured output model. If provided, the agent is expected to generate structured output in the `metadata` field of the output message. Returns: `Msg`: The output message generated by the agent. """ await self.memory.add(msg) # Long-term memory retrieval if self._static_control: # Retrieve information from the long-term memory if available retrieved_info = await self.long_term_memory.retrieve(msg) if retrieved_info: await self.memory.add( Msg( name="long_term_memory", content="<long_term_memory>The content below are " "retrieved from long-term memory, which maybe " f"useful:\n{retrieved_info}</long_term_memory>", role="user", ), ) self._required_structured_model = structured_model # Record structured output model if provided if structured_model: self.toolkit.set_extended_model( self.finish_function_name, structured_model, ) # The reasoning-acting loop reply_msg = None for _ in range(self.max_iters): msg_reasoning = await self._reasoning() futures = [ self._acting(tool_call) for tool_call in msg_reasoning.get_content_blocks( "tool_use", ) ] # Parallel tool calls or not if self.parallel_tool_calls: acting_responses = await asyncio.gather(*futures) else: # Sequential tool calls acting_responses = [await _ for _ in futures] # Find the first non-None replying message from the acting for acting_msg in acting_responses: reply_msg = reply_msg or acting_msg if reply_msg: break # When the maximum iterations are reached if reply_msg is None: reply_msg = await self._summarizing() # Post-process the memory, long-term memory if self._static_control: await self.long_term_memory.record( [ *([*msg] if isinstance(msg, list) else [msg]), *await self.memory.get_memory(), reply_msg, ], ) await self.memory.add(reply_msg) return reply_msg
async def _reasoning( self, ) -> Msg: """Perform the reasoning process.""" prompt = await self.formatter.format( msgs=[ Msg("system", self.sys_prompt, "system"), *await self.memory.get_memory(), ], ) res = await self.model( prompt, tools=self.toolkit.get_json_schemas(), ) # handle output from the model interrupted_by_user = False msg = None try: if self.model.stream: msg = Msg(self.name, [], "assistant") async for content_chunk in res: msg.content = content_chunk.content await self.print(msg, False) await self.print(msg, True) else: msg = Msg(self.name, list(res.content), "assistant") await self.print(msg, True) return msg except asyncio.CancelledError as e: interrupted_by_user = True raise e from None finally: if msg and not msg.has_content_blocks("tool_use"): # Turn plain text response into a tool call of the finish # function msg.content = [ ToolUseBlock( id=shortuuid.uuid(), type="tool_use", name=self.finish_function_name, input={"response": msg.get_text_content()}, ), ] # None will be ignored by the memory await self.memory.add(msg) # Post-process for user interruption if interrupted_by_user and msg: # Fake tool results tool_use_blocks: list = msg.get_content_blocks( "tool_use", ) for tool_call in tool_use_blocks: msg_res = Msg( "system", [ ToolResultBlock( type="tool_result", id=tool_call["id"], name=tool_call["name"], output="The tool call has been interrupted " "by the user.", ), ], "system", ) await self.memory.add(msg_res) await self.print(msg_res, True) async def _acting(self, tool_call: ToolUseBlock) -> Msg | None: """Perform the acting process. Args: tool_call (`ToolUseBlock`): The tool use block to be executed. Returns: `Union[Msg, None]`: Return a message to the user if the `_finish_function` is called, otherwise return `None`. """ tool_res_msg = Msg( "system", [ ToolResultBlock( type="tool_result", id=tool_call["id"], name=tool_call["name"], output=[], ), ], "system", ) try: # Execute the tool call tool_res = await self.toolkit.call_tool_function(tool_call) response_msg = None # Async generator handling async for chunk in tool_res: # Turn into a tool result block tool_res_msg.content[0][ # type: ignore[index] "output" ] = chunk.content # Skip the printing of the finish function call if ( tool_call["name"] != self.finish_function_name or tool_call["name"] == self.finish_function_name and not chunk.metadata.get("success") ): await self.print(tool_res_msg, chunk.is_last) # Return message if generate_response is called successfully if tool_call[ "name" ] == self.finish_function_name and chunk.metadata.get( "success", True, ): response_msg = chunk.metadata.get("response_msg") return response_msg finally: # Record the tool result message in the memory await self.memory.add(tool_res_msg)
[docs] async def observe(self, msg: Msg | list[Msg] | None) -> None: """Receive observing message(s) without generating a reply. Args: msg (`Msg | list[Msg] | None`): The message or messages to be observed. """ await self.memory.add(msg)
async def _summarizing(self) -> Msg: """Generate a response when the agent fails to solve the problem in the maximum iterations.""" hint_msg = Msg( "user", "You have failed to generate response within the maximum " "iterations. Now respond directly by summarizing the current " "situation.", role="user", ) # Generate a reply by summarizing the current situation prompt = await self.formatter.format( [ Msg("system", self.sys_prompt, "system"), *await self.memory.get_memory(), hint_msg, ], ) # TODO: handle the structured output here, maybe force calling the # finish_function here res = await self.model(prompt) res_msg = Msg(self.name, [], "assistant") if isinstance(res, AsyncGenerator): async for chunk in res: res_msg.content = chunk.content await self.print(res_msg, False) await self.print(res_msg, True) else: res_msg.content = res.content await self.print(res_msg, True) return res_msg
[docs] async def handle_interrupt( self, _msg: Msg | list[Msg] | None = None, ) -> Msg: """The post-processing logic when the reply is interrupted by the user or something else.""" response_msg = Msg( self.name, "I noticed that you have interrupted me. What can I " "do for you?", "assistant", metadata={}, ) await self.print(response_msg, True) await self.memory.add(response_msg) await self.memory.clear() return response_msg
[docs] def generate_response( self, response: str, **kwargs: Any, ) -> ToolResponse: """Generate a response. Note only the input argument `response` is visible to the others, you should include all the necessary information in the `response` argument. Args: response (`str`): Your response to the user. """ response_msg = Msg( self.name, response, "assistant", ) # Prepare structured output if self._required_structured_model: try: # Use the metadata field of the message to store the # structured output response_msg.metadata = ( self._required_structured_model.model_validate( kwargs, ).model_dump() ) except ValidationError as e: return ToolResponse( content=[ TextBlock( type="text", text=f"Arguments Validation Error: {e}", ), ], metadata={ "success": False, "response_msg": None, }, ) return ToolResponse( content=[ TextBlock( type="text", text="Successfully generated response.", ), ], metadata={ "success": True, "response_msg": response_msg, }, is_last=True, )