# -*- coding: utf-8 -*-
""" Client of rpc agent server """

import threading
import json
import os
from typing import Optional, Sequence, Union, Generator
from loguru import logger

from ..message import Msg
from ..serialize import deserialize

    import dill
    import grpc
    from grpc import RpcError
    from google.protobuf.empty_pb2 import Empty
    from agentscope.rpc.rpc_agent_pb2_grpc import RpcAgentStub
    import agentscope.rpc.rpc_agent_pb2 as agent_pb2
except ImportError as import_error:
    from agentscope.utils.common import ImportErrorReporter

    dill = ImportErrorReporter(import_error, "distribute")
    grpc = ImportErrorReporter(import_error, "distribute")
    agent_pb2 = ImportErrorReporter(import_error, "distribute")
    RpcAgentStub = ImportErrorReporter(import_error, "distribute")
    RpcError = ImportError

from ..utils.common import _generate_id_from_seed
from ..exception import AgentServerNotAliveError
from ..constants import _DEFAULT_RPC_OPTIONS
from ..exception import AgentCallError
from ..manager import FileManager

class RpcAgentClient:
    """A client of Rpc agent server"""
def __init__(
        self,
        host: str,
        port: int,
        agent_id: str = None,
    ) -> None:
        """Init a rpc agent client

        Args:
            host (`str`): The hostname of the rpc agent server which
                the client is connected.
            port (`int`): The port of the rpc agent server which the
                client is connected.
            agent_id (`str`): The agent id of the agent being called.
                Defaults to None.
        """
 = host
        self.port = port
        self.agent_id = agent_id
[docs] def call_agent_func( self, func_name: str, value: Optional[str] = None, timeout: int = 300, ) -> str: """Call the specific function of an agent running on the server. Args: func_name (`str`): The name of the function being called. value (`str`, optional): The serialized function input value. Defaults to None. timeout (`int`, optional): The timeout for the RPC call in seconds. Defaults to 300. Returns: str: serialized return data. """ try: with grpc.insecure_channel(f"{}:{self.port}") as channel: stub = RpcAgentStub(channel) result_msg = stub.call_agent_func( agent_pb2.RpcMsg( value=value, target_func=func_name, agent_id=self.agent_id, ), timeout=timeout, ) return result_msg.message except Exception as e: # check the server and raise a more reasonable error if not self.is_alive(): raise AgentServerNotAliveError(, port=self.port, message=str(e), ) from e raise e
[docs] def is_alive(self) -> bool: """Check if the agent server is alive. Returns: bool: Indicate whether the server is alive. """ try: with grpc.insecure_channel(f"{}:{self.port}") as channel: stub = RpcAgentStub(channel) status = stub.is_alive(Empty(), timeout=5) if not status.ok: raise AgentServerNotAliveError(, port=self.port, ) return status.ok except Exception: f"Agent server [{}:{self.port}] not alive.", ) return False
[docs] def stop(self) -> None: """Stop the agent server.""" try: with grpc.insecure_channel(f"{}:{self.port}") as channel: stub = RpcAgentStub(channel) f"Stopping agent server at [{}:{self.port}].", ) resp = stub.stop(Empty(), timeout=5) if resp.ok: f"Agent server at [{}:{self.port}] stopped.", ) else: logger.error( f"Fail to stop the agent server: {resp.message}", ) except Exception as e: logger.error( f"Fail to stop the agent server: {e}", )
[docs] def create_agent( self, agent_configs: dict, agent_id: str = None, ) -> bool: """Create a new agent for this client. Args: agent_configs (`dict`): Init configs of the agent, generated by `_AgentMeta`. agent_id (`str`): agent_id of the created agent. Returns: bool: Indicate whether the creation is successful """ try: with grpc.insecure_channel(f"{}:{self.port}") as channel: stub = RpcAgentStub(channel) status = stub.create_agent( agent_pb2.CreateAgentRequest( agent_id=( self.agent_id if agent_id is None else agent_id ), agent_init_args=dill.dumps(agent_configs), ), ) if not status.ok: logger.error( f"Error when creating agent: {status.message}", ) return status.ok except Exception as e: # check the server and raise a more reasonable error if not self.is_alive(): raise AgentServerNotAliveError(, port=self.port, message=str(e), ) from e raise e
def delete_agent(
        self,
        agent_id: str = None,
    ) -> bool:
        """
        Delete agents with the specific agent_id.

        Args:
            agent_id (`str`): id of the agent to be deleted.

        Returns:
            bool: Indicate whether the deletion is successful
        """
        with grpc.insecure_channel(f"{}:{self.port}") as channel:
            stub = RpcAgentStub(channel)
            status = stub.delete_agent(
                agent_pb2.StringMsg(value=agent_id),
            )
            if not status.ok:
                logger.error(f"Error when deleting agent: {status.message}")
            return status.ok
def delete_all_agent(self) -> bool:
        """Delete all agents on the server."""
        with grpc.insecure_channel(f"{}:{self.port}") as channel:
            stub = RpcAgentStub(channel)
            status = stub.delete_all_agents(Empty())
            if not status.ok:
                logger.error(f"Error when delete all agents: {status.message}")
            return status.ok
