# -*- coding: utf-8 -*-
# pylint: disable=C0301
""" Utils for mixing model's answers in agentscope """
from typing import Union, List, Sequence, Tuple
import concurrent.futures
from loguru import logger
from agentscope.manager import ModelManager
from agentscope.message import Msg
from agentscope.models import ModelWrapperBase
# Referenced from the project [MoA](https://github.com/togethercomputer/MoA)
DEFAULT_AGGREGATOR_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
Responses from models:""" # noqa
[docs]
class MixtureOfAgents:
"""
The MoA model that take multiple models and aggregate their responses,
leverages the collective strengths of multiple LLMs to enhance performance.
Reference from the project [MoA](https://github.com/togethercomputer/MoA).
"""
[docs]
def __init__(
self,
main_model: Union[str, ModelWrapperBase],
reference_models: List[Union[str, ModelWrapperBase]],
rounds: int = 1,
aggregator_prompt: str = DEFAULT_AGGREGATOR_PROMPT,
show_internal: bool = False,
) -> None:
"""
Args:
main_model (`Union[str, ModelWrapperBase]`):
The main_model will make the final aggregation in the last
round, summarizing all the previous responses from models.
Can take both config name of model or model instance as input.
reference_models (`List[Union[str, ModelWrapperBase]]`):
The reference_models used for generating different responses
in each round.
Can take both config name of model or model instance as input.
We encourage using different models to get better diversity.
Empirically, responses generated by heterogeneous models
contribute more than those produced by the same model.
rounds (`int`):
The number of processing rounds to refine the responses.
Can range from 0 to inf.
aggregator_prompt (`str`):
The prompt used for aggregating responses.
Using the prompt from paper MoA by default.
show_internal (`bool`):
Whether to show the internal process of MoA.
"""
model_manager = ModelManager.get_instance()
# init main_model
if isinstance(main_model, str):
self.main_model = model_manager.get_model_by_config_name(
main_model,
)
elif isinstance(main_model, ModelWrapperBase):
self.main_model = main_model
else:
raise ValueError(
"main_model must be a string or a ModelWrapperBase instance",
)
# init reference_models
self.reference_models: List[ModelWrapperBase] = []
for ref_model in reference_models:
if isinstance(ref_model, str):
self.reference_models.append(
model_manager.get_model_by_config_name(ref_model),
)
elif isinstance(ref_model, ModelWrapperBase):
self.reference_models.append(ref_model)
else:
raise ValueError(
"reference_models must be a list of strings "
"or ModelWrapperBase instances",
)
self.references: List[str] = [
"" for _ in range(len(self.reference_models))
]
self.rounds = rounds
self.aggregator_prompt = aggregator_prompt
self.show_internal = show_internal
def _get_res_with_aggregate_model(
self,
aggre_model: ModelWrapperBase,
) -> str:
messages = []
messages.append(
Msg(role="system", content=self.aggregator_prompt, name="system"),
)
for i, ref in enumerate(self.references, start=0):
messages.append(
Msg(
role="user",
content=ref,
name=f"Model_{i}",
),
)
aggre_format_msg = aggre_model.format(messages)
aggre_res = aggre_model(aggre_format_msg)
return aggre_res.text
def __call__(
self,
*args: Union[Msg, Sequence[Msg]],
) -> str:
"""
Get model response from messages.
Is equivalent to calling a model with:
```
format_msg = model.format(messages)
return model(format_msg)
```
Args:
*args (`Union[Msg, Sequence[Msg]]`):
The messages to be sent to the model.
"""
def _process_reference(
i: int,
ref_model: ModelWrapperBase,
*args: Union[Msg, Sequence[Msg]],
) -> Tuple[int, str]:
format_msg = ref_model.format(*args)
ref_model_res = ref_model(format_msg)
return i, ref_model_res.text
def _process_new_refs(
i: int,
ref_model: ModelWrapperBase,
) -> Tuple[int, str]:
return i, self._get_res_with_aggregate_model(ref_model)
# get all the references
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(_process_reference, i, ref_model, *args)
for i, ref_model in enumerate(self.reference_models, start=0)
]
for future in concurrent.futures.as_completed(futures):
i, result = future.result()
self.references[i] = result
if self.show_internal:
logger.info(f"Round {0}, Model_{i}: {result}")
for r in range(self.rounds):
if self.show_internal:
logger.info("=" * 20)
new_refs = ["" for _ in range(len(self.reference_models))]
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(_process_new_refs, i, ref_model)
for i, ref_model in enumerate(
self.reference_models,
start=0,
)
]
for future in concurrent.futures.as_completed(futures):
i, result = future.result()
new_refs[i] = result
if self.show_internal:
print(f"Round {r + 1}, Model_{i}: {result}")
self.references = new_refs
final_res = self._get_res_with_aggregate_model(self.main_model)
return final_res