agentscope.models.yi_model 源代码

# -*- coding: utf-8 -*-
"""Model wrapper for Yi models"""
import json
from typing import (
    List,
    Union,
    Sequence,
    Optional,
    Generator,
)

import requests

from ._model_utils import (
    _verify_text_content_in_openai_message_response,
    _verify_text_content_in_openai_delta_response,
)
from .model import ModelWrapperBase, ModelResponse
from ..message import Msg


[文档] class YiChatWrapper(ModelWrapperBase): """The model wrapper for Yi Chat API. Response: - From https://platform.lingyiwanwu.com/docs ```json { "id": "cmpl-ea89ae83", "object": "chat.completion", "created": 5785971, "model": "yi-large-rag", "usage": { "completion_tokens": 113, "prompt_tokens": 896, "total_tokens": 1009 }, "choices": [ { "index": 0, "message": { "role": "assistant", "content": "Today in Los Angeles, the weather ...", }, "finish_reason": "stop" } ] } ``` """ model_type: str = "yi_chat"
[文档] def __init__( self, config_name: str, model_name: str, api_key: str, max_tokens: Optional[int] = None, top_p: float = 0.9, temperature: float = 0.3, stream: bool = False, ) -> None: """Initialize the Yi chat model wrapper. Args: config_name (`str`): The name of the configuration to use. model_name (`str`): The name of the model to use, e.g. yi-large, yi-medium, etc. api_key (`str`): The API key for the Yi API. max_tokens (`Optional[int]`, defaults to `None`): The maximum number of tokens to generate, defaults to `None`. top_p (`float`, defaults to `0.9`): The randomness parameters in the range [0, 1]. temperature (`float`, defaults to `0.3`): The temperature parameter in the range [0, 2]. stream (`bool`, defaults to `False`): Whether to stream the response or not. """ super().__init__(config_name, model_name) if top_p > 1 or top_p < 0: raise ValueError( f"The `top_p` parameter must be in the range [0, 1], but got " f"{top_p} instead.", ) if temperature < 0 or temperature > 2: raise ValueError( f"The `temperature` parameter must be in the range [0, 2], " f"but got {temperature} instead.", ) self.api_key = api_key self.max_tokens = max_tokens self.top_p = top_p self.temperature = temperature self.stream = stream
def __call__( self, messages: list[dict], stream: Optional[bool] = None, ) -> ModelResponse: """Invoke the Yi Chat API by sending a list of messages.""" # Checking messages if not isinstance(messages, list): raise ValueError( f"Yi `messages` field expected type `list`, " f"got `{type(messages)}` instead.", ) if not all("role" in msg and "content" in msg for msg in messages): raise ValueError( "Each message in the 'messages' list must contain a 'role' " "and 'content' key for Yi API.", ) if stream is None: stream = self.stream # Forward to generate response kwargs = { "url": "https://api.lingyiwanwu.com/v1/chat/completions", "json": { "model": self.model_name, "messages": messages, "temperature": self.temperature, "max_tokens": self.max_tokens, "top_p": self.top_p, "stream": stream, }, "headers": { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, } response = requests.post(**kwargs) response.raise_for_status() if stream: def generator() -> Generator[str, None, None]: text = "" last_chunk = {} for line in response.iter_lines(): if line: line_str = line.decode("utf-8").strip() # Remove prefix "data: " if exists json_str = line_str.removeprefix("data: ") # The last response is "data: [DONE]" if json_str == "[DONE]": continue try: chunk = json.loads(json_str) if _verify_text_content_in_openai_delta_response( chunk, ): text += chunk["choices"][0]["delta"]["content"] yield text last_chunk = chunk except json.decoder.JSONDecodeError as e: raise json.decoder.JSONDecodeError( f"Invalid JSON: {json_str}", e.doc, e.pos, ) from e # In Yi Chat API, the last valid chunk will save all the text # in this message self._save_model_invocation_and_update_monitor( kwargs, last_chunk, ) return ModelResponse( stream=generator(), ) else: response = response.json() self._save_model_invocation_and_update_monitor( kwargs, response, ) # Re-use the openai response checking function if _verify_text_content_in_openai_message_response(response): return ModelResponse( text=response["choices"][0]["message"]["content"], raw=response, ) else: raise RuntimeError( f"Invalid response from Yi Chat API: {response}", )
[文档] def format( self, *args: Union[Msg, Sequence[Msg]], ) -> List[dict]: """Format the messages into the required format of Yi Chat API. Note this strategy maybe not suitable for all scenarios, and developers are encouraged to implement their own prompt engineering strategies. The following is an example: .. code-block:: python prompt1 = model.format( Msg("system", "You're a helpful assistant", role="system"), Msg("Bob", "Hi, how can I help you?", role="assistant"), Msg("user", "What's the date today?", role="user") ) The prompt will be as follows: .. code-block:: python # prompt1 [ { "role": "user", "content": ( "You're a helpful assistant\\n" "\\n" "## Conversation History\\n" "Bob: Hi, how can I help you?\\n" "user: What's the date today?" ) } ] Args: args (`Union[Msg, Sequence[Msg]]`): The input arguments to be formatted, where each argument should be a `Msg` object, or a list of `Msg` objects. In distribution, placeholder is also allowed. Returns: `List[dict]`: The formatted messages. """ # TODO: Support Vision model if self.model_name == "yi-vision": raise NotImplementedError( "Yi Vision model is not supported in the current version, " "please format the messages manually.", ) return ModelWrapperBase.format_for_common_chat_models(*args)
def _save_model_invocation_and_update_monitor( self, kwargs: dict, response: dict, ) -> None: """Save model invocation and update the monitor accordingly. Args: kwargs (`dict`): The keyword arguments used in model invocation response (`dict`): The response from model API """ self._save_model_invocation( arguments=kwargs, response=response, ) usage = response.get("usage", None) if usage is not None: prompt_tokens = usage.get("prompt_tokens", 0) completion_tokens = usage.get("completion_tokens", 0) self.monitor.update_text_and_embedding_tokens( model_name=self.model_name, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, )