def clone_agent(self, agent_id: str) -> Optional[str]:
        """Clone a new agent instance from the origin instance.

        Args:
            agent_id (`str`): The agent_id of the agent to be cloned.

        Returns:
            str: The `agent_id` of the generated agent.
        """
        with grpc.insecure_channel(f"{}:{self.port}") as channel:
            stub = RpcAgentStub(channel)
            resp = stub.clone_agent(
                agent_pb2.StringMsg(value=agent_id),
            )
            if not resp.ok:
                logger.error(
                    f"Error when clone agent [{agent_id}]: {resp.message}",
                )
                return None
            else:
                return resp.message
def update_placeholder(self, task_id: int) -> str:
        """Update the placeholder value.

        Args:
            task_id (`int`): `task_id` of the PlaceholderMessage.

        Returns:
            bool: Whether the update is successful.
            str: Serialized message value.
        """
        with grpc.insecure_channel(
            f"{}:{self.port}",
            options=_DEFAULT_RPC_OPTIONS,
        ) as channel:
            stub = RpcAgentStub(channel)
            resp = stub.update_placeholder(
                agent_pb2.UpdatePlaceholderRequest(task_id=task_id),
            )
            if not resp.ok:
                raise AgentCallError(
          ,
                    port=self.port,
                    message=f"Failed to update placeholder: {resp.message}",
                )
            return resp.message
def get_agent_list(self) -> Sequence[dict]:
        """
        Get the summary of all agents on the server as a list.

        Returns:
            Sequence[str]: list of agent summary information.
        """
        with grpc.insecure_channel(f"{}:{self.port}") as channel:
            stub = RpcAgentStub(channel)
            resp = stub.get_agent_list(Empty())
            if not resp.ok:
                logger.error(f"Error when get agent list: {resp.message}")
                return []
            return [
                json.loads(agent_str) for agent_str in json.loads(resp.message)
            ]
def get_server_info(self) -> dict:
        """Get the agent server resource usage information."""
        try:
            with grpc.insecure_channel(f"{}:{self.port}") as channel:
                stub = RpcAgentStub(channel)
                resp = stub.get_server_info(Empty())
                if not resp.ok:
                    logger.error(f"Error in get_server_info: {resp.message}")
                    return {}
                return json.loads(resp.message)
        except Exception as e:
            logger.error(f"Error in get_server_info: {e}")
            return {}
def set_model_configs(
        self,
        model_configs: Union[dict, list[dict]],
    ) -> bool:
        """Set the model configs of the server."""
        with grpc.insecure_channel(f"{}:{self.port}") as channel:
            stub = RpcAgentStub(channel)
            resp = stub.set_model_configs(
                agent_pb2.StringMsg(value=json.dumps(model_configs)),
            )
            if not resp.ok:
                logger.error(f"Error in set_model_configs: {resp.message}")
                return False
            return True
def get_agent_memory(self, agent_id: str) -> Union[list[Msg], Msg]:
        """Get the memory usage of the specific agent."""
        with grpc.insecure_channel(f"{}:{self.port}") as channel:
            stub = RpcAgentStub(channel)
            resp = stub.get_agent_memory(
                agent_pb2.StringMsg(value=agent_id),
            )
            if not resp.ok:
                logger.error(f"Error in get_agent_memory: {resp.message}")
            return deserialize(resp.message)
def download_file(self, path: str) -> str:
        """Download a file from a remote server to the local machine.

        Args:
            path (`str`): The path of the file to be downloaded. Note that
                it is the path on the remote server.

        Returns:
            `str`: The path of the downloaded file. Note that it is the path
                on the local machine.
        """
        file_manager = FileManager.get_instance()
        local_filename = (
            f"{_generate_id_from_seed(path, 5)}_{os.path.basename(path)}"
        )

        def _generator() -> Generator[bytes, None, None]:
            with grpc.insecure_channel(f"{}:{self.port}") as channel:
                for resp in RpcAgentStub(channel).download_file(
                    agent_pb2.StringMsg(value=path),
                ):
                    yield

        return file_manager.save_file(_generator(), local_filename)
class ResponseStub:
    """A stub used to save the response of a rpc call in a sub-thread."""
def __init__(self) -> None:
        self.response = None
        self.condition = threading.Condition()
def set_response(self, response: str) -> None:
        """Set the message."""
        with self.condition:
            self.response = response
            self.condition.notify_all()
def get_response(self) -> str:
        """Get the message."""
        with self.condition:
            while self.response is None:
                self.condition.wait()
            return self.response
def call_in_thread(
    client: RpcAgentClient,
    value: str,
    func_name: str,
) -> ResponseStub:
    """Call rpc function in a sub-thread.

    Args:
        client (`RpcAgentClient`): The rpc client.
        value (`str`): The value of the request.
        func_name (`str`): The name of the function being called.

    Returns:
        `ResponseStub`: A stub to get the response.
    """
    stub = ResponseStub()

    def wrapper() -> None:
        try:
            resp = client.call_agent_func(
                func_name=func_name,
                value=value,
            )
            stub.set_response(resp)  # type: ignore[arg-type]
        except RpcError as e:
            logger.error(f"Fail to call {func_name} in thread: {e}")
            stub.set_response(str(e))

    thread = threading.Thread(target=wrapper)
    thread.start()
    return stub