# -*- coding: utf-8 -*-
"""Model wrapper for post-based inference apis."""
import json
import time
from abc import ABC
from typing import Any, Union, Sequence, List
import requests
from loguru import logger
from .openai_model import OpenAIChatWrapper
from .model import ModelWrapperBase, ModelResponse
from ..constants import _DEFAULT_MAX_RETRIES
from ..constants import _DEFAULT_MESSAGES_KEY
from ..constants import _DEFAULT_RETRY_INTERVAL
from ..message import Msg
[docs]
class PostAPIModelWrapperBase(ModelWrapperBase, ABC):
"""The base model wrapper for the model deployed on the POST API."""
model_type: str = "post_api"
[docs]
def __init__(
self,
config_name: str,
api_url: str,
headers: dict = None,
max_length: int = 2048,
timeout: int = 30,
json_args: dict = None,
post_args: dict = None,
max_retries: int = _DEFAULT_MAX_RETRIES,
messages_key: str = _DEFAULT_MESSAGES_KEY,
retry_interval: int = _DEFAULT_RETRY_INTERVAL,
**kwargs: Any,
) -> None:
"""Initialize the model wrapper.
Args:
config_name (`str`):
The id of the model.
api_url (`str`):
The url of the post request api.
headers (`dict`, defaults to `None`):
The headers of the api. Defaults to None.
max_length (`int`, defaults to `2048`):
The maximum length of the model.
timeout (`int`, defaults to `30`):
The timeout of the api. Defaults to 30.
json_args (`dict`, defaults to `None`):
The json arguments of the api. Defaults to None.
post_args (`dict`, defaults to `None`):
The post arguments of the api. Defaults to None.
max_retries (`int`, defaults to `3`):
The maximum number of retries when the `parse_func` raise an
exception.
messages_key (`str`, defaults to `inputs`):
The key of the input messages in the json argument.
retry_interval (`int`, defaults to `1`):
The interval between retries when a request fails.
Note:
When an object of `PostApiModelWrapper` is called, the arguments
will of post requests will be used as follows:
.. code-block:: python
request.post(
url=api_url,
headers=headers,
json={
messages_key: messages,
**json_args
},
**post_args
)
"""
if json_args is not None:
model_name = json_args.get(
"model",
json_args.get("model_name", None),
)
else:
model_name = None
super().__init__(config_name=config_name, model_name=model_name)
self.api_url = api_url
self.headers = headers
self.max_length = max_length
self.timeout = timeout
self.json_args = json_args or {}
self.post_args = post_args or {}
self.max_retries = max_retries
self.messages_key = messages_key
self.retry_interval = retry_interval
def _parse_response(self, response: dict) -> ModelResponse:
"""Parse the response json data into ModelResponse"""
return ModelResponse(raw=response)
def __call__(self, input_: str, **kwargs: Any) -> ModelResponse:
"""Calling the model with requests.post.
Args:
input_ (`str`):
The input string to the model.
Returns:
`dict`: A dictionary that contains the response of the model and
related
information (e.g. cost, time, the number of tokens, etc.).
Note:
`parse_func`, `fault_handler` and `max_retries` are reserved for
`_response_parse_decorator` to parse and check the response
generated by model wrapper. Their usages are listed as follows:
- `parse_func` is a callable function used to parse and check
the response generated by the model, which takes the response
as input.
- `max_retries` is the maximum number of retries when the
`parse_func` raise an exception.
- `fault_handler` is a callable function which is called
when the response generated by the model is invalid after
`max_retries` retries.
"""
# step1: prepare keyword arguments
post_args = {**self.post_args, **kwargs}
request_kwargs = {
"url": self.api_url,
"json": {self.messages_key: input_, **self.json_args},
"headers": self.headers or {},
**post_args,
}
# step2: prepare post requests
for i in range(1, self.max_retries + 1):
response = requests.post(**request_kwargs)
if response.status_code == requests.codes.ok:
break
if i < self.max_retries:
logger.warning(
f"Failed to call the model with "
f"requests.codes == {response.status_code}, retry "
f"{i + 1}/{self.max_retries} times",
)
time.sleep(i * self.retry_interval)
# step3: record model invocation
# record the model api invocation, which will be skipped if
# `FileManager.save_api_invocation` is `False`
try:
response_json = response.json()
except requests.exceptions.JSONDecodeError as e:
raise RuntimeError(
f"Fail to serialize the response to json: \n{str(response)}",
) from e
self._save_model_invocation(
arguments=request_kwargs,
response=response_json,
)
# step4: parse the response
if response.status_code == requests.codes.ok:
return self._parse_response(response_json)
else:
logger.error(json.dumps(request_kwargs, indent=4))
raise RuntimeError(
f"Failed to call the model with {response.json()}",
)
[docs]
class PostAPIChatWrapper(PostAPIModelWrapperBase):
"""A post api model wrapper compatible with openai chat, e.g., vLLM,
FastChat."""
model_type: str = "post_api_chat"
def _parse_response(self, response: dict) -> ModelResponse:
return ModelResponse(
text=response["data"]["response"]["choices"][0]["message"][
"content"
],
)
[docs]
def format(
self,
*args: Union[Msg, Sequence[Msg]],
) -> Union[List[dict]]:
"""Format the input messages into a list of dict according to the model
name. For example, if the model name is prefixed with "gpt-", the
input messages will be formatted for OpenAI models.
Args:
args (`Union[Msg, Sequence[Msg]]`):
The input arguments to be formatted, where each argument
should be a `Msg` object, or a list of `Msg` objects.
In distribution, placeholder is also allowed.
Returns:
`Union[List[dict]]`:
The formatted messages.
"""
# Format according to the potential model field in the json_args
model_name = self.json_args.get(
"model",
self.json_args.get("model_name", None),
)
# OpenAI
if model_name and model_name.startswith("gpt-"):
return OpenAIChatWrapper.static_format(
*args,
model_name=model_name,
)
# Gemini
elif model_name and model_name.startswith("gemini"):
from .gemini_model import GeminiChatWrapper
return GeminiChatWrapper.format(*args)
# Include DashScope, ZhipuAI, Ollama, the other models supported by
# litellm and unknown models
else:
return ModelWrapperBase.format_for_common_chat_models(*args)
[docs]
class PostAPIDALLEWrapper(PostAPIModelWrapperBase):
"""A post api model wrapper compatible with openai dall_e"""
model_type: str = "post_api_dall_e"
deprecated_model_type: str = "post_api_dalle"
def _parse_response(self, response: dict) -> ModelResponse:
if "data" not in response["data"]["response"]:
if "error" in response["data"]["response"]:
error_msg = response["data"]["response"]["error"]["message"]
else:
error_msg = response["data"]["response"]
logger.error(f"Error in API call:\n{error_msg}")
raise ValueError(f"Error in API call:\n{error_msg}")
urls = [img["url"] for img in response["data"]["response"]["data"]]
return ModelResponse(image_urls=urls)
[docs]
def format(
self,
*args: Union[Msg, Sequence[Msg]],
) -> Union[List[dict], str]:
raise RuntimeError(
f"Model Wrapper [{type(self).__name__}] doesn't "
f"need to format the input. Please try to use the "
f"model wrapper directly.",
)
[docs]
class PostAPIEmbeddingWrapper(PostAPIModelWrapperBase):
"""
A post api model wrapper for embedding model
"""
model_type: str = "post_api_embedding"
def _parse_response(self, response: dict) -> ModelResponse:
"""
Parse the response json data into ModelResponse with embedding.
Args:
response (`dict`):
The response obtained from the API. This parsing assume the
structure of the response is the same as OpenAI's as following:
{
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [
0.0023064255,
-0.009327292,
.... (1536 floats total for ada-002)
-0.0028842222,
],
"index": 0
}
],
"model": "text-embedding-ada-002",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
}
"""
if (
"data" not in response
or len(response["data"]) < 1
or "embedding" not in response["data"][0]
):
error_msg = json.dumps(response, ensure_ascii=False, indent=2)
logger.error(f"Error in embedding API call:\n{error_msg}")
raise ValueError(f"Error in embedding API call:\n{error_msg}")
embeddings = [data["embedding"] for data in response["data"]]
return ModelResponse(
embedding=embeddings,
raw=response,
)
[docs]
def format(
self,
*args: Union[Msg, Sequence[Msg]],
) -> Union[List[dict], str]:
raise RuntimeError(
f"Model Wrapper [{type(self).__name__}] doesn't "
f"need to format the input. Please try to use the "
f"model wrapper directly.",
)