# -*- coding: utf-8 -*-
"""A general dialog agent."""
from typing import Optional, Union, Sequence, Any
from loguru import logger
from ..message import Msg
from .agent import AgentBase
[docs]
class DialogAgent(AgentBase):
"""A simple agent used to perform a dialogue. Your can set its role by
`sys_prompt`."""
[docs]
def __init__(
self,
name: str,
sys_prompt: str,
model_config_name: str,
use_memory: bool = True,
**kwargs: Any,
) -> None:
"""Initialize the dialog agent.
Arguments:
name (`str`):
The name of the agent.
sys_prompt (`Optional[str]`):
The system prompt of the agent, which can be passed by args
or hard-coded in the agent.
model_config_name (`str`):
The name of the model config, which is used to load model from
configuration.
use_memory (`bool`, defaults to `True`):
Whether the agent has memory.
"""
super().__init__(
name=name,
sys_prompt=sys_prompt,
model_config_name=model_config_name,
use_memory=use_memory,
)
if kwargs:
logger.warning(
f"Unused keyword arguments are provided: {kwargs}",
)
[docs]
def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg:
"""Reply function of the agent. Processes the input data,
generates a prompt using the current dialogue memory and system
prompt, and invokes the language model to produce a response. The
response is then formatted and added to the dialogue memory.
Args:
x (`Optional[Union[Msg, Sequence[Msg]]]`, defaults to `None`):
The input message(s) to the agent, which also can be omitted if
the agent doesn't need any input.
Returns:
`Msg`: The output message generated by the agent.
"""
# record the input if needed
if self.memory:
self.memory.add(x)
# prepare prompt
prompt = self.model.format(
Msg("system", self.sys_prompt, role="system"),
self.memory
and self.memory.get_memory()
or x, # type: ignore[arg-type]
)
# call llm and generate response
response = self.model(prompt)
# Print/speak the message in this agent's voice
# Support both streaming and non-streaming responses by "or"
self.speak(response.stream or response.text)
msg = Msg(self.name, response.text, role="assistant")
# Record the message in memory
if self.memory:
self.memory.add(msg)
return msg