agentscope.agents._user_input 源代码
# -*- coding: utf-8 -*-
"""The user input module."""
import time
from abc import ABC, abstractmethod
from queue import Queue
from threading import Event
from typing import Optional, Any
import json5
import jsonschema
import requests
import shortuuid
import socketio
from loguru import logger
from pydantic import BaseModel
from ..message import ContentBlock, TextBlock
[文档]
class UserInputData(BaseModel):
"""The user input data."""
blocks_input: Optional[list[ContentBlock]] = None
"""The text input from the user."""
structured_input: Optional[dict[str, Any]] = None
"""The structured input from the user."""
[文档]
class UserInputBase(ABC):
"""The base class for user input."""
@abstractmethod
def __call__(
self,
agent_id: str,
agent_name: str,
*args: Any,
structured_schema: Optional[BaseModel] = None,
**kwargs: dict,
) -> UserInputData:
"""The input method.
Args:
agent_id (`str`):
The identity of the agent.
agent_name (`str`):
The name of the agent.
structured_schema (`Optional[BaseModel]`, defaults to `None`):
The base model class of the structured input.
Returns:
`UserInputData`:
The user input data.
"""
[文档]
class StudioUserInput(UserInputBase):
"""The agentscope studio user input."""
_websocket_namespace: str = "/python"
def __init__(
self,
studio_url: str,
run_id: str,
max_retries: int = 3,
reconnect_attempts: int = 3,
reconnection_delay: int = 1,
reconnection_delay_max: int = 5,
) -> None:
"""Initialize the StudioUserInput object.
Args:
studio_url (`str`):
The URL of the AgentScope Studio.
run_id (`str`):
The current run identity.
max_retries (`int`, defaults to `3`):
The maximum number of retries to get user input.
"""
self._is_connected = False
self._is_reconnecting = False
self.studio_url = studio_url
self.run_id = run_id
self.max_retries = max_retries
# Init Websocket
self.sio = socketio.Client(
reconnection=True,
reconnection_attempts=reconnect_attempts,
reconnection_delay=reconnection_delay,
reconnection_delay_max=reconnection_delay_max,
)
self.input_queues = {}
self.input_events = {}
@self.sio.on("connect", namespace=self._websocket_namespace)
def on_connect() -> None:
self._is_connected = True
logger.info(
f"Connected to AgentScope Studio at {self.studio_url} with "
f"run_id {run_id}",
)
@self.sio.on("disconnect", namespace=self._websocket_namespace)
def on_disconnect() -> None:
self._is_connected = False
logger.info(
f"Disconnected from AgentScope Studio at {self.studio_url}",
)
@self.sio.on("reconnect", namespace=self._websocket_namespace)
def on_reconnect(attempt_number: int) -> None:
self._is_connected = True
self._is_reconnecting = False
logger.info(
f"Reconnected to AgentScope Studio at {self.studio_url} "
f"with run_id {self.run_id} after {attempt_number} attempts",
)
@self.sio.on("reconnect_attempt", namespace=self._websocket_namespace)
def on_reconnect_attempt(attempt_number: int) -> None:
self._is_reconnecting = True
logger.info(
f"Attempting to reconnect to AgentScope Studio at "
f"{self.studio_url} (attempt {attempt_number})",
)
@self.sio.on("reconnect_failed", namespace=self._websocket_namespace)
def on_reconnect_failed() -> None:
self._is_reconnecting = False
logger.error(
f"Failed to reconnect to AgentScope "
f"Studio at {self.studio_url}",
)
@self.sio.on("reconnect_error", namespace=self._websocket_namespace)
def on_reconnect_error(error: Any) -> None:
logger.error(
"Error while reconnecting to AgentScope Studio at "
f"{self.studio_url}: {str(error)}",
)
# The AgentScope Studio backend send the "sendUserInput" event to
# the current python run
@self.sio.on("forwardUserInput", namespace=self._websocket_namespace)
def receive_user_input(
request_id: str,
blocks_input: list[ContentBlock],
structured_input: dict[str, Any],
) -> None:
if request_id in self.input_queues:
self.input_queues[request_id].put(
UserInputData(
blocks_input=blocks_input,
structured_input=structured_input,
),
)
self.input_events[request_id].set()
try:
self.sio.connect(
f"{self.studio_url}",
namespaces=["/python"],
auth={"run_id": self.run_id},
)
except Exception as e:
raise RuntimeError(
f"Failed to connect to AgentScope Studio at {self.studio_url}",
) from e
def _ensure_connected(
self,
timeout: float = 30.0,
check_interval: float = 5.0,
) -> None:
"""Ensure the connection is established or wait for reconnection.
Args:
timeout (`float`):
Maximum time to wait for reconnection in seconds. Defaults
to 30.0.
check_interval (`float`):
Interval between connection checks in seconds. Defaults to 1.0.
Raises:
`RuntimeError`:
If connection cannot be established within timeout.
"""
if self._is_connected:
return
if self._is_reconnecting:
start_time = time.time()
while self._is_reconnecting:
# Check timeout
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
raise RuntimeError(
"Reconnection timeout after "
f"{elapsed_time:.1f} seconds",
)
# Log status
logger.info(
"Waiting for reconnection... "
f"({elapsed_time:.1f}s / {timeout}s)",
)
# Wait for next check
time.sleep(check_interval)
# After reconnection attempt completed, check final status
if self._is_connected:
return
# Not connected and not reconnecting
raise RuntimeError(
f"Not connected to AgentScope Studio at {self.studio_url}",
)
def __call__( # type: ignore[override]
self,
agent_id: str,
agent_name: str,
*args: Any,
structured_schema: Optional[BaseModel] = None,
) -> UserInputData:
"""Get the user input from AgentScope Studio.
Args:
agent_id (`str`):
The identity of the agent.
agent_name (`str`):
The name of the agent.
structured_schema (`Optional[BaseModel]`, defaults to `None`):
The base model class of the structured input.
Raises:
`RuntimeError`:
Failed to get user input from AgentScope Studio.
Returns:
`UserInputData`:
The user input.
"""
self._ensure_connected()
request_id = shortuuid.uuid()
self.input_queues[request_id] = Queue()
self.input_events[request_id] = Event()
n_retry = 0
while True:
try:
response = requests.post(
f"{self.studio_url}/trpc/requestUserInput",
json={
"requestId": request_id,
"runId": self.run_id,
"agentId": agent_id,
"agentName": agent_name,
"structuredInput": structured_schema,
},
)
response.raise_for_status()
break
except Exception as e:
if n_retry < self.max_retries:
n_retry += 1
continue
raise RuntimeError(
"Failed to get user input from AgentScope Studio",
) from e
try:
self.input_events[request_id].wait()
response_data = self.input_queues[request_id].get()
return response_data
finally:
self.input_queues.pop(request_id, None)
self.input_events.pop(request_id, None)
def __del__(self) -> None:
"""Cleanup socket connection when object it destroyed"""
try:
self.sio.disconnect()
except Exception as e:
logger.error(
f"Failed to disconnect from AgentScope Studio at "
f"{self.studio_url}: {str(e)}",
)
[文档]
class TerminalUserInput(UserInputBase):
"""The terminal user input."""
def __init__(self, input_prefix: str = "User Input: ") -> None:
"""The terminal user input."""
self.input_prefix = input_prefix
def __call__( # type: ignore[override]
self,
agent_id: str,
agent_name: str,
*args: Any,
structured_schema: Optional[BaseModel] = None,
) -> UserInputData:
"""The input method for terminal."""
text_input = input(self.input_prefix)
structured_input = None
if structured_schema is not None:
structured_input = {}
json_schema = structured_schema.model_json_schema()
required = json_schema.get("required", [])
print("Structured input (press Enter to skip for optional):)")
for key, item in json_schema.get("properties").items():
requirements = {**item}
requirements.pop("title")
while True:
res = input(f"\t{key} ({requirements}):")
if res == "":
if key in required:
print(f"Key {key} is required.")
continue
res = item.get("default", None)
if item.get("type").lower() == "integer":
res = json5.loads(res)
try:
jsonschema.validate(res, item)
structured_input[key] = res
break
except jsonschema.ValidationError as e:
logger.error(f"Validation error: {e}")
return UserInputData(
blocks_input=[TextBlock(type="text", text=text_input)],
structured_input=structured_input,
)