# -*- coding: utf-8 -*-
"""Google Gemini model wrapper."""
import os
from abc import ABC
from collections.abc import Iterable
from typing import Sequence, Union, Any, List, Optional, Generator
from loguru import logger
from ..message import Msg
from ..models import ModelWrapperBase, ModelResponse
from ..utils.common import _convert_to_str
try:
import google.generativeai as genai
# This package will be installed when the google-generativeai is installed
import google.ai.generativelanguage as glm
except ImportError:
genai = None
glm = None
[docs]
class GeminiWrapperBase(ModelWrapperBase, ABC):
"""The base class for Google Gemini model wrapper."""
_generation_method = None
"""The generation method used in `__call__` function, which is used to
filter models in `list_models` function."""
def __init__(
self,
config_name: str,
model_name: str,
api_key: str = None,
**kwargs: Any,
) -> None:
"""Initialize the wrapper for Google Gemini model.
Args:
model_name (`str`):
The name of the model.
api_key (`str`, defaults to `None`):
The api_key for the model. If it is not provided, it will be
loaded from environment variable.
"""
super().__init__(config_name=config_name, model_name=model_name)
# Test if the required package is installed
if genai is None:
raise ImportError(
"The google-generativeai package is not installed, "
"please install it first.",
)
# Load the api_key from argument or environment variable
api_key = api_key or os.environ.get("GOOGLE_API_KEY")
if api_key is None:
raise ValueError(
"Google api_key must be provided or set as an "
"environment variable.",
)
genai.configure(api_key=api_key, **kwargs)
self.model_name = model_name
[docs]
def list_models(self) -> Sequence:
"""List all available models for this API calling."""
support_models = list(genai.list_models())
if self.generation_method is None:
return support_models
else:
return [
_
for _ in support_models
if self._generation_method in _.supported_generation_methods
]
def __call__(self, *args: Any, **kwargs: Any) -> ModelResponse:
"""Processing input with the model."""
raise NotImplementedError(
f"Model Wrapper [{type(self).__name__}]"
f" is missing the the required `__call__`"
f" method.",
)
[docs]
class GeminiChatWrapper(GeminiWrapperBase):
"""The wrapper for Google Gemini chat model, e.g. gemini-pro"""
model_type: str = "gemini_chat"
"""The type of the model, which is used in model configuration."""
generation_method = "generateContent"
"""The generation method used in `__call__` function."""
def __init__(
self,
config_name: str,
model_name: str,
api_key: str = None,
stream: bool = False,
**kwargs: Any,
) -> None:
"""Initialize the wrapper for Google Gemini model.
Args:
model_name (`str`):
The name of the model.
api_key (`str`, defaults to `None`):
The api_key for the model. If it is not provided, it will be
loaded from environment variable.
stream (`bool`, defaults to `False`):
Whether to use stream mode.
"""
super().__init__(
config_name=config_name,
model_name=model_name,
api_key=api_key,
**kwargs,
)
self.stream = stream
# Create the generative model
self.model = genai.GenerativeModel(model_name, **kwargs)
def __call__(
self,
contents: Union[Sequence, str],
stream: Optional[bool] = None,
**kwargs: Any,
) -> ModelResponse:
"""Generate response for the given contents.
Args:
contents (`Union[Sequence, str]`):
The content to generate response.
stream (`Optional[bool]`, defaults to `None`)
Whether to use stream mode.
**kwargs:
The additional arguments for generating response.
Returns:
`ModelResponse`:
The response text in text field, and the raw response in raw
field.
"""
# step1: checking messages
if isinstance(contents, Iterable):
pass
elif not isinstance(contents, str):
logger.warning(
"The input content is not a string or a list of "
"messages, which may cause unexpected behavior.",
)
# Check if stream is provided
if stream is None:
stream = self.stream
# step2: forward to generate response
kwargs.update(
{
"contents": contents,
"stream": stream,
},
)
response = self.model.generate_content(**kwargs)
if stream:
def generator() -> Generator[str, None, None]:
text = ""
last_chunk = None
for chunk in response:
text += self._extract_text_content_from_response(
contents,
chunk,
)
yield text
last_chunk = chunk
# Update the last chunk
last_chunk.candidates[0].content.parts[0].text = text
self._save_model_invocation_and_update_monitor(
contents,
kwargs,
last_chunk,
)
return ModelResponse(
stream=generator(),
)
else:
self._save_model_invocation_and_update_monitor(
contents,
kwargs,
response,
)
# step6: return response
return ModelResponse(
text=response.text,
raw=response,
)
def _save_model_invocation_and_update_monitor(
self,
contents: Union[Sequence[Any], str],
kwargs: dict,
response: Any,
) -> None:
"""Save the model invocation and update the monitor accordingly."""
# Record the api invocation if needed
self._save_model_invocation(
arguments=kwargs,
response=str(response),
)
# Update monitor accordingly
if hasattr(response, "usage_metadata"):
token_prompt = response.usage_metadata.prompt_token_count
token_response = response.usage_metadata.candidates_token_count
else:
token_prompt = self.model.count_tokens(contents).total_tokens
token_response = self.model.count_tokens(
response.text,
).total_tokens
self.monitor.update_text_and_embedding_tokens(
model_name=self.model_name,
prompt_tokens=token_prompt,
completion_tokens=token_response,
)
def _extract_text_content_from_response(
self,
contents: Union[Sequence[Any], str],
response: Any,
) -> str:
"""Extract the text content from the response of gemini API
Note: to avoid import error during type checking in python 3.11+, here
we use `typing.Any` to avoid raising error
Args:
contents (`Union[Sequence[Any], str]`):
The prompt contents
response (`Any`):
The response from gemini API
Returns:
`str`: The extracted string.
"""
# Check for candidates and handle accordingly
if (
not response.candidates[0].content
or not response.candidates[0].content.parts
or not response.candidates[0].content.parts[0].text
):
# If we cannot get the response text from the model
finish_reason = response.candidates[0].finish_reason
reasons = glm.Candidate.FinishReason
if finish_reason == reasons.STOP:
error_info = (
"Natural stop point of the model or provided stop "
"sequence."
)
elif finish_reason == reasons.MAX_TOKENS:
error_info = (
"The maximum number of tokens as specified in the request "
"was reached."
)
elif finish_reason == reasons.SAFETY:
error_info = (
"The candidate content was flagged for safety reasons."
)
elif finish_reason == reasons.RECITATION:
error_info = (
"The candidate content was flagged for recitation reasons."
)
elif finish_reason in [
reasons.FINISH_REASON_UNSPECIFIED,
reasons.OTHER,
]:
error_info = "Unknown error."
else:
error_info = "No information provided from Gemini API."
raise ValueError(
"The Google Gemini API failed to generate text response with "
f"the following finish reason: {error_info}\n"
f"YOUR INPUT: {contents}\n"
f"RAW RESPONSE FROM GEMINI API: {response}\n",
)
return response.text
[docs]
class GeminiEmbeddingWrapper(GeminiWrapperBase):
"""The wrapper for Google Gemini embedding model,
e.g. models/embedding-001
Response:
- Refer to https://ai.google.dev/api/embeddings?hl=zh-cn#response-body
.. code-block:: json
{
"embeddings": [
{
object (ContentEmbedding)
}
]
}
"""
model_type: str = "gemini_embedding"
"""The type of the model, which is used in model configuration."""
_generation_method = "embedContent"
"""The generation method used in `__call__` function."""
def __call__(
self,
content: Union[Sequence[Msg], str],
task_type: str = None,
title: str = None,
**kwargs: Any,
) -> ModelResponse:
"""Generate embedding for the given content. More detailed information
please refer to
https://ai.google.dev/tutorials/python_quickstart#use_embeddings
Args:
content (`Union[Sequence[Msg], str]`):
The content to generate embedding.
task_type (`str`, defaults to `None`):
The type of the task.
title (`str`, defaults to `None`):
The title of the content.
**kwargs:
The additional arguments for generating embedding.
Returns:
`ModelResponse`:
The response embedding in embedding field, and the raw response
in raw field.
"""
# step1: forward to generate response
response = genai.embed_content(
model=self.model_name,
content=content,
task_type=task_type,
title=title,
**kwargs,
)
# step2: record the api invocation if needed
self._save_model_invocation(
arguments={
"content": content,
"task_type": task_type,
"title": title,
**kwargs,
},
response=response,
)
# TODO: Up to 2024/07/26, the embedding model doesn't support to
# count tokens.
# step3: update monitor accordingly
self.monitor.update_text_and_embedding_tokens(
model_name=self.model_name,
)
return ModelResponse(
raw=response,
embedding=response["embedding"],
)