# -*- coding: utf-8 -*-
"""The model manager for AgentScope."""
import importlib
import json
import os
from typing import Any, Union, Type
from loguru import logger
from ..models import ModelWrapperBase, _BUILD_IN_MODEL_WRAPPERS
[docs]
class ModelManager:
"""The model manager for AgentScope, which is responsible for loading and
managing model configurations and models."""
_instance = None
model_configs: dict[str, dict] = {}
"""The model configs"""
model_wrapper_mapping: dict[str, Type[ModelWrapperBase]] = {}
"""The registered model wrapper classes."""
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
"""Create a singleton instance."""
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
else:
raise RuntimeError(
"The Model manager has been initialized. Try to use "
"ModelManager.get_instance() to get the instance.",
)
return cls._instance
[docs]
@classmethod
def get_instance(cls) -> "ModelManager":
"""Get the instance of the singleton class."""
if cls._instance is None:
raise ValueError(
"AgentScope hasn't been initialized. Please call "
"`agentscope.init` function first.",
)
return cls._instance
def __init__(
self,
) -> None:
"""Initialize the model manager with model configs"""
self.model_configs = {}
self.model_wrapper_mapping = {}
for cls_name in _BUILD_IN_MODEL_WRAPPERS:
models_module = importlib.import_module("agentscope.models")
cls = getattr(models_module, cls_name)
if getattr(cls, "model_type", None):
self.register_model_wrapper_class(cls, exist_ok=False)
[docs]
def initialize(
self,
model_configs: Union[dict, str, list, None] = None,
) -> None:
"""Initialize the model manager with model configs."""
if model_configs is not None:
self.load_model_configs(model_configs)
[docs]
def clear_model_configs(self) -> None:
"""Clear the loaded model configs."""
self.model_configs.clear()
[docs]
def load_model_configs(
self,
model_configs: Union[dict, str, list],
clear_existing: bool = False,
) -> None:
"""read model configs from a path or a list of dicts.
Args:
model_configs (`Union[str, list, dict]`):
The path of the model configs | a config dict | a list of model
configs.
clear_existing (`bool`, defaults to `False`):
Whether to clear the loaded model configs before reading.
Returns:
`dict`:
The model configs.
"""
if clear_existing:
self.clear_model_configs()
cfgs = model_configs
# Load model configs from a path
if isinstance(cfgs, str):
if not os.path.exists(cfgs):
raise FileNotFoundError(
f"Cannot find the model configs file in the given path "
f"`{model_configs}`.",
)
with open(cfgs, "r", encoding="utf-8") as f:
cfgs = json.load(f)
# Load model configs from a dict or a list of dicts
if isinstance(cfgs, dict):
cfgs = [cfgs]
if isinstance(cfgs, list):
if not all(isinstance(_, dict) for _ in cfgs):
raise ValueError(
"The model config unit should be a dict.",
)
else:
raise TypeError(
f"Invalid type of model_configs, it could be a dict, a list "
f"of dicts, or a path to a json file (containing a dict or a "
f"list of dicts), but got {type(model_configs)}",
)
# Check and register the model configs
for cfg in cfgs:
# Check the format of model configs
if "config_name" not in cfg or "model_type" not in cfg:
raise ValueError(
"The `config_name` and `model_type` fields are required "
f"for model config, but got: {cfg}",
)
config_name = cfg["config_name"]
if config_name in self.model_configs:
logger.warning(
f"Config name [{config_name}] already exists.",
)
continue
self.model_configs[config_name] = cfg
# print the loaded model configs
logger.info(
"Load configs for model wrapper: {}",
", ".join(self.model_configs.keys()),
)
[docs]
def get_model_by_config_name(self, config_name: str) -> ModelWrapperBase:
"""Load the model by config name, and return the model wrapper."""
if len(self.model_configs) == 0:
raise ValueError(
"No model configs loaded, please call "
"`read_model_configs` first.",
)
# Find model config by name
if config_name not in self.model_configs:
raise ValueError(
f"Cannot find [{config_name}] in loaded configurations.",
)
config = self.model_configs.get(config_name, None)
if config is None:
raise ValueError(
f"Cannot find [{config_name}] in loaded configurations.",
)
model_type = config["model_type"]
if model_type not in self.model_wrapper_mapping:
raise ValueError(
f"Unsupported model_type `{model_type}`, currently supported "
f"model types: "
f"{', '.join(list(self.model_wrapper_mapping.keys()))}. ",
)
kwargs = {k: v for k, v in config.items() if k != "model_type"}
return self.model_wrapper_mapping[model_type](**kwargs)
[docs]
def get_config_by_name(self, config_name: str) -> Union[dict, None]:
"""Load the model config by name, and return the config dict."""
return self.model_configs.get(config_name, None)
[docs]
def state_dict(self) -> dict:
"""Serialize the model manager into a dict."""
return {
"model_configs": self.model_configs,
}
[docs]
def load_dict(self, data: dict) -> None:
"""Load the model manager from a dict."""
self.clear_model_configs()
assert "model_configs" in data
self.model_configs = data["model_configs"]
[docs]
def register_model_wrapper_class(
self,
model_wrapper_class: Type[ModelWrapperBase],
exist_ok: bool,
) -> None:
"""Register the model wrapper class.
Args:
model_wrapper_class (`Type[ModelWrapperBase]`):
The model wrapper class to be registered, which must inherit
from `ModelWrapperBase`.
exist_ok (`bool`):
Whether to overwrite the existing model wrapper with the same
name.
"""
if not issubclass(model_wrapper_class, ModelWrapperBase):
raise TypeError(
"The model wrapper class should inherit from "
f"ModelWrapperBase, but got {model_wrapper_class}.",
)
if not hasattr(model_wrapper_class, "model_type"):
raise ValueError(
f"The model wrapper class `{model_wrapper_class}` should "
f"have a `model_type` attribute.",
)
model_type = model_wrapper_class.model_type
if model_type in self.model_wrapper_mapping:
if exist_ok:
logger.warning(
f'Model wrapper "{model_type}" '
"already exists, overwrite it.",
)
self.model_wrapper_mapping[model_type] = model_wrapper_class
else:
raise ValueError(
f'Model wrapper "{model_type}" already exists, '
"please set `exist_ok=True` to overwrite it.",
)
else:
self.model_wrapper_mapping[model_type] = model_wrapper_class
[docs]
def flush(self) -> None:
"""Flush the model manager."""
self.clear_model_configs()