Source code for agentscope.agent._user_input

# -*- coding: utf-8 -*-
"""The user input related classes."""
import json.decoder
import time
from abc import abstractmethod
from dataclasses import dataclass
from queue import Queue
from threading import Event
from typing import Any, Type, List

import jsonschema
import requests
import shortuuid
import socketio
from pydantic import BaseModel
import json5

from .. import _config
from .._logging import logger
from ..message import (
    TextBlock,
    VideoBlock,
    AudioBlock,
    ImageBlock,
)


[docs] @dataclass class UserInputData: """The user input data.""" blocks_input: List[TextBlock | ImageBlock | AudioBlock | VideoBlock] = None """The text input from the user""" structured_input: dict[str, Any] | None = None """The structured input from the user"""
[docs] class UserInputBase: """The base class used to handle the user input from different sources."""
[docs] @abstractmethod async def __call__( self, agent_id: str, agent_name: str, *args: Any, structured_model: Type[BaseModel] | None = None, **kwargs: Any, ) -> UserInputData: """The user input method, which returns the user input and the required structured data. Args: agent_id (`str`): The agent identifier. agent_name (`str`): The agent name. structured_model (`Type[BaseModel] | None`, optional): A base model class that defines the structured input format. Returns: `UserInputData`: The user input data. """
[docs] class TerminalUserInput(UserInputBase): """The terminal user input."""
[docs] def __init__(self, input_hint: str = "User Input: ") -> None: """Initialize the terminal user input with a hint.""" self.input_hint = input_hint
[docs] async def __call__( self, agent_id: str, agent_name: str, *args: Any, structured_model: Type[BaseModel] | None = None, **kwargs: Any, ) -> UserInputData: """Handle the user input from the terminal. Args: agent_id (`str`): The agent identifier. agent_name (`str`): The agent name. structured_model (`Type[BaseModel] | None`, optional): A base model class that defines the structured input format. Returns: `UserInputData`: The user input data. """ text_input = input(self.input_hint) structured_input = None if structured_model is not None: structured_input = {} json_schema = structured_model.model_json_schema() required = json_schema.get("required", []) print("Structured input (press Enter to skip for optional):)") for key, item in json_schema.get("properties").items(): requirements = {**item} requirements.pop("title") while True: res = input(f"\t{key} ({requirements}): ") if res == "": if key in required: print(f"Key {key} is required.") continue res = item.get("default", None) if item.get("type").lower() == "integer": try: res = json5.loads(res) except json.decoder.JSONDecodeError as e: print( "\033[31mInvalid input with error:\n" "```\n" f"{e}\n" "```\033[0m", ) continue try: jsonschema.validate(res, item) structured_input[key] = res break except jsonschema.ValidationError as e: print( f"\033[31mValidation error:\n```\n{e}\n```\033[0m", ) time.sleep(0.5) return UserInputData( blocks_input=[TextBlock(type="text", text=text_input)], structured_input=structured_input, )
[docs] class StudioUserInput(UserInputBase): """The class that host the user input on the AgentScope Studio.""" _websocket_namespace: str = "/python"
[docs] def __init__( self, studio_url: str, run_id: str, max_retries: int = 3, reconnect_attempts: int = 3, reconnection_delay: int = 1, reconnection_delay_max: int = 5, ) -> None: """Initialize the StudioUserInput object. Args: studio_url (`str`): The URL of the AgentScope Studio. run_id (`str`): The current run identity. max_retries (`int`, defaults to `3`): The maximum number of retries to get user input. """ self._is_connected = False self._is_reconnecting = False self.studio_url = studio_url self.run_id = run_id self.max_retries = max_retries # Init Websocket self.sio = socketio.Client( reconnection=True, reconnection_attempts=reconnect_attempts, reconnection_delay=reconnection_delay, reconnection_delay_max=reconnection_delay_max, ) self.input_queues = {} self.input_events = {} @self.sio.on("connect", namespace=self._websocket_namespace) def on_connect() -> None: self._is_connected = True logger.info( 'Connected to AgentScope Studio with project name "%s" and ' 'run name "%s".', self.studio_url, run_id, ) logger.info( "View the run at: %s/dashboard/projects/%s", self.studio_url, _config.project, ) @self.sio.on("disconnect", namespace=self._websocket_namespace) def on_disconnect() -> None: self._is_connected = False logger.info( "Disconnected from AgentScope Studio at %s", self.studio_url, ) @self.sio.on("reconnect", namespace=self._websocket_namespace) def on_reconnect(attempt_number: int) -> None: self._is_connected = True self._is_reconnecting = False logger.info( "Reconnected to AgentScope Studio at %s with run_id %s after " "%d attempts", self.studio_url, self.run_id, attempt_number, ) @self.sio.on("reconnect_attempt", namespace=self._websocket_namespace) def on_reconnect_attempt(attempt_number: int) -> None: self._is_reconnecting = True logger.info( "Attempting to reconnect to AgentScope Studio at %s " "(attempt %d)", self.studio_url, attempt_number, ) @self.sio.on("reconnect_failed", namespace=self._websocket_namespace) def on_reconnect_failed() -> None: self._is_reconnecting = False logger.error( "Failed to reconnect to AgentScope Studio at %s", self.studio_url, ) @self.sio.on("reconnect_error", namespace=self._websocket_namespace) def on_reconnect_error(error: Any) -> None: logger.error( "Error while reconnecting to AgentScope Studio at %s: %s", self.studio_url, str(error), ) # The AgentScope Studio backend send the "sendUserInput" event to # the current python run @self.sio.on("forwardUserInput", namespace=self._websocket_namespace) def receive_user_input( request_id: str, blocks_input: List[ TextBlock | ImageBlock | AudioBlock | VideoBlock ], structured_input: dict[str, Any], ) -> None: if request_id in self.input_queues: self.input_queues[request_id].put( UserInputData( blocks_input=blocks_input, structured_input=structured_input, ), ) self.input_events[request_id].set() try: self.sio.connect( f"{self.studio_url}", namespaces=["/python"], auth={"run_id": self.run_id}, ) except Exception as e: raise RuntimeError( f"Failed to connect to AgentScope Studio at {self.studio_url}", ) from e
def _ensure_connected( self, timeout: float = 30.0, check_interval: float = 5.0, ) -> None: """Ensure the connection is established or wait for reconnection. Args: timeout (`float`): Maximum time to wait for reconnection in seconds. Defaults to 30.0. check_interval (`float`): Interval between connection checks in seconds. Defaults to 1.0. Raises: `RuntimeError`: If connection cannot be established within timeout. """ if self._is_connected: return if self._is_reconnecting: start_time = time.time() while self._is_reconnecting: # Check timeout elapsed_time = time.time() - start_time if elapsed_time > timeout: raise RuntimeError( f"Reconnection timeout after {elapsed_time} seconds", ) # Log status logger.info( "Waiting for reconnection... (%.1fs / %.1fs)", elapsed_time, timeout, ) # Wait for next check time.sleep(check_interval) # After reconnection attempt completed, check final status if self._is_connected: return # Not connected and not reconnecting raise RuntimeError( f"Not connected to AgentScope Studio at {self.studio_url}.", )
[docs] async def __call__( # type: ignore[override] self, agent_id: str, agent_name: str, *args: Any, structured_model: Type[BaseModel] | None = None, ) -> UserInputData: """Get the user input from AgentScope Studio. Args: agent_id (`str`): The identity of the agent. agent_name (`str`): The name of the agent. structured_model (`Type[BaseModel] | None`, optional): The base model class of the structured input. Raises: `RuntimeError`: Failed to get user input from AgentScope Studio. Returns: `UserInputData`: The user input. """ self._ensure_connected() request_id = shortuuid.uuid() self.input_queues[request_id] = Queue() self.input_events[request_id] = Event() if structured_model is None: structured_input = None else: structured_input = structured_model.model_json_schema() n_retry = 0 while True: try: response = requests.post( f"{self.studio_url}/trpc/requestUserInput", json={ "requestId": request_id, "runId": self.run_id, "agentId": agent_id, "agentName": agent_name, "structuredInput": structured_input, }, ) response.raise_for_status() break except Exception as e: if n_retry < self.max_retries: n_retry += 1 continue raise RuntimeError( "Failed to get user input from AgentScope Studio", ) from e try: self.input_events[request_id].wait() response_data = self.input_queues[request_id].get() return response_data finally: self.input_queues.pop(request_id, None) self.input_events.pop(request_id, None)
def __del__(self) -> None: """Cleanup socket connection when object it destroyed""" try: self.sio.disconnect() except Exception as e: logger.error( "Failed to disconnect from AgentScope Studio at %s: %s", self.studio_url, str(e), )