# -*- coding: utf-8 -*-"""The gemini text embedding model class."""fromdatetimeimportdatetimefromtypingimportAny,Listfrom._embedding_responseimportEmbeddingResponsefrom._embedding_usageimportEmbeddingUsagefrom._cache_baseimportEmbeddingCacheBasefrom._embedding_baseimportEmbeddingModelBasefrom..messageimportTextBlock
[文档]classGeminiTextEmbedding(EmbeddingModelBase):"""The Gemini text embedding model."""supported_modalities:list[str]=["text"]"""This class only supports text input."""
[文档]def__init__(self,api_key:str,model_name:str,dimensions:int=3072,embedding_cache:EmbeddingCacheBase|None=None,**kwargs:Any,)->None:"""Initialize the Gemini text embedding model class. Args: api_key (`str`): The Gemini API key. model_name (`str`): The name of the embedding model. dimensions (`int`, defaults to 3072): The dimension of the embedding vector, refer to the `official documentation <https://ai.google.dev/gemini-api/docs/embeddings?hl=zh-cn#control-embedding-size>`_ for more details. embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`): The embedding cache class instance, used to cache the embedding results to avoid repeated API calls. """fromgoogleimportgenaisuper().__init__(model_name,dimensions)self.client=genai.Client(api_key=api_key,**kwargs)self.embedding_cache=embedding_cache
[文档]asyncdef__call__(self,text:List[str|TextBlock],**kwargs:Any,)->EmbeddingResponse:"""The Gemini embedding API call. Args: text (`List[str | TextBlock]`): The input text to be embedded. It can be a list of strings. # TODO: handle the batch size limit """gather_text=[]for_intext:ifisinstance(_,dict)and"text"in_:gather_text.append(_["text"])elifisinstance(_,str):gather_text.append(_)else:raiseValueError("Input text must be a list of strings or TextBlock dicts.",)kwargs={"model":self.model_name,"contents":gather_text,"config":kwargs,}ifself.embedding_cache:cached_embeddings=awaitself.embedding_cache.retrieve(identifier=kwargs,)ifcached_embeddings:returnEmbeddingResponse(embeddings=cached_embeddings,usage=EmbeddingUsage(tokens=0,time=0,),source="cache",)start_time=datetime.now()response=self.client.models.embed_content(**kwargs)time=(datetime.now()-start_time).total_seconds()ifself.embedding_cache:awaitself.embedding_cache.store(identifier=kwargs,embeddings=[_.valuesfor_inresponse.embeddings],)returnEmbeddingResponse(embeddings=[_.valuesfor_inresponse.embeddings],usage=EmbeddingUsage(time=time,),)