# -*- coding: utf-8 -*-
"""User Agent class"""
from typing import Union, Type
from typing import Optional
from pydantic import BaseModel
from ._agent import AgentBase
from ._user_input import TerminalUserInput, UserInputBase
from ..message import Msg
[docs]
class UserAgent(AgentBase):
"""User agent class"""
_input_method: UserInputBase = TerminalUserInput()
"""The user input method, you can exchange it with you own input method
by calling the `register_input_method` method."""
def __init__(
self,
name: str = "User",
input_hint: str = "User Input: ",
require_url: bool = False,
) -> None:
"""Initialize a UserAgent object.
Args:
name (`str`, defaults to `"User"`):
The name of the agent. Defaults to "User".
input_hint (`str`, defaults to `"User Input: "`):
The hint of the input. Defaults to "User Input: ".
require_url (`bool`, defaults to `False`):
Whether the agent requires user to input a URL. Defaults to
False. The URL can lead to a website, a file,
or a directory. It will be added into the generated message
in field `url`.
"""
super().__init__(name=name)
self.name = name
self.input_hint = input_hint
self.require_url = require_url
[docs]
def reply(
self,
x: Optional[Union[Msg, list[Msg]]] = None,
structured_output: Optional[Type[BaseModel]] = None,
) -> Msg:
"""Receive input message(s) and generate a reply message from the user.
The input source can be set by calling `override_input_method` and
`override_class_input_method` methods.
For example, if you want to receive user input from your own source,
you can create a class that inherits from `UserInputBase` and
override the input method of the `UserAgent` class.
Args:
x (`Optional[Union[Msg, list[Msg]]]`, defaults to `None`):
The input message(s) to the agent, which also can be omitted if
the agent doesn't need any input.
structured_output (`Optional[Type[BaseModel]]`, defaults
to `None`):
A child class of `pydantic.BaseModel` that defines the
structured output format. If provided, the user will be
prompted to fill in the required fields. The output will be
placed in the `metadata` field of the output message.
Returns:
`Msg`: The output message generated by the agent.
"""
if self.memory:
self.memory.add(x)
# Get the input from the specified input method.
input_data = self._input_method(
agent_id=self.agent_id,
agent_name=self.name,
structured_schema=structured_output,
)
blocks_input = input_data.blocks_input
if (
blocks_input
and len(blocks_input) == 1
and blocks_input[0].get("type") == "text"
):
# Turn blocks_input into a string if only one text block exists
blocks_input = blocks_input[0].get("text")
msg = Msg(
self.name,
content=blocks_input,
role="user",
metadata=input_data.structured_input,
)
self.speak(msg)
if self.memory:
self.memory.add(msg)
return msg