Source code for agentscope.formatter._ollama_formatter
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The Ollama formatter module."""
import base64
import os
from typing import Any
from urllib.parse import urlparse
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from .._utils._common import _get_bytes_from_web_url
from ..message import Msg, TextBlock, ImageBlock, ToolUseBlock, ToolResultBlock
from ..token import TokenCounterBase
def _convert_ollama_image_url_to_base64_data(url: str) -> str:
"""Convert image url to base64."""
parsed_url = urlparse(url)
if not os.path.exists(url) and parsed_url.scheme != "":
# Web url
data = _get_bytes_from_web_url(url)
return data
if os.path.exists(url):
# Local file
with open(url, "rb") as f:
data = base64.b64encode(f.read()).decode("utf-8")
return data
raise ValueError(
f"The URL `{url}` is not a valid image URL or local file.",
)
[docs]
class OllamaChatFormatter(TruncatedFormatterBase):
"""Formatter for Ollama messages."""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
[docs]
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into Ollama API format.
Args:
msgs (`list[Msg]`):
The list of message objects to format.
Returns:
`list[dict[str, Any]]`:
The formatted messages as a list of dictionaries.
"""
self.assert_list_of_msgs(msgs)
messages: list[dict] = []
for msg in msgs:
content_blocks: list = []
tool_calls = []
images = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
content_blocks.append({**block})
elif typ == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": block.get("input", {}),
},
},
)
elif typ == "tool_result":
messages.append(
{
"role": "tool",
"tool_call_id": block.get("id"),
"content": self.convert_tool_result_to_string(
block.get("output"), # type: ignore[arg-type]
),
"name": block.get("name"),
},
)
elif typ == "image":
source_type = block["source"]["type"]
if source_type == "url":
images.append(
_convert_ollama_image_url_to_base64_data(
block["source"]["url"],
),
)
elif source_type == "base64":
images.append(block["source"]["data"])
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
content_msg = "\n".join(
content.get("text", "") for content in content_blocks
)
msg_ollama = {
"role": msg.role,
"content": content_msg or None,
}
if tool_calls:
msg_ollama["tool_calls"] = tool_calls
if images:
msg_ollama["images"] = images
if msg_ollama["content"] or msg_ollama.get("tool_calls"):
messages.append(msg_ollama)
return messages
[docs]
class OllamaMultiAgentFormatter(TruncatedFormatterBase):
"""
Ollama formatter for multi-agent conversations, where more than
a user and an agent are involved.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
[docs]
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the Ollama multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
token_counter (`TokenCounterBase | None`, optional):
The token counter used for truncation.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If `None`, no truncation will be applied.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for the Ollama API."""
return {
"role": "system",
"content": msg.get_text_content(),
}
[docs]
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the Ollama API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the Ollama API.
"""
return await OllamaChatFormatter().format(msgs)
[docs]
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the Ollama API.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the ollama API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required Ollama format
formatted_msgs: list[dict] = []
# Collect the multimodal files
conversation_blocks: list = []
accumulated_text = []
images = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] == "image":
# Handle the accumulated text as a single block
source = block["source"]
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
accumulated_text.clear()
if source["type"] == "url":
images.append(
_convert_ollama_image_url_to_base64_data(
source["url"],
),
)
elif source["type"] == "base64":
images.append(source["data"])
conversation_blocks.append({**block})
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append({"text": "</history>"})
conversation_blocks_text = "\n".join(
conversation_block.get("text", "")
for conversation_block in conversation_blocks
)
user_message = {
"role": "user",
"content": conversation_blocks_text,
}
if images:
user_message["images"] = images
if conversation_blocks:
formatted_msgs.append(user_message)
return formatted_msgs