Source code for agentscope.formatter._truncated_formatter_base

# -*- coding: utf-8 -*-
"""The truncated formatter base class, which allows to truncate the input
messages."""
from abc import ABC
from copy import deepcopy
from typing import (
    Any,
    Tuple,
    Literal,
    AsyncGenerator,
)

from ._formatter_base import FormatterBase
from ..message import Msg
from ..token import TokenCounterBase
from ..tracing import trace_format


[docs] class TruncatedFormatterBase(FormatterBase, ABC): """Base class for truncated formatters, which formats input messages into required formats with tokens under a specified limit."""
[docs] def __init__( self, token_counter: TokenCounterBase | None = None, max_tokens: int | None = None, ) -> None: """Initialize the TruncatedFormatterBase. Args: token_counter (`TokenCounterBase | None`, optional): A token counter instance used to count tokens in the messages. If not provided, the formatter will format the messages without considering token limits. max_tokens (`int | None`, optional): The maximum number of tokens allowed in the formatted messages. If not provided, the formatter will not truncate the messages. """ self.token_counter = token_counter assert ( max_tokens is None or 0 < max_tokens ), "max_tokens must be greater than 0" self.max_tokens = max_tokens
[docs] @trace_format async def format( self, msgs: list[Msg], **kwargs: Any, ) -> list[dict[str, Any]]: """Format the input messages into the required format. If token counter and max token limit are provided, the messages will be truncated to fit the limit. Args: msgs (`list[Msg]`): The input messages to be formatted. Returns: `list[dict[str, Any]]`: The formatted messages in the required format. """ # Check if the input messages are valid self.assert_list_of_msgs(msgs) msgs = deepcopy(msgs) while True: formatted_msgs = await self._format(msgs) n_tokens = await self._count(formatted_msgs) if ( n_tokens is None or self.max_tokens is None or n_tokens <= self.max_tokens ): return formatted_msgs # truncate the input messages msgs = await self._truncate(msgs)
[docs] async def _format(self, msgs: list[Msg]) -> list[dict[str, Any]]: """Format the input messages into the required format. This method should be implemented by the subclasses.""" formatted_msgs = [] start_index = 0 if len(msgs) > 0 and msgs[0].role == "system": formatted_msgs.append( await self._format_system_message(msgs[0]), ) start_index = 1 is_first_agent_message = True async for typ, group in self._group_messages(msgs[start_index:]): match typ: case "tool_sequence": formatted_msgs.extend( await self._format_tool_sequence(group), ) case "agent_message": formatted_msgs.extend( await self._format_agent_message( group, is_first_agent_message, ), ) is_first_agent_message = False return formatted_msgs
async def _format_system_message( self, msg: Msg, ) -> dict[str, Any]: """Format system message for the LLM API. .. note:: This is the default implementation. For certain LLM APIs with specific requirements, you may need to implement a custom formatting function to accommodate those particular needs. """ return { "role": "system", "content": msg.get_content_blocks("text"), }
[docs] async def _format_tool_sequence( self, msgs: list[Msg], ) -> list[dict[str, Any]]: """Given a sequence of tool call/result messages, format them into the required format for the LLM API.""" raise NotImplementedError( "_format_tool_sequence is not implemented", )
[docs] async def _format_agent_message( self, msgs: list[Msg], is_first: bool = True, ) -> list[dict[str, Any]]: """Given a sequence of messages without tool calls/results, format them into the required format for the LLM API.""" raise NotImplementedError( "_format_agent_message is not implemented", )
async def _truncate(self, msgs: list[Msg]) -> list[Msg]: """Truncate the input messages, so that it can fit the token limit. This function is called only when - both `token_counter` and `max_tokens` are provided, - the formatted output of the input messages exceeds the token limit. .. tip:: This function only provides a simple strategy, and developers can override this method to implement more sophisticated truncation strategies. .. note:: The tool call message should be truncated together with its corresponding tool result message to satisfy the LLM API requirements. Args: msgs (`list[Msg]`): The input messages to be truncated. Raises: `ValueError`: If the system prompt message already exceeds the token limit, or if there are tool calls without corresponding tool results. Returns: `list[Msg]`: The truncated messages. """ start_index = 0 if len(msgs) > 0 and msgs[0].role == "system": if len(msgs) == 1: # If the system prompt already exceeds the token limit, we # raise an error. raise ValueError( f"The system prompt message already exceeds the token " f"limit ({self.max_tokens} tokens).", ) start_index = 1 # Create a tool call IDs queues to delete the corresponding tool # result message tool_call_ids = set() for i in range(start_index, len(msgs)): msg = msgs[i] for block in msg.get_content_blocks("tool_use"): tool_call_ids.add(block["id"]) for block in msg.get_content_blocks("tool_result"): try: tool_call_ids.remove(block["id"]) except KeyError: pass # We can stop truncating if the queue is empty if len(tool_call_ids) == 0: return msgs[:start_index] + msgs[i + 1 :] if len(tool_call_ids) > 0: raise ValueError( "The input messages contains tool call(s) that do not have " f"the corresponding tool result(s): {tool_call_ids}. ", ) return msgs[:start_index] async def _count(self, msgs: list[dict[str, Any]]) -> int | None: """Count the number of tokens in the input messages. If token counter is not provided, `None` will be returned. Args: msgs (`list[Msg]`): The input messages to count tokens for. """ if self.token_counter is None: return None return await self.token_counter.count(msgs) @staticmethod async def _group_messages( msgs: list[Msg], ) -> AsyncGenerator[ Tuple[Literal["tool_sequence", "agent_message"], list[Msg]], None, ]: """Group the input messages into two types and yield them as a generator. The two types are: - agent message that doesn't contain tool calls/results, and - tool sequence that consisted of a sequence of tool calls/results .. note:: The group operation is used in multi-agent scenario, where multiple entities are involved in the input messages. So that to be compatible with tools API, we have to group the messages and format them with different strategies. Args: msgs (`list[Msg]`): The input messages to be grouped, where the system prompt message shouldn't be included. Yields: `AsyncGenerator[Tuple[str, list[Msg]], None]`: A generator that yields tuples of group type and the list of messages in that group. The group type can be either "tool_sequence" or "agent_message". """ group_type: Literal["tool_sequence", "agent_message"] | None = None group = [] for msg in msgs: if group_type is None: if msg.has_content_blocks( "tool_use", ) or msg.has_content_blocks("tool_result"): group_type = "tool_sequence" else: group_type = "agent_message" group.append(msg) continue # determine if this msg has the same type as the current group if group_type == "tool_sequence": if msg.has_content_blocks( "tool_use", ) or msg.has_content_blocks("tool_result"): group.append(msg) else: yield group_type, group group = [msg] group_type = "agent_message" elif group_type == "agent_message": if msg.has_content_blocks( "tool_use", ) or msg.has_content_blocks("tool_result"): yield group_type, group group = [msg] group_type = "tool_sequence" else: group.append(msg) if group_type: yield group_type, group