Source code for agentscope.embedding._ollama_embedding
# -*- coding: utf-8 -*-"""The ollama text embedding model class."""importasynciofromdatetimeimportdatetimefromtypingimportList,Anyfrom._embedding_responseimportEmbeddingResponsefrom._embedding_usageimportEmbeddingUsagefrom._cache_baseimportEmbeddingCacheBasefrom..embeddingimportEmbeddingModelBase
[docs]def__init__(self,model_name:str,host:str|None=None,embedding_cache:EmbeddingCacheBase|None=None,**kwargs:Any,)->None:"""Initialize the Ollama text embedding model class. Args: model_name (`str`): The name of the embedding model. host (`str | None`, defaults to `None`): The host URL for the Ollama API. embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`): The embedding cache class instance, used to cache the embedding results to avoid repeated API calls. """importollamasuper().__init__(model_name)self.client=ollama.AsyncClient(host=host,**kwargs)self.embedding_cache=embedding_cache
[docs]asyncdef__call__(self,text:List[str],**kwargs:Any,)->EmbeddingResponse:"""Call the Ollama embedding API. Args: text (`List[str]`): The input text to be embedded. It can be a list of strings. """ifself.embedding_cache:cached_embeddings=awaitself.embedding_cache.retrieve(identifier=kwargs,)ifcached_embeddings:returnEmbeddingResponse(embeddings=cached_embeddings,usage=EmbeddingUsage(tokens=0,time=0,),source="cache",)start_time=datetime.now()response=awaitasyncio.gather(*[self.client.embeddings(self.model_name,_,**kwargs)for_intext],)time=(datetime.now()-start_time).total_seconds()ifself.embedding_cache:awaitself.embedding_cache.store(identifier=kwargs,embeddings=[_.embeddingfor_inresponse],)returnEmbeddingResponse(embeddings=[_.embeddingfor_inresponse],usage=EmbeddingUsage(time=time,),)