Source code for agentscope.models.post_model

# -*- coding: utf-8 -*-
"""Model wrapper for post-based inference apis."""
import json
import time
from abc import ABC
from typing import Any, Union, Sequence, List, Optional

import requests
from loguru import logger

from .openai_model import OpenAIChatWrapper
from .model import ModelWrapperBase, ModelResponse
from ..constants import _DEFAULT_MAX_RETRIES
from ..constants import _DEFAULT_MESSAGES_KEY
from ..constants import _DEFAULT_RETRY_INTERVAL
from ..message import Msg


[docs] class PostAPIModelWrapperBase(ModelWrapperBase, ABC): """The base model wrapper for the model deployed on the POST API.""" model_type: str def __init__( self, config_name: str, api_url: str, model_name: Optional[str] = None, headers: dict = None, max_length: int = 2048, timeout: int = 30, json_args: dict = None, post_args: dict = None, max_retries: int = _DEFAULT_MAX_RETRIES, messages_key: str = _DEFAULT_MESSAGES_KEY, retry_interval: int = _DEFAULT_RETRY_INTERVAL, **kwargs: Any, ) -> None: """Initialize the model wrapper. Args: config_name (`str`): The id of the model. api_url (`str`): The url of the post request api. model_name (`str`): The name of the model. If `None`, the model name will be extracted from the `json_args`. headers (`dict`, defaults to `None`): The headers of the api. Defaults to None. max_length (`int`, defaults to `2048`): The maximum length of the model. timeout (`int`, defaults to `30`): The timeout of the api. Defaults to 30. json_args (`dict`, defaults to `None`): The json arguments of the api. Defaults to None. post_args (`dict`, defaults to `None`): The post arguments of the api. Defaults to None. max_retries (`int`, defaults to `3`): The maximum number of retries when the `parse_func` raise an exception. messages_key (`str`, defaults to `inputs`): The key of the input messages in the json argument. retry_interval (`int`, defaults to `1`): The interval between retries when a request fails. Note: When an object of `PostApiModelWrapper` is called, the arguments will of post requests will be used as follows: .. code-block:: python request.post( url=api_url, headers=headers, json={ messages_key: messages, **json_args }, **post_args ) """ if model_name is None: if json_args is not None: model_name = json_args.get( "model", json_args.get("model_name", None), ) else: model_name = None super().__init__(config_name=config_name, model_name=model_name) self.api_url = api_url self.headers = headers self.max_length = max_length self.timeout = timeout self.json_args = json_args or {} self.post_args = post_args or {} self.max_retries = max_retries self.messages_key = messages_key self.retry_interval = retry_interval def _parse_response(self, response: dict) -> ModelResponse: """Parse the response json data into ModelResponse""" return ModelResponse(raw=response) def __call__(self, input_: str, **kwargs: Any) -> ModelResponse: """Calling the model with requests.post. Args: input_ (`str`): The input string to the model. Returns: `dict`: A dictionary that contains the response of the model and related information (e.g. cost, time, the number of tokens, etc.). Note: `parse_func`, `fault_handler` and `max_retries` are reserved for `_response_parse_decorator` to parse and check the response generated by model wrapper. Their usages are listed as follows: - `parse_func` is a callable function used to parse and check the response generated by the model, which takes the response as input. - `max_retries` is the maximum number of retries when the `parse_func` raise an exception. - `fault_handler` is a callable function which is called when the response generated by the model is invalid after `max_retries` retries. """ # step1: prepare keyword arguments post_args = {**self.post_args, **kwargs} request_kwargs = { "url": self.api_url, "json": {self.messages_key: input_, **self.json_args}, "headers": self.headers or {}, **post_args, } # step2: prepare post requests for i in range(1, self.max_retries + 1): response = requests.post(**request_kwargs) if response.status_code == requests.codes.ok: break if i < self.max_retries: logger.warning( f"Failed to call the model with " f"requests.codes == {response.status_code}, retry " f"{i + 1}/{self.max_retries} times", ) time.sleep(i * self.retry_interval) # step3: record model invocation # record the model api invocation, which will be skipped if # `FileManager.save_api_invocation` is `False` try: response_json = response.json() except requests.exceptions.JSONDecodeError as e: raise RuntimeError( f"Fail to serialize the response to json: \n{str(response)}", ) from e self._save_model_invocation( arguments=request_kwargs, response=response_json, ) # step4: parse the response if response.status_code == requests.codes.ok: return self._parse_response(response_json) else: logger.error(json.dumps(request_kwargs, indent=4)) raise RuntimeError( f"Failed to call the model with {response.json()}", )
[docs] class PostAPIChatWrapper(PostAPIModelWrapperBase): """A post api model wrapper compatible with openai chat, e.g., vLLM, FastChat.""" model_type: str = "post_api_chat" def _parse_response(self, response: dict) -> ModelResponse: return ModelResponse( text=response["data"]["response"]["choices"][0]["message"][ "content" ], )
[docs] def format( self, *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict]]: """Format the input messages into a list of dict according to the model name. For example, if the model name is prefixed with "gpt-", the input messages will be formatted for OpenAI models. Args: args (`Union[Msg, Sequence[Msg]]`): The input arguments to be formatted, where each argument should be a `Msg` object, or a list of `Msg` objects. In distribution, placeholder is also allowed. Returns: `Union[List[dict]]`: The formatted messages. """ # Format according to the potential model field in the json_args model_name = self.json_args.get( "model", self.json_args.get("model_name", None), ) # OpenAI if model_name and model_name.startswith("gpt-"): return OpenAIChatWrapper.static_format( *args, model_name=model_name, ) # Gemini elif model_name and model_name.startswith("gemini"): from .gemini_model import GeminiChatWrapper return GeminiChatWrapper.format(*args) # Include DashScope, ZhipuAI, Ollama, the other models supported by # litellm and unknown models else: return ModelWrapperBase.format_for_common_chat_models(*args)
[docs] class PostAPIDALLEWrapper(PostAPIModelWrapperBase): """A post api model wrapper compatible with openai dall_e""" model_type: str = "post_api_dall_e" def _parse_response(self, response: dict) -> ModelResponse: if "data" not in response["data"]["response"]: if "error" in response["data"]["response"]: error_msg = response["data"]["response"]["error"]["message"] else: error_msg = response["data"]["response"] logger.error(f"Error in API call:\n{error_msg}") raise ValueError(f"Error in API call:\n{error_msg}") urls = [img["url"] for img in response["data"]["response"]["data"]] return ModelResponse(image_urls=urls)
[docs] def format( self, *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict], str]: raise RuntimeError( f"Model Wrapper [{type(self).__name__}] doesn't " f"need to format the input. Please try to use the " f"model wrapper directly.", )
[docs] class PostAPIEmbeddingWrapper(PostAPIModelWrapperBase): """ A post api model wrapper for embedding model """ model_type: str = "post_api_embedding" def _parse_response(self, response: dict) -> ModelResponse: """ Parse the response json data into ModelResponse with embedding. Args: response (`dict`): The response obtained from the API. This parsing assume the structure of the response is the same as OpenAI's as following: { "object": "list", "data": [ { "object": "embedding", "embedding": [ 0.0023064255, -0.009327292, .... (1536 floats total for ada-002) -0.0028842222, ], "index": 0 } ], "model": "text-embedding-ada-002", "usage": { "prompt_tokens": 8, "total_tokens": 8 } } """ if ( "data" not in response or len(response["data"]) < 1 or "embedding" not in response["data"][0] ): error_msg = json.dumps(response, ensure_ascii=False, indent=2) logger.error(f"Error in embedding API call:\n{error_msg}") raise ValueError(f"Error in embedding API call:\n{error_msg}") embeddings = [data["embedding"] for data in response["data"]] return ModelResponse( embedding=embeddings, raw=response, )
[docs] def format( self, *args: Union[Msg, Sequence[Msg]], ) -> Union[List[dict], str]: raise RuntimeError( f"Model Wrapper [{type(self).__name__}] doesn't " f"need to format the input. Please try to use the " f"model wrapper directly.", )