Source code for agentscope.formatter._gemini_formatter
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""Google gemini API formatter in agentscope."""
import base64
import os
from typing import Any
from urllib.parse import urlparse
from ._truncated_formatter_base import TruncatedFormatterBase
from .._utils._common import _get_bytes_from_web_url
from ..message import (
Msg,
TextBlock,
ImageBlock,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
VideoBlock,
)
from .._logging import logger
from ..token import TokenCounterBase
def _to_gemini_inline_data(url: str) -> dict:
"""Convert url into the Gemini API required format."""
parsed_url = urlparse(url)
extension = url.split(".")[-1].lower()
if not os.path.exists(url) and parsed_url.scheme != "":
# Web url
typ = None
for k, v in GeminiChatFormatter.supported_extensions.items():
if extension in v:
typ = k
break
if typ is None:
raise TypeError(
f"Unsupported file extension: {extension}, expected "
f"{GeminiChatFormatter.supported_extensions}",
)
data = _get_bytes_from_web_url(url)
return {
"data": data,
"mime_type": f"{typ}/{extension}",
}
elif os.path.exists(url):
# Local file
with open(url, "rb") as f:
data = base64.b64encode(f.read()).decode("utf-8")
return {
"data": data,
"mime_type": f"image/{extension}",
}
raise ValueError(
f"The URL `{url}` is not a valid image URL or local file.",
)
[docs]
class GeminiChatFormatter(TruncatedFormatterBase):
"""The formatter for Google Gemini API."""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
VideoBlock,
AudioBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
supported_extensions: list[str] = {
"image": ["png", "jpeg", "webp", "heic", "heif"],
"video": [
"mp4",
"mpeg",
"mov",
"avi",
"x-flv",
"mpg",
"webm",
"wmv",
"3gpp",
],
"audio": ["mp3", "wav", "aiff", "aac", "ogg", "flac"],
}
[docs]
async def _format(
self,
msgs: list[Msg],
) -> list[dict]:
"""Format message objects into Gemini API required format."""
self.assert_list_of_msgs(msgs)
messages: list = []
for msg in msgs:
parts = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
parts.append(
{
"text": block.get("text"),
},
)
elif typ == "tool_use":
parts.append(
{
"function_call": {
"id": block["id"],
"name": block["name"],
"args": block["input"],
},
},
)
elif typ == "tool_result":
text_output = self.convert_tool_result_to_string(
block["output"], # type: ignore[arg-type]
)
messages.append(
{
"role": "user",
"parts": [
{
"function_response": {
"id": block["id"],
"name": block["name"],
"response": {
"output": text_output,
},
},
},
],
},
)
elif typ in ["image", "audio", "video"]:
if block["source"]["type"] == "base64":
media_type = block["source"]["media_type"]
base64_data = block["source"]["data"]
parts.append(
{
"inline_data": {
"data": base64_data,
"mime_type": media_type,
},
},
)
elif block["source"]["type"] == "url":
parts.append(
{
"inline_data": _to_gemini_inline_data(
block["source"]["url"],
),
},
)
else:
logger.warning(
"Unsupported block type: %s in the message, skipped. ",
typ,
)
role = "model" if msg.role == "assistant" else "user"
if parts:
messages.append(
{
"role": role,
"parts": parts,
},
)
return messages
[docs]
class GeminiMultiAgentFormatter(TruncatedFormatterBase):
"""The multi-agent formatter for Google Gemini API, where more than a
user and an agent are involved.
.. note:: This formatter will combine previous messages (except tool
calls/results) into a history section in the first system message with
the conversation history prompt.
.. note:: For tool calls/results, they will be presented as separate
messages as required by the Gemini API. Therefore, the tool calls/
results messages are expected to be placed at the end of the input
messages.
.. tip:: Telling the assistant's name in the system prompt is very
important in multi-agent conversations. So that LLM can know who it
is playing as.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
VideoBlock,
AudioBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
[docs]
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the Gemini multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to be used for the conversation history section.
token_counter (`TokenCounterBase | None`, optional):
The token counter used for truncation.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If `None`, no truncation will be applied.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for the Gemini API."""
return {
"role": "user",
"parts": [
{
"text": msg.get_text_content(),
},
],
}
[docs]
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the Gemini API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the Gemini API.
"""
return await GeminiChatFormatter().format(msgs)
[docs]
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the Gemini API.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the Gemini API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into Gemini API required format
formatted_msgs: list = []
# Collect the multimodal files
conversation_parts: list = []
accumulated_text = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] in ["image", "video", "audio"]:
# handle the accumulated text as a single part if exists
if accumulated_text:
conversation_parts.append(
{
"text": "\n".join(accumulated_text),
},
)
accumulated_text.clear()
# handle the multimodal data
if block["source"]["type"] == "url":
conversation_parts.append(
{
"inline_data": _to_gemini_inline_data(
block["source"]["url"],
),
},
)
elif block["source"]["type"] == "base64":
media_type = block["source"]["media_type"]
base64_data = block["source"]["data"]
conversation_parts.append(
{
"inline_data": {
"data": base64_data,
"mime_type": media_type,
},
},
)
if accumulated_text:
conversation_parts.append(
{
"text": "\n".join(accumulated_text),
},
)
# Add prompt and <history></history> tags around conversation history
if conversation_parts:
if conversation_parts[0].get("text"):
conversation_parts[0]["text"] = (
conversation_history_prompt
+ "<history>"
+ conversation_parts[0]["text"]
)
else:
conversation_parts.insert(
0,
{"text": conversation_history_prompt + "<history>"},
)
if conversation_parts[-1].get("text"):
conversation_parts[-1]["text"] += "\n</history>"
else:
conversation_parts.append(
{"text": "</history>"},
)
formatted_msgs.append(
{
"role": "user",
"parts": conversation_parts,
},
)
return formatted_msgs