# -*- coding: utf-8 -*-
"""A proxy object which represent a object located in a rpc server."""
from __future__ import annotations
from typing import Any, Callable, Union
from abc import ABC
from inspect import getmembers, isfunction
from types import FunctionType
from concurrent.futures import ThreadPoolExecutor, Future
import threading
try:
import cloudpickle as pickle
except ImportError as e:
from agentscope.utils.common import ImportErrorReporter
pickle = ImportErrorReporter(e, "distribute")
from .rpc_client import RpcClient
from .rpc_async import AsyncResult
from .retry_strategy import RetryBase, _DEAFULT_RETRY_STRATEGY
from ..exception import AgentCreationError, AgentServerNotAliveError
[docs]
def get_public_methods(cls: type) -> list[str]:
"""Get all public methods of the given class."""
return [
name
for name, member in getmembers(cls, predicate=isfunction)
if isinstance(member, FunctionType) and not name.startswith("_")
]
def _call_func_in_thread(func: Callable, *args: Any, **kwargs: Any) -> Any:
"""Call a function in a sub-thread."""
future = Future()
def wrapper(*args: Any, **kwargs: Any) -> None:
try:
result = func(*args, **kwargs)
future.set_result(result)
except Exception as ex:
future.set_exception(ex)
thread = threading.Thread(target=wrapper, args=args, kwargs=kwargs)
thread.start()
return future
class _ClassInfo:
def __init__(self) -> None:
self.async_func = set()
self.sync_func = set()
# TODO: we don't record attributes here, because we don't know how to
# handle them for now.
def update(self, info: _ClassInfo) -> None:
"""Update the class info with the given info."""
self.async_func.update(info.async_func)
self.sync_func.update(info.sync_func)
def detect(self, attrs: dict) -> None:
"""Detect the public async/sync method in the given attrs."""
for key, value in attrs.items():
if callable(value):
if getattr(value, "_is_async", False):
# add all functions with @async_func to the async_func set
self.async_func.add(key)
elif getattr(value, "_is_sync", False) or not key.startswith(
"_",
):
# add all other public functions to the sync_func set
self.sync_func.add(key)
[docs]
class RpcObject(ABC):
"""A proxy object which represent an object located in an Rpc server.
Note:
When `to_dist` is called on an object or when using the `to_dist`
parameter to create an object, the object is moved to an Rpc server,
and an `RpcObject` instance is left behind in the local process.
The `RpcObject` will automatically forward all public method invocations
to the original object in the rpc server.
"""
def __init__( # pylint: disable=R0912
self,
cls: type,
oid: str,
host: str,
port: int,
connect_existing: bool = False,
max_pool_size: int = 8192,
max_expire_time: int = 7200,
max_timeout_seconds: int = 5,
local_mode: bool = True,
retry_strategy: Union[RetryBase, dict] = _DEAFULT_RETRY_STRATEGY,
configs: dict = None,
) -> None:
"""Initialize the rpc object.
Args:
cls (`type`): The class of the object in the rpc server.
oid (`str`): The id of the object in the rpc server.
host (`str`): The host of the rpc server.
port (`int`): The port of the rpc server.
connect_existing (`bool`, defaults to `False`):
Set to `True`, if the object is already running on the
server.
max_pool_size (`int`, defaults to `8192`):
Max number of task results that the server can accommodate.
max_expire_time (`int`, defaults to `7200`):
Max expire time for task results.
max_timeout_seconds (`int`, defaults to `5`):
Max timeout seconds for the rpc call.
local_mode (`bool`, defaults to `True`):
Whether the started gRPC server only listens to local
requests.
retry_strategy (`Union[RetryBase, dict]`, defaults to `_DEAFULT_RETRY_STRATEGY`):
The retry strategy for async rpc call.
configs (`dict`, defaults to `None`):
The configs for the agent. Generated by `RpcMeta`. Don't use this arg manually.
"""
self.host = host
self.port = port
self._oid = oid
self._cls = cls
self.connect_existing = connect_existing
self.executor = ThreadPoolExecutor(max_workers=1)
if isinstance(retry_strategy, RetryBase):
self.retry_strategy = retry_strategy
else:
self.retry_strategy = RetryBase.load_dict(retry_strategy)
from ..studio._client import _studio_client
if self.port is None and _studio_client.active:
server = _studio_client.alloc_server()
if "host" in server:
if RpcClient(
host=server["host"],
port=server["port"],
).is_alive():
self.host = server["host"]
self.port = server["port"]
launch_server = self.port is None
self.server_launcher = None
if launch_server:
from ..server import RpcAgentServerLauncher
# check studio first
self.host = "localhost"
studio_url = None
if _studio_client.active:
studio_url = _studio_client.studio_url
self.server_launcher = RpcAgentServerLauncher(
host=self.host,
port=self.port,
capacity=2,
max_pool_size=max_pool_size,
max_expire_time=max_expire_time,
max_timeout_seconds=max_timeout_seconds,
local_mode=local_mode,
custom_agent_classes=[cls],
studio_url=studio_url, # type: ignore[arg-type]
)
self._launch_server()
else:
self.client = RpcClient(self.host, self.port)
if not connect_existing:
self.create(configs)
if launch_server:
self._check_created()
else:
self._creating_stub = None
[docs]
def create(self, configs: dict) -> None:
"""create the object on the rpc server."""
self._creating_stub = _call_func_in_thread(
self.client.create_agent,
configs,
self._oid,
)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
self._check_created()
if "__call__" in self._cls._info.async_func:
return self._async_func("__call__")(*args, **kwargs)
else:
return self._call_func(
"__call__",
args={
"args": args,
"kwargs": kwargs,
},
)
def __getitem__(self, item: str) -> Any:
self._check_created()
return self._call_func("__getitem__", {"args": (item,)})
def _launch_server(self) -> None:
"""Launch a rpc server and update the port and the client"""
self.server_launcher.launch()
self.port = self.server_launcher.port
self.client = RpcClient(
host=self.host,
port=self.port,
)
if not self.client.is_alive():
raise AgentServerNotAliveError(self.host, self.port)
[docs]
def stop(self) -> None:
"""Stop the RpcAgent and the rpc server."""
if self.server_launcher is not None:
self.server_launcher.shutdown()
def _check_created(self) -> None:
"""Check if the object is created on the rpc server."""
if self._creating_stub is not None:
response = self._creating_stub.result()
if response is not True:
if issubclass(response.__class__, Exception):
raise response
raise AgentCreationError(self.host, self.port)
self._creating_stub = None
def _call_func(self, func_name: str, args: dict) -> Any:
"""Call a function in rpc server."""
return pickle.loads(
self.client.call_agent_func(
agent_id=self._oid,
func_name=func_name,
value=pickle.dumps(args),
),
)
def _async_func(self, name: str) -> Callable:
def async_wrapper(*args, **kwargs) -> Any: # type: ignore[no-untyped-def]
return AsyncResult(
host=self.host,
port=self.port,
stub=_call_func_in_thread(
self._call_func,
func_name=name,
args={"args": args, "kwargs": kwargs},
),
retry=self.retry_strategy,
)
return async_wrapper
def _sync_func(self, name: str) -> Callable:
def sync_wrapper(*args, **kwargs) -> Any: # type: ignore[no-untyped-def]
return self._call_func(
func_name=name,
args={"args": args, "kwargs": kwargs},
)
return sync_wrapper
def __getattr__(self, name: str) -> Callable:
self._check_created()
if name in self._cls._info.async_func:
# for async functions
return self._async_func(name)
elif name in self._cls._info.sync_func:
# for sync functions
return self._sync_func(name)
else:
# for attributes
return self._call_func(
func_name=name,
args={},
)
def __del__(self) -> None:
self.stop()
def __deepcopy__(self, memo: dict) -> Any:
"""For deepcopy."""
if id(self) in memo:
return memo[id(self)]
clone = RpcObject(
cls=self._cls,
oid=self._oid,
host=self.host,
port=self.port,
connect_existing=True,
)
memo[id(self)] = clone
return clone
def __reduce__(self) -> tuple:
self._check_created()
return (
RpcObject,
(
self._cls,
self._oid,
self.host,
self.port,
True,
),
)