# -*- coding: utf-8 -*-"""The main entry point for agent learning."""fromdataclassesimportdataclassfrom._workflowimport(WorkflowType,_validate_function_signature,)
[docs]deftune(workflow_func:WorkflowType,config_path:str)->None:"""Train the agent workflow with the specific configuration. Args: workflow_func (WorkflowType): The learning workflow function to execute. config_path (str): The configuration for the learning process. """try:fromtrinity.cli.launcherimportrun_stagefromtrinity.common.configimportConfigfromomegaconfimportOmegaConfexceptImportErrorase:raiseImportError("Trinity-RFT is not installed. Please install it with ""`pip install trinity-rft`.",)fromeifnot_validate_function_signature(workflow_func):raiseValueError("Invalid workflow function signature, please ""check the types of your workflow input/output.",)@dataclassclassTuneConfig(Config):"""Configuration for learning process."""defto_trinity_config(self,workflow_func:WorkflowType)->Config:"""Convert to Trinity-RFT compatible configuration."""workflow_name="agentscope_workflow_adapter"self.buffer.explorer_input.taskset.default_workflow_type=(workflow_name)self.buffer.explorer_input.default_workflow_type=workflow_nameself.buffer.explorer_input.taskset.workflow_args["workflow_func"]=workflow_funcreturnself.check_and_update()@classmethoddefload_config(cls,config_path:str)->"TuneConfig":"""Load the learning configuration from a YAML file. Args: config_path (str): The path to the configuration file. Returns: TuneConfig: The loaded learning configuration. """schema=OmegaConf.structured(cls)yaml_config=OmegaConf.load(config_path)try:config=OmegaConf.merge(schema,yaml_config)returnOmegaConf.to_object(config)exceptExceptionase:raiseValueError(f"Invalid configuration: {e}")fromereturnrun_stage(config=TuneConfig.load_config(config_path).to_trinity_config(workflow_func,),)