# -*- coding: utf-8 -*-
"""Model wrapper for Yi models"""
import json
from typing import (
List,
Union,
Sequence,
Optional,
Generator,
)
import requests
from ._model_utils import (
_verify_text_content_in_openai_message_response,
_verify_text_content_in_openai_delta_response,
)
from .model import ModelWrapperBase, ModelResponse
from ..message import Msg
[docs]
class YiChatWrapper(ModelWrapperBase):
"""The model wrapper for Yi Chat API.
Response:
- From https://platform.lingyiwanwu.com/docs
```json
{
"id": "cmpl-ea89ae83",
"object": "chat.completion",
"created": 5785971,
"model": "yi-large-rag",
"usage": {
"completion_tokens": 113,
"prompt_tokens": 896,
"total_tokens": 1009
},
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Today in Los Angeles, the weather ...",
},
"finish_reason": "stop"
}
]
}
```
"""
model_type: str = "yi_chat"
[docs]
def __init__(
self,
config_name: str,
model_name: str,
api_key: str,
max_tokens: Optional[int] = None,
top_p: float = 0.9,
temperature: float = 0.3,
stream: bool = False,
) -> None:
"""Initialize the Yi chat model wrapper.
Args:
config_name (`str`):
The name of the configuration to use.
model_name (`str`):
The name of the model to use, e.g. yi-large, yi-medium, etc.
api_key (`str`):
The API key for the Yi API.
max_tokens (`Optional[int]`, defaults to `None`):
The maximum number of tokens to generate, defaults to `None`.
top_p (`float`, defaults to `0.9`):
The randomness parameters in the range [0, 1].
temperature (`float`, defaults to `0.3`):
The temperature parameter in the range [0, 2].
stream (`bool`, defaults to `False`):
Whether to stream the response or not.
"""
super().__init__(config_name, model_name)
if top_p > 1 or top_p < 0:
raise ValueError(
f"The `top_p` parameter must be in the range [0, 1], but got "
f"{top_p} instead.",
)
if temperature < 0 or temperature > 2:
raise ValueError(
f"The `temperature` parameter must be in the range [0, 2], "
f"but got {temperature} instead.",
)
self.api_key = api_key
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.stream = stream
def __call__(
self,
messages: list[dict],
stream: Optional[bool] = None,
) -> ModelResponse:
"""Invoke the Yi Chat API by sending a list of messages."""
# Checking messages
if not isinstance(messages, list):
raise ValueError(
f"Yi `messages` field expected type `list`, "
f"got `{type(messages)}` instead.",
)
if not all("role" in msg and "content" in msg for msg in messages):
raise ValueError(
"Each message in the 'messages' list must contain a 'role' "
"and 'content' key for Yi API.",
)
if stream is None:
stream = self.stream
# Forward to generate response
kwargs = {
"url": "https://api.lingyiwanwu.com/v1/chat/completions",
"json": {
"model": self.model_name,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"stream": stream,
},
"headers": {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
}
response = requests.post(**kwargs)
response.raise_for_status()
if stream:
def generator() -> Generator[str, None, None]:
text = ""
last_chunk = {}
for line in response.iter_lines():
if line:
line_str = line.decode("utf-8").strip()
# Remove prefix "data: " if exists
json_str = line_str.removeprefix("data: ")
# The last response is "data: [DONE]"
if json_str == "[DONE]":
continue
try:
chunk = json.loads(json_str)
if _verify_text_content_in_openai_delta_response(
chunk,
):
text += chunk["choices"][0]["delta"]["content"]
yield text
last_chunk = chunk
except json.decoder.JSONDecodeError as e:
raise json.decoder.JSONDecodeError(
f"Invalid JSON: {json_str}",
e.doc,
e.pos,
) from e
# In Yi Chat API, the last valid chunk will save all the text
# in this message
self._save_model_invocation_and_update_monitor(
kwargs,
last_chunk,
)
return ModelResponse(
stream=generator(),
)
else:
response = response.json()
self._save_model_invocation_and_update_monitor(
kwargs,
response,
)
# Re-use the openai response checking function
if _verify_text_content_in_openai_message_response(response):
return ModelResponse(
text=response["choices"][0]["message"]["content"],
raw=response,
)
else:
raise RuntimeError(
f"Invalid response from Yi Chat API: {response}",
)
def _save_model_invocation_and_update_monitor(
self,
kwargs: dict,
response: dict,
) -> None:
"""Save model invocation and update the monitor accordingly.
Args:
kwargs (`dict`):
The keyword arguments used in model invocation
response (`dict`):
The response from model API
"""
self._save_model_invocation(
arguments=kwargs,
response=response,
)
usage = response.get("usage", None)
if usage is not None:
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
self.monitor.update_text_and_embedding_tokens(
model_name=self.model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)