Source code for agentscope.rpc.rpc_object

# -*- coding: utf-8 -*-
"""A proxy object which represent a object located in a rpc server."""
from __future__ import annotations
from typing import Any, Callable, Union
from abc import ABC
from inspect import getmembers, isfunction
from types import FunctionType
from concurrent.futures import ThreadPoolExecutor, Future
import threading

try:
    import cloudpickle as pickle
except ImportError as e:
    from agentscope.utils.common import ImportErrorReporter

    pickle = ImportErrorReporter(e, "distribute")

from .rpc_client import RpcClient
from .rpc_async import AsyncResult
from .retry_strategy import RetryBase, _DEAFULT_RETRY_STRATEGY
from ..exception import AgentCreationError, AgentServerNotAliveError


[docs] def get_public_methods(cls: type) -> list[str]: """Get all public methods of the given class.""" return [ name for name, member in getmembers(cls, predicate=isfunction) if isinstance(member, FunctionType) and not name.startswith("_") ]
def _call_func_in_thread(func: Callable, *args: Any, **kwargs: Any) -> Any: """Call a function in a sub-thread.""" future = Future() def wrapper(*args: Any, **kwargs: Any) -> None: try: result = func(*args, **kwargs) future.set_result(result) except Exception as ex: future.set_exception(ex) thread = threading.Thread(target=wrapper, args=args, kwargs=kwargs) thread.start() return future class _ClassInfo: def __init__(self) -> None: self.async_func = set() self.sync_func = set() # TODO: we don't record attributes here, because we don't know how to # handle them for now. def update(self, info: _ClassInfo) -> None: """Update the class info with the given info.""" self.async_func.update(info.async_func) self.sync_func.update(info.sync_func) def detect(self, attrs: dict) -> None: """Detect the public async/sync method in the given attrs.""" for key, value in attrs.items(): if callable(value): if getattr(value, "_is_async", False): # add all functions with @async_func to the async_func set self.async_func.add(key) elif getattr(value, "_is_sync", False) or not key.startswith( "_", ): # add all other public functions to the sync_func set self.sync_func.add(key)
[docs] class RpcObject(ABC): """A proxy object which represent an object located in an Rpc server. Note: When `to_dist` is called on an object or when using the `to_dist` parameter to create an object, the object is moved to an Rpc server, and an `RpcObject` instance is left behind in the local process. The `RpcObject` will automatically forward all public method invocations to the original object in the rpc server. """
[docs] def __init__( # pylint: disable=R0912 self, cls: type, oid: str, host: str, port: int, connect_existing: bool = False, max_pool_size: int = 8192, max_expire_time: int = 7200, max_timeout_seconds: int = 5, local_mode: bool = True, retry_strategy: Union[RetryBase, dict] = _DEAFULT_RETRY_STRATEGY, configs: dict = None, ) -> None: """Initialize the rpc object. Args: cls (`type`): The class of the object in the rpc server. oid (`str`): The id of the object in the rpc server. host (`str`): The host of the rpc server. port (`int`): The port of the rpc server. connect_existing (`bool`, defaults to `False`): Set to `True`, if the object is already running on the server. max_pool_size (`int`, defaults to `8192`): Max number of task results that the server can accommodate. max_expire_time (`int`, defaults to `7200`): Max expire time for task results. max_timeout_seconds (`int`, defaults to `5`): Max timeout seconds for the rpc call. local_mode (`bool`, defaults to `True`): Whether the started gRPC server only listens to local requests. retry_strategy (`Union[RetryBase, dict]`, defaults to `_DEAFULT_RETRY_STRATEGY`): The retry strategy for async rpc call. configs (`dict`, defaults to `None`): The configs for the agent. Generated by `RpcMeta`. Don't use this arg manually. """ self.host = host self.port = port self._oid = oid self._cls = cls self.connect_existing = connect_existing self.executor = ThreadPoolExecutor(max_workers=1) if isinstance(retry_strategy, RetryBase): self.retry_strategy = retry_strategy else: self.retry_strategy = RetryBase.load_dict(retry_strategy) from ..studio._client import _studio_client if self.port is None and _studio_client.active: server = _studio_client.alloc_server() if "host" in server: if RpcClient( host=server["host"], port=server["port"], ).is_alive(): self.host = server["host"] self.port = server["port"] launch_server = self.port is None self.server_launcher = None if launch_server: from ..server import RpcAgentServerLauncher # check studio first self.host = "localhost" studio_url = None if _studio_client.active: studio_url = _studio_client.studio_url self.server_launcher = RpcAgentServerLauncher( host=self.host, port=self.port, capacity=2, max_pool_size=max_pool_size, max_expire_time=max_expire_time, max_timeout_seconds=max_timeout_seconds, local_mode=local_mode, custom_agent_classes=[cls], studio_url=studio_url, # type: ignore[arg-type] ) self._launch_server() else: self.client = RpcClient(self.host, self.port) if not connect_existing: self.create(configs) if launch_server: self._check_created() else: self._creating_stub = None
[docs] def create(self, configs: dict) -> None: """create the object on the rpc server.""" self._creating_stub = _call_func_in_thread( self.client.create_agent, configs, self._oid, )
def __call__(self, *args: Any, **kwargs: Any) -> Any: self._check_created() if "__call__" in self._cls._info.async_func: return self._async_func("__call__")(*args, **kwargs) else: return self._call_func( "__call__", args={ "args": args, "kwargs": kwargs, }, ) def __getitem__(self, item: str) -> Any: self._check_created() return self._call_func("__getitem__", {"args": (item,)}) def _launch_server(self) -> None: """Launch a rpc server and update the port and the client""" self.server_launcher.launch() self.port = self.server_launcher.port self.client = RpcClient( host=self.host, port=self.port, ) if not self.client.is_alive(): raise AgentServerNotAliveError(self.host, self.port)
[docs] def stop(self) -> None: """Stop the RpcAgent and the rpc server.""" if self.server_launcher is not None: self.server_launcher.shutdown()
def _check_created(self) -> None: """Check if the object is created on the rpc server.""" if self._creating_stub is not None: response = self._creating_stub.result() if response is not True: if issubclass(response.__class__, Exception): raise response raise AgentCreationError(self.host, self.port) self._creating_stub = None def _call_func(self, func_name: str, args: dict) -> Any: """Call a function in rpc server.""" return pickle.loads( self.client.call_agent_func( agent_id=self._oid, func_name=func_name, value=pickle.dumps(args), ), ) def _async_func(self, name: str) -> Callable: def async_wrapper(*args, **kwargs) -> Any: # type: ignore[no-untyped-def] return AsyncResult( host=self.host, port=self.port, stub=_call_func_in_thread( self._call_func, func_name=name, args={"args": args, "kwargs": kwargs}, ), retry=self.retry_strategy, ) return async_wrapper def _sync_func(self, name: str) -> Callable: def sync_wrapper(*args, **kwargs) -> Any: # type: ignore[no-untyped-def] return self._call_func( func_name=name, args={"args": args, "kwargs": kwargs}, ) return sync_wrapper def __getattr__(self, name: str) -> Callable: self._check_created() if name in self._cls._info.async_func: # for async functions return self._async_func(name) elif name in self._cls._info.sync_func: # for sync functions return self._sync_func(name) else: # for attributes return self._call_func( func_name=name, args={}, ) def __del__(self) -> None: self.stop() def __deepcopy__(self, memo: dict) -> Any: """For deepcopy.""" if id(self) in memo: return memo[id(self)] clone = RpcObject( cls=self._cls, oid=self._oid, host=self.host, port=self.port, connect_existing=True, ) memo[id(self)] = clone return clone def __reduce__(self) -> tuple: self._check_created() return ( RpcObject, ( self._cls, self._oid, self.host, self.port, True, ), )