Source code for agentscope.models.gemini_model

# -*- coding: utf-8 -*-
"""Google Gemini model wrapper."""
import os
from abc import ABC
from collections.abc import Iterable
from typing import Sequence, Union, Any, List, Optional, Generator

from loguru import logger

from ..message import Msg
from ..models import ModelWrapperBase, ModelResponse
from ..utils.common import _convert_to_str

try:
    import google.generativeai as genai

    # This package will be installed when the google-generativeai is installed
    import google.ai.generativelanguage as glm
except ImportError:
    genai = None
    glm = None


[docs] class GeminiWrapperBase(ModelWrapperBase, ABC): """The base class for Google Gemini model wrapper.""" _generation_method = None """The generation method used in `__call__` function, which is used to filter models in `list_models` function.""" def __init__( self, config_name: str, model_name: str, api_key: str = None, **kwargs: Any, ) -> None: """Initialize the wrapper for Google Gemini model. Args: model_name (`str`): The name of the model. api_key (`str`, defaults to `None`): The api_key for the model. If it is not provided, it will be loaded from environment variable. """ super().__init__(config_name=config_name, model_name=model_name) # Test if the required package is installed if genai is None: raise ImportError( "The google-generativeai package is not installed, " "please install it first.", ) # Load the api_key from argument or environment variable api_key = api_key or os.environ.get("GOOGLE_API_KEY") if api_key is None: raise ValueError( "Google api_key must be provided or set as an " "environment variable.", ) genai.configure(api_key=api_key, **kwargs) self.model_name = model_name
[docs] def list_models(self) -> Sequence: """List all available models for this API calling.""" support_models = list(genai.list_models()) if self.generation_method is None: return support_models else: return [ _ for _ in support_models if self._generation_method in _.supported_generation_methods ]
def __call__(self, *args: Any, **kwargs: Any) -> ModelResponse: """Processing input with the model.""" raise NotImplementedError( f"Model Wrapper [{type(self).__name__}]" f" is missing the the required `__call__`" f" method.", )
[docs] class GeminiChatWrapper(GeminiWrapperBase): """The wrapper for Google Gemini chat model, e.g. gemini-pro""" model_type: str = "gemini_chat" """The type of the model, which is used in model configuration.""" generation_method = "generateContent" """The generation method used in `__call__` function.""" def __init__( self, config_name: str, model_name: str, api_key: str = None, stream: bool = False, **kwargs: Any, ) -> None: """Initialize the wrapper for Google Gemini model. Args: model_name (`str`): The name of the model. api_key (`str`, defaults to `None`): The api_key for the model. If it is not provided, it will be loaded from environment variable. stream (`bool`, defaults to `False`): Whether to use stream mode. """ super().__init__( config_name=config_name, model_name=model_name, api_key=api_key, **kwargs, ) self.stream = stream # Create the generative model self.model = genai.GenerativeModel(model_name, **kwargs) def __call__( self, contents: Union[Sequence, str], stream: Optional[bool] = None, **kwargs: Any, ) -> ModelResponse: """Generate response for the given contents. Args: contents (`Union[Sequence, str]`): The content to generate response. stream (`Optional[bool]`, defaults to `None`) Whether to use stream mode. **kwargs: The additional arguments for generating response. Returns: `ModelResponse`: The response text in text field, and the raw response in raw field. """ # step1: checking messages if isinstance(contents, Iterable): pass elif not isinstance(contents, str): logger.warning( "The input content is not a string or a list of " "messages, which may cause unexpected behavior.", ) # Check if stream is provided if stream is None: stream = self.stream # step2: forward to generate response kwargs.update( { "contents": contents, "stream": stream, }, ) response = self.model.generate_content(**kwargs) if stream: def generator() -> Generator[str, None, None]: text = "" last_chunk = None for chunk in response: text += self._extract_text_content_from_response( contents, chunk, ) yield text last_chunk = chunk # Update the last chunk last_chunk.candidates[0].content.parts[0].text = text self._save_model_invocation_and_update_monitor( contents, kwargs, last_chunk, ) return ModelResponse( stream=generator(), ) else: self._save_model_invocation_and_update_monitor( contents, kwargs, response, ) # step6: return response return ModelResponse( text=response.text, raw=response, ) def _save_model_invocation_and_update_monitor( self, contents: Union[Sequence[Any], str], kwargs: dict, response: Any, ) -> None: """Save the model invocation and update the monitor accordingly.""" # Record the api invocation if needed self._save_model_invocation( arguments=kwargs, response=str(response), ) # Update monitor accordingly if hasattr(response, "usage_metadata"): token_prompt = response.usage_metadata.prompt_token_count token_response = response.usage_metadata.candidates_token_count else: token_prompt = self.model.count_tokens(contents).total_tokens token_response = self.model.count_tokens( response.text, ).total_tokens self.monitor.update_text_and_embedding_tokens( model_name=self.model_name, prompt_tokens=token_prompt, completion_tokens=token_response, ) def _extract_text_content_from_response( self, contents: Union[Sequence[Any], str], response: Any, ) -> str: """Extract the text content from the response of gemini API Note: to avoid import error during type checking in python 3.11+, here we use `typing.Any` to avoid raising error Args: contents (`Union[Sequence[Any], str]`): The prompt contents response (`Any`): The response from gemini API Returns: `str`: The extracted string. """ # Check for candidates and handle accordingly if ( not response.candidates[0].content or not response.candidates[0].content.parts or not response.candidates[0].content.parts[0].text ): # If we cannot get the response text from the model finish_reason = response.candidates[0].finish_reason reasons = glm.Candidate.FinishReason if finish_reason == reasons.STOP: error_info = ( "Natural stop point of the model or provided stop " "sequence." ) elif finish_reason == reasons.MAX_TOKENS: error_info = ( "The maximum number of tokens as specified in the request " "was reached." ) elif finish_reason == reasons.SAFETY: error_info = ( "The candidate content was flagged for safety reasons." ) elif finish_reason == reasons.RECITATION: error_info = ( "The candidate content was flagged for recitation reasons." ) elif finish_reason in [ reasons.FINISH_REASON_UNSPECIFIED, reasons.OTHER, ]: error_info = "Unknown error." else: error_info = "No information provided from Gemini API." raise ValueError( "The Google Gemini API failed to generate text response with " f"the following finish reason: {error_info}\n" f"YOUR INPUT: {contents}\n" f"RAW RESPONSE FROM GEMINI API: {response}\n", ) return response.text
[docs] @staticmethod def format( *args: Union[Msg, Sequence[Msg]], ) -> List[dict]: """This function provide a basic prompting strategy for Gemini Chat API in multi-party conversation, which combines all input into a single string, and wrap it into a user message. We make the above decision based on the following constraints of the Gemini generate API: 1. In Gemini `generate_content` API, the `role` field must be either `user` or `model`. 2. If we pass a list of messages to the `generate_content` API, the `user` role must speak in the beginning and end of the messages, and `user` and `model` must alternative. This prevents us to build a multi-party conversations, where `model` may keep speaking in different names. The above information is updated to 2024/03/21. More information about the Gemini `generate_content` API can be found in https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini Based on the above considerations, we decide to combine all messages into a single user message. This is a simple and straightforward strategy, if you have any better ideas, pull request and discussion are welcome in our GitHub repository https://github.com/agentscope/agentscope! Args: args (`Union[Msg, Sequence[Msg]]`): The input arguments to be formatted, where each argument should be a `Msg` object, or a list of `Msg` objects. In distribution, placeholder is also allowed. Returns: `List[dict]`: A list with one user message. """ if len(args) == 0: raise ValueError( "At least one message should be provided. An empty message " "list is not allowed.", ) input_msgs = [] for _ in args: if _ is None: continue if isinstance(_, Msg): input_msgs.append(_) elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _): input_msgs.extend(_) else: raise TypeError( f"The input should be a Msg object or a list " f"of Msg objects, got {type(_)}.", ) # record dialog history as a list of strings sys_prompt = None dialogue = [] for i, unit in enumerate(input_msgs): if i == 0 and unit.role == "system": # system prompt sys_prompt = _convert_to_str(unit.content) else: # Merge all messages into a conversation history prompt dialogue.append( f"{unit.name}: {_convert_to_str(unit.content)}", ) prompt_components = [] if sys_prompt is not None: if not sys_prompt.endswith("\n"): sys_prompt += "\n" prompt_components.append(sys_prompt) if len(dialogue) > 0: prompt_components.extend(["## Conversation History"] + dialogue) user_prompt = "\n".join(prompt_components) messages = [ { "role": "user", "parts": [ user_prompt, ], }, ] return messages
[docs] class GeminiEmbeddingWrapper(GeminiWrapperBase): """The wrapper for Google Gemini embedding model, e.g. models/embedding-001 Response: - Refer to https://ai.google.dev/api/embeddings?hl=zh-cn#response-body .. code-block:: json { "embeddings": [ { object (ContentEmbedding) } ] } """ model_type: str = "gemini_embedding" """The type of the model, which is used in model configuration.""" _generation_method = "embedContent" """The generation method used in `__call__` function.""" def __call__( self, content: Union[Sequence[Msg], str], task_type: str = None, title: str = None, **kwargs: Any, ) -> ModelResponse: """Generate embedding for the given content. More detailed information please refer to https://ai.google.dev/tutorials/python_quickstart#use_embeddings Args: content (`Union[Sequence[Msg], str]`): The content to generate embedding. task_type (`str`, defaults to `None`): The type of the task. title (`str`, defaults to `None`): The title of the content. **kwargs: The additional arguments for generating embedding. Returns: `ModelResponse`: The response embedding in embedding field, and the raw response in raw field. """ # step1: forward to generate response response = genai.embed_content( model=self.model_name, content=content, task_type=task_type, title=title, **kwargs, ) # step2: record the api invocation if needed self._save_model_invocation( arguments={ "content": content, "task_type": task_type, "title": title, **kwargs, }, response=response, ) # TODO: Up to 2024/07/26, the embedding model doesn't support to # count tokens. # step3: update monitor accordingly self.monitor.update_text_and_embedding_tokens( model_name=self.model_name, ) return ModelResponse( raw=response, embedding=response["embedding"], )