# -*- coding: utf-8 -*-"""A model class for RL Training with Trinity-RFT."""fromtypingimport(Optional,TYPE_CHECKING,)from._openai_modelimportOpenAIChatModelfrom..typesimportJSONSerializableObjectifTYPE_CHECKING:fromopenaiimportAsyncOpenAIelse:AsyncOpenAI="openai.AsyncOpenAI"
[文档]classTrinityChatModel(OpenAIChatModel):"""A model class for RL Training with Trinity-RFT."""
[文档]def__init__(self,openai_async_client:AsyncOpenAI,generate_kwargs:dict[str,JSONSerializableObject]|None=None,enable_thinking:Optional[bool]=None,)->None:"""Initialize the Trinity model class. Args: openai_async_client (`AsyncOpenAI`): The OpenAI async client instance provided by Trinity-RFT. generate_kwargs (`dict[str, JSONSerializableObject] | None`, \ optional): Additional keyword arguments to pass to the model's generate method. Defaults to None. enable_thinking (`bool`, optional): Whether to enable the model's thinking capability. Only applicable for Qwen3 series models. Defaults to None. """model_name=getattr(openai_async_client,"model_path",None)ifmodel_nameisNone:raiseValueError("The provided openai_async_client does not have a ""`model_path` attribute. Please ensure you are using ""the instance provided by Trinity-RFT.",)super().__init__(model_name=model_name,api_key="EMPTY",generate_kwargs=generate_kwargs,stream=False,# RL training does not support streaming)ifenable_thinkingisnotNone:if"chat_template_kwargs"notinself.generate_kwargs:self.generate_kwargs["chat_template_kwargs"]={}assertisinstance(self.generate_kwargs["chat_template_kwargs"],dict,),"chat_template_kwargs must be a dictionary."self.generate_kwargs["chat_template_kwargs"]["enable_thinking"]=enable_thinking# change the client instance to the provided oneself.client=openai_async_client