agentscope.rpc.rpc_object 源代码

# -*- coding: utf-8 -*-
"""A proxy object which represent a object located in a rpc server."""
from __future__ import annotations
from typing import Any, Callable, Union
from abc import ABC
from inspect import getmembers, isfunction
from types import FunctionType
from concurrent.futures import ThreadPoolExecutor, Future
import threading

    import cloudpickle as pickle
except ImportError as e:
    from agentscope.utils.common import ImportErrorReporter

    pickle = ImportErrorReporter(e, "distribute")

from .rpc_client import RpcClient
from .rpc_async import AsyncResult
from .retry_strategy import RetryBase, _DEFAULT_RETRY_STRATEGY
from ..exception import AgentCreationError, AgentServerNotAliveError

[文档] def get_public_methods(cls: type) -> list[str]: """Get all public methods of the given class.""" return [ name for name, member in getmembers(cls, predicate=isfunction) if isinstance(member, FunctionType) and not name.startswith("_") ]
def _call_func_in_thread(func: Callable, *args: Any, **kwargs: Any) -> Any: """Call a function in a sub-thread.""" future = Future() def wrapper(*args: Any, **kwargs: Any) -> None: try: result = func(*args, **kwargs) future.set_result(result) except Exception as ex: future.set_exception(ex) thread = threading.Thread(target=wrapper, args=args, kwargs=kwargs) thread.start() return future class _ClassInfo: def __init__(self) -> None: self.async_func = set() self.sync_func = set() # TODO: we don't record attributes here, because we don't know how to # handle them for now. def update(self, info: _ClassInfo) -> None: """Update the class info with the given info.""" self.async_func.update(info.async_func) self.sync_func.update(info.sync_func) def detect(self, attrs: dict) -> None: """Detect the public async/sync method in the given attrs.""" for key, value in attrs.items(): if callable(value): if getattr(value, "_is_async", False): # add all functions with @async_func to the async_func set self.async_func.add(key) elif getattr(value, "_is_sync", False) or not key.startswith( "_", ): # add all other public functions to the sync_func set self.sync_func.add(key)
[文档] class RpcObject(ABC): """A proxy object which represent an object located in an Rpc server. Note: When `to_dist` is called on an object or when using the `to_dist` parameter to create an object, the object is moved to an Rpc server, and an `RpcObject` instance is left behind in the local process. The `RpcObject` will automatically forward all public method invocations to the original object in the rpc server. """ def __init__( # pylint: disable=R0912 self, cls: type, oid: str, host: str, port: int, connect_existing: bool = False, max_pool_size: int = 8192, max_expire_time: int = 7200, max_timeout_seconds: int = 5, local_mode: bool = True, retry_strategy: Union[RetryBase, dict] = _DEFAULT_RETRY_STRATEGY, configs: dict = None, ) -> None: """Initialize the rpc object. Args: cls (`type`): The class of the object in the rpc server. oid (`str`): The id of the object in the rpc server. host (`str`): The host of the rpc server. port (`int`): The port of the rpc server. connect_existing (`bool`, defaults to `False`): Set to `True`, if the object is already running on the server. max_pool_size (`int`, defaults to `8192`): Max number of task results that the server can accommodate. max_expire_time (`int`, defaults to `7200`): Max expire time for task results. max_timeout_seconds (`int`, defaults to `5`): Max timeout seconds for the rpc call. local_mode (`bool`, defaults to `True`): Whether the started gRPC server only listens to local requests. retry_strategy (`Union[RetryBase, dict]`, defaults to `_DEFAULT_RETRY_STRATEGY`): The retry strategy for async rpc call. configs (`dict`, defaults to `None`): The configs for the agent. Generated by `RpcMeta`. Don't use this arg manually. """ = host self.port = port self._oid = oid self._cls = cls self.connect_existing = connect_existing self.executor = ThreadPoolExecutor(max_workers=1) if isinstance(retry_strategy, RetryBase): self.retry_strategy = retry_strategy else: self.retry_strategy = RetryBase.load_dict(retry_strategy) from import _studio_client if self.port is None and server = _studio_client.alloc_server() if "host" in server: if RpcClient( host=server["host"], port=server["port"], ).is_alive(): = server["host"] self.port = server["port"] launch_server = self.port is None self.server_launcher = None if launch_server: from ..server import RpcAgentServerLauncher # check studio first = "localhost" studio_url = None if studio_url = _studio_client.studio_url self.server_launcher = RpcAgentServerLauncher(, port=self.port, capacity=2, max_pool_size=max_pool_size, max_expire_time=max_expire_time, max_timeout_seconds=max_timeout_seconds, local_mode=local_mode, custom_agent_classes=[cls], studio_url=studio_url, # type: ignore[arg-type] ) self._launch_server() else: self.client = RpcClient(, self.port) if not connect_existing: self.create(configs) if launch_server: self._check_created() else: self._creating_stub = None
[文档] def create(self, configs: dict) -> None: """create the object on the rpc server.""" self._creating_stub = _call_func_in_thread( self.client.create_agent, configs, self._oid, )
def __call__(self, *args: Any, **kwargs: Any) -> Any: self._check_created() if "__call__" in self._cls._info.async_func: return self._async_func("__call__")(*args, **kwargs) else: return self._call_func( "__call__", args={ "args": args, "kwargs": kwargs, }, ) def __getitem__(self, item: str) -> Any: self._check_created() return self._call_func("__getitem__", {"args": (item,)}) def _launch_server(self) -> None: """Launch a rpc server and update the port and the client""" self.server_launcher.launch() self.port = self.server_launcher.port self.client = RpcClient(, port=self.port, ) if not self.client.is_alive(): raise AgentServerNotAliveError(, self.port)
[文档] def stop(self) -> None: """Stop the RpcAgent and the rpc server.""" if self.server_launcher is not None: self.server_launcher.shutdown()
def _check_created(self) -> None: """Check if the object is created on the rpc server.""" if self._creating_stub is not None: response = self._creating_stub.result() if response is not True: if issubclass(response.__class__, Exception): raise response raise AgentCreationError(, self.port) self._creating_stub = None def _call_func(self, func_name: str, args: dict) -> Any: """Call a function in rpc server.""" return pickle.loads( self.client.call_agent_func( agent_id=self._oid, func_name=func_name, value=pickle.dumps(args), ), ) def _async_func(self, name: str) -> Callable: def async_wrapper(*args, **kwargs) -> Any: # type: ignore[no-untyped-def] return AsyncResult(, port=self.port, stub=_call_func_in_thread( self._call_func, func_name=name, args={"args": args, "kwargs": kwargs}, ), retry=self.retry_strategy, ) return async_wrapper def _sync_func(self, name: str) -> Callable: def sync_wrapper(*args, **kwargs) -> Any: # type: ignore[no-untyped-def] return self._call_func( func_name=name, args={"args": args, "kwargs": kwargs}, ) return sync_wrapper def __getattr__(self, name: str) -> Callable: self._check_created() if name in self._cls._info.async_func: # for async functions return self._async_func(name) elif name in self._cls._info.sync_func: # for sync functions return self._sync_func(name) else: # for attributes return self._call_func( func_name=name, args={}, ) def __del__(self) -> None: self.stop() def __deepcopy__(self, memo: dict) -> Any: """For deepcopy.""" if id(self) in memo: return memo[id(self)] clone = RpcObject( cls=self._cls, oid=self._oid,, port=self.port, connect_existing=True, ) memo[id(self)] = clone return clone def __reduce__(self) -> tuple: self._check_created() return ( RpcObject, ( self._cls, self._oid,, self.port, True, ), )