# -*- coding: utf-8 -*-
# mypy: disable-error-code="dict-item"
"""The Google Gemini model in agentscope."""
from datetime import datetime
from typing import AsyncGenerator, Any, TYPE_CHECKING, AsyncIterator, Literal
from ..message import ToolUseBlock, TextBlock, ThinkingBlock
from ._model_usage import ChatUsage
from ._model_base import ChatModelBase
from ._model_response import ChatResponse
from ..tracing import trace_llm
from ..types import JSONSerializableObject
if TYPE_CHECKING:
from google.genai.types import GenerateContentResponse
else:
GenerateContentResponse = "google.genai.types.GenerateContentResponse"
[docs]
class GeminiChatModel(ChatModelBase):
"""The Google Gemini chat model class in agentscope."""
[docs]
def __init__(
self,
model_name: str,
api_key: str,
stream: bool = True,
thinking_config: dict | None = None,
client_args: dict = None,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
) -> None:
"""Initialize the Gemini chat model.
Args:
model_name (`str`):
The name of the Gemini model to use, e.g. "gemini-2.5-flash".
api_key (`str`):
The API key for Google Gemini.
stream (`bool`, default `True`):
Whether to use streaming output or not.
thinking_config (`dict | None`, optional):
Thinking config, supported models are 2.5 Pro, 2.5 Flash, etc.
Refer to https://ai.google.dev/gemini-api/docs/thinking for
more details.
.. code-block:: python
:caption: Example of thinking_config
{
"include_thoughts": True, # enable thoughts or not
"thinking_budget": 1024 # Max tokens for reasoning
}
client_args (`dict`, default `None`):
The extra keyword arguments to initialize the OpenAI client.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments used in Gemini API generation,
e.g. `temperature`, `seed`.
"""
try:
from google import genai
except ImportError as e:
raise ImportError(
"Please install gemini Python sdk with "
"`pip install -q -U google-genai`",
) from e
super().__init__(model_name, stream)
self.client = genai.Client(
api_key=api_key,
**(client_args or {}),
)
self.thinking_config = thinking_config
self.generate_kwargs = generate_kwargs or {}
[docs]
@trace_llm
async def __call__(
self,
messages: list[dict],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "any", "required"]
| str
| None = None,
**config_kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""Call the Gemini model with the provided arguments.
Args:
messages (`list[dict[str, Any]]`):
A list of dictionaries, where `role` and `content` fields are
required.
tools (`list[dict] | None`, default `None`):
The tools JSON schemas that the model can use.
tool_choice (`Literal["auto", "none", "any", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "any", "required", or specific tool
name. For more details, please refer to
https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes
**config_kwargs (`Any`):
The keyword arguments for Gemini chat completions API.
"""
config: dict = {
"thinking_config": self.thinking_config,
**self.generate_kwargs,
**config_kwargs,
}
if tools:
config["tools"] = self._format_tools_json_schemas(tools)
if tool_choice:
self._validate_tool_choice(tool_choice, tools)
config["tool_config"] = self._format_tool_choice(tool_choice)
# Prepare the arguments for the Gemini API call
kwargs: dict[str, JSONSerializableObject] = {
"model": self.model_name,
"contents": messages,
"config": config,
}
start_datetime = datetime.now()
if self.stream:
response = await self.client.aio.models.generate_content_stream(
**kwargs,
)
return self._parse_gemini_stream_generation_response(
start_datetime,
response,
)
# non-streaming
response = await self.client.aio.models.generate_content(
**kwargs,
)
parsed_response = self._parse_gemini_generation_response(
used_time=(datetime.now() - start_datetime).total_seconds(),
response=response,
)
return parsed_response
async def _parse_gemini_stream_generation_response(
self,
start_datetime: datetime,
response: AsyncIterator[GenerateContentResponse],
) -> AsyncGenerator[ChatResponse, None]:
"""Parse the Gemini streaming generation response into ChatResponse"""
parsed_chunk = None
text = ""
thinking = ""
async for chunk in response:
content_block: list = []
# Thinking parts
if (
chunk.candidates
and chunk.candidates[0].content
and chunk.candidates[0].content.parts
):
for part in chunk.candidates[0].content.parts:
if part.thought and part.text:
thinking += part.text
# Text parts
if chunk.text:
text += chunk.text
# Function calls
tool_calls = []
if chunk.function_calls:
for function_call in chunk.function_calls:
tool_calls.append(
ToolUseBlock(
type="tool_use",
id=function_call.id,
name=function_call.name,
input=function_call.args or {},
),
)
usage = None
if chunk.usage_metadata:
usage = ChatUsage(
input_tokens=chunk.usage_metadata.prompt_token_count,
output_tokens=chunk.usage_metadata.total_token_count
- chunk.usage_metadata.prompt_token_count,
time=(datetime.now() - start_datetime).total_seconds(),
)
if thinking:
content_block.append(
ThinkingBlock(
type="thinking",
thinking=thinking,
),
)
if text:
content_block.append(
TextBlock(
type="text",
text=text,
),
)
content_block.extend(
[
*tool_calls,
],
)
parsed_chunk = ChatResponse(
content=content_block,
usage=usage,
)
yield parsed_chunk
def _parse_gemini_generation_response(
self,
used_time: float,
response: GenerateContentResponse,
) -> ChatResponse:
"""Parse the Gemini generation response into ChatResponse"""
content: list = []
if (
response.candidates
and response.candidates[0].content
and response.candidates[0].content.parts
):
for part in response.candidates[0].content.parts:
if part.thought and part.text:
content.append(
ThinkingBlock(
type="thinking",
thinking=part.text,
),
)
if response.text:
content.append(
TextBlock(
type="text",
text=response.text,
),
)
if response.function_calls:
for tool_call in response.function_calls:
content.append(
ToolUseBlock(
type="tool_use",
id=tool_call.id,
name=tool_call.name,
input=tool_call.args or {},
),
)
if response.usage_metadata:
usage = ChatUsage(
input_tokens=response.usage_metadata.prompt_token_count,
output_tokens=response.usage_metadata.total_token_count
- response.usage_metadata.prompt_token_count,
time=used_time,
)
else:
usage = None
return ChatResponse(
content=content,
usage=usage,
)
def _format_tools_json_schemas(
self,
schemas: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format the tools JSON schema into required format for Gemini API.
Args:
schemas (`dict[str, Any]`):
The tools JSON schemas.
Returns:
List[Dict[str, Any]]:
A list containing a dictionary with the
"function_declarations" key, which maps to a list of
function definitions.
Example:
.. code-block:: python
:caption: Example tool schemas of Gemini API
# Input JSON schema
schemas = [
{
'type': 'function',
'function': {
'name': 'execute_shell_command',
'description': 'xxx',
'parameters': {
'type': 'object',
'properties': {
'command': {
'type': 'string',
'description': 'xxx.'
},
'timeout': {
'type': 'integer',
'default': 300
}
},
'required': ['command']
}
}
}
]
# Output format (Gemini API expected):
[
{
'function_declarations': [
{
'name': 'execute_shell_command',
'description': 'xxx.',
'parameters': {
'type': 'object',
'properties': {
'command': {
'type': 'string',
'description': 'xxx.'
},
'timeout': {
'type': 'integer',
'default': 300
}
},
'required': ['command']
}
}
]
}
]
"""
return [
{
"function_declarations": [
_["function"] for _ in schemas if "function" in _
],
},
]
def _format_tool_choice(
self,
tool_choice: Literal["auto", "none", "any", "required"] | str | None,
) -> dict | None:
"""Format tool_choice parameter for API compatibility.
Args:
tool_choice (`Literal["auto", "none"] | str | None`, default \
`None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "any", "required", or specific tool
name.
For more details, please refer to
https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes
Returns:
`dict | None`:
The formatted tool choice configuration dict, or None if
tool_choice is None.
"""
if tool_choice is None:
return None
mode_mapping = {
"auto": "AUTO",
"none": "NONE",
"any": "ANY",
"required": "ANY",
}
mode = mode_mapping.get(tool_choice)
if mode:
return {"function_calling_config": {"mode": mode}}
return {
"function_calling_config": {
"mode": "ANY",
"allowed_function_names": [tool_choice],
},
}