# -*- coding: utf-8 -*-
"""Google Gemini model wrapper."""
import os
from abc import ABC
from collections.abc import Iterable
from typing import Sequence, Union, Any, List, Optional, Generator
from loguru import logger
from ._model_usage import ChatUsage
from ..formatters import GeminiFormatter
from ..message import Msg
from ..models import ModelWrapperBase, ModelResponse
try:
import google.generativeai as genai
# This package will be installed when the google-generativeai is installed
import google.ai.generativelanguage as glm
except ImportError:
genai = None
glm = None
[docs]
class GeminiWrapperBase(ModelWrapperBase, ABC):
"""The base class for Google Gemini model wrapper."""
_generation_method = None
"""The generation method used in `__call__` function, which is used to
filter models in `list_models` function."""
def __init__(
self,
config_name: str,
model_name: str,
api_key: str = None,
**kwargs: Any,
) -> None:
"""Initialize the wrapper for Google Gemini model.
Args:
model_name (`str`):
The name of the model.
api_key (`str`, defaults to `None`):
The api_key for the model. If it is not provided, it will be
loaded from environment variable.
"""
super().__init__(config_name=config_name, model_name=model_name)
# Test if the required package is installed
if genai is None:
raise ImportError(
"The google-generativeai package is not installed, "
"please install it first.",
)
# Load the api_key from argument or environment variable
api_key = api_key or os.environ.get("GOOGLE_API_KEY")
if api_key is None:
raise ValueError(
"Google api_key must be provided or set as an "
"environment variable.",
)
genai.configure(api_key=api_key, **kwargs)
self.model_name = model_name
[docs]
def list_models(self) -> Sequence:
"""List all available models for this API calling."""
support_models = list(genai.list_models())
if self.generation_method is None:
return support_models
else:
return [
_
for _ in support_models
if self._generation_method in _.supported_generation_methods
]
def __call__(self, *args: Any, **kwargs: Any) -> ModelResponse:
"""Processing input with the model."""
raise NotImplementedError(
f"Model Wrapper [{type(self).__name__}]"
f" is missing the the required `__call__`"
f" method.",
)
[docs]
class GeminiChatWrapper(GeminiWrapperBase):
"""The wrapper for Google Gemini chat model, e.g. gemini-pro"""
model_type: str = "gemini_chat"
"""The type of the model, which is used in model configuration."""
generation_method = "generateContent"
"""The generation method used in `__call__` function."""
def __init__(
self,
config_name: str,
model_name: str,
api_key: str = None,
stream: bool = False,
**kwargs: Any,
) -> None:
"""Initialize the wrapper for Google Gemini model.
Args:
model_name (`str`):
The name of the model.
api_key (`str`, defaults to `None`):
The api_key for the model. If it is not provided, it will be
loaded from environment variable.
stream (`bool`, defaults to `False`):
Whether to use stream mode.
"""
super().__init__(
config_name=config_name,
model_name=model_name,
api_key=api_key,
**kwargs,
)
self.stream = stream
# Create the generative model
self.model = genai.GenerativeModel(model_name, **kwargs)
def __call__(
self,
contents: Union[Sequence, str],
stream: Optional[bool] = None,
**kwargs: Any,
) -> ModelResponse:
"""Generate response for the given contents.
Args:
contents (`Union[Sequence, str]`):
The content to generate response.
stream (`Optional[bool]`, defaults to `None`)
Whether to use stream mode.
**kwargs:
The additional arguments for generating response.
Returns:
`ModelResponse`:
The response text in text field, and the raw response in raw
field.
"""
# step1: checking messages
if isinstance(contents, Iterable):
pass
elif not isinstance(contents, str):
logger.warning(
"The input content is not a string or a list of "
"messages, which may cause unexpected behavior.",
)
# Check if stream is provided
if stream is None:
stream = self.stream
# step2: forward to generate response
kwargs.update(
{
"contents": contents,
"stream": stream,
},
)
response = self.model.generate_content(**kwargs)
if stream:
def generator() -> Generator[str, None, None]:
text = ""
last_chunk = None
for chunk in response:
text += self._extract_text_content_from_response(
contents,
chunk,
)
yield text
last_chunk = chunk
# Update the last chunk
last_chunk.candidates[0].content.parts[0].text = text
self._save_model_invocation_and_update_monitor(
contents,
kwargs,
last_chunk,
)
return ModelResponse(
stream=generator(),
)
else:
self._save_model_invocation_and_update_monitor(
contents,
kwargs,
response,
)
# step6: return response
return ModelResponse(
text=response.text,
raw=response,
)
def _save_model_invocation_and_update_monitor(
self,
contents: Union[Sequence[Any], str],
kwargs: dict,
response: Any,
) -> None:
"""Save the model invocation and update the monitor accordingly."""
# Update monitor accordingly
if hasattr(response, "usage_metadata"):
prompt_tokens = response.usage_metadata.prompt_token_count
completion_tokens = response.usage_metadata.candidates_token_count
else:
prompt_tokens = self.model.count_tokens(contents).total_tokens
completion_tokens = self.model.count_tokens(
response.text,
).total_tokens
formatted_usage = ChatUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
# Record the api invocation if needed
self._save_model_invocation(
arguments=kwargs,
response=str(response),
usage=formatted_usage,
)
self.monitor.update_text_and_embedding_tokens(
model_name=self.model_name,
**formatted_usage.usage.model_dump(),
)
def _extract_text_content_from_response(
self,
contents: Union[Sequence[Any], str],
response: Any,
) -> str:
"""Extract the text content from the response of gemini API
Note: to avoid import error during type checking in python 3.11+, here
we use `typing.Any` to avoid raising error
Args:
contents (`Union[Sequence[Any], str]`):
The prompt contents
response (`Any`):
The response from gemini API
Returns:
`str`: The extracted string.
"""
# Check for candidates and handle accordingly
if (
not response.candidates[0].content
or not response.candidates[0].content.parts
or not response.candidates[0].content.parts[0].text
):
# If we cannot get the response text from the model
finish_reason = response.candidates[0].finish_reason
reasons = glm.Candidate.FinishReason
if finish_reason == reasons.STOP:
error_info = (
"Natural stop point of the model or provided stop "
"sequence."
)
elif finish_reason == reasons.MAX_TOKENS:
error_info = (
"The maximum number of tokens as specified in the request "
"was reached."
)
elif finish_reason == reasons.SAFETY:
error_info = (
"The candidate content was flagged for safety reasons."
)
elif finish_reason == reasons.RECITATION:
error_info = (
"The candidate content was flagged for recitation reasons."
)
elif finish_reason in [
reasons.FINISH_REASON_UNSPECIFIED,
reasons.OTHER,
]:
error_info = "Unknown error."
else:
error_info = "No information provided from Gemini API."
raise ValueError(
"The Google Gemini API failed to generate text response with "
f"the following finish reason: {error_info}\n"
f"YOUR INPUT: {contents}\n"
f"RAW RESPONSE FROM GEMINI API: {response}\n",
)
return response.text
[docs]
class GeminiEmbeddingWrapper(GeminiWrapperBase):
"""The wrapper for Google Gemini embedding model,
e.g. models/embedding-001
Response:
- Refer to https://ai.google.dev/api/embeddings?hl=zh-cn#response-body
.. code-block:: json
{
"embeddings": [
{
object (ContentEmbedding)
}
]
}
"""
model_type: str = "gemini_embedding"
"""The type of the model, which is used in model configuration."""
_generation_method = "embedContent"
"""The generation method used in `__call__` function."""
def __call__(
self,
content: Union[Sequence[Msg], str],
task_type: str = None,
title: str = None,
**kwargs: Any,
) -> ModelResponse:
"""Generate embedding for the given content. More detailed information
please refer to
https://ai.google.dev/tutorials/python_quickstart#use_embeddings
Args:
content (`Union[Sequence[Msg], str]`):
The content to generate embedding.
task_type (`str`, defaults to `None`):
The type of the task.
title (`str`, defaults to `None`):
The title of the content.
**kwargs:
The additional arguments for generating embedding.
Returns:
`ModelResponse`:
The response embedding in embedding field, and the raw response
in raw field.
"""
# step1: forward to generate response
response = genai.embed_content(
model=self.model_name,
content=content,
task_type=task_type,
title=title,
**kwargs,
)
# step2: record the api invocation if needed
self._save_model_invocation(
arguments={
"content": content,
"task_type": task_type,
"title": title,
**kwargs,
},
response=response,
)
# TODO: Up to 2024/07/26, the embedding model doesn't support to
# count tokens.
# step3: update monitor accordingly
self.monitor.update_text_and_embedding_tokens(
model_name=self.model_name,
)
return ModelResponse(
raw=response,
embedding=response["embedding"],
)