Source code for agentscope.server.async_result_pool

# -*- coding: utf-8 -*-
"""A pool used to store the async result."""
import threading
from abc import ABC, abstractmethod

try:
    import redis
    import expiringdict
except ImportError as import_error:
    from agentscope.utils.common import ImportErrorReporter

    redis = ImportErrorReporter(import_error, "distribute")
    expiringdict = ImportErrorReporter(import_error, "distribute")


[docs] class AsyncResultPool(ABC): """Interface of Async Result Pool, used to store async results."""
[docs] @abstractmethod def prepare(self) -> int: """Prepare a slot for the async result. Returns: `int`: The key of the async result. """
[docs] @abstractmethod def set(self, key: int, value: bytes) -> None: """Set a value to the pool. Args: key (`int`): The key of the value. value (`bytes`): The value to be set. """
[docs] @abstractmethod def get(self, key: int, timeout: int = 5) -> bytes: """Get a value from the pool. Args: key (`int`): The key of the value timeout (`int`): The timeout seconds to wait for the value. Returns: `bytes`: The value Raises: `TimeoutError`: When the timeout is reached. """
[docs] class LocalPool(AsyncResultPool): """Local pool for storing results.""" def __init__(self, max_len: int, max_expire: int) -> None: self.pool = expiringdict.ExpiringDict( max_len=max_len, max_age_seconds=max_expire, ) self.object_id_cnt = 0 self.object_id_lock = threading.Lock() def _get_object_id(self) -> int: with self.object_id_lock: self.object_id_cnt += 1 return self.object_id_cnt
[docs] def prepare(self) -> int: oid = self._get_object_id() self.pool[oid] = threading.Condition() return oid
[docs] def set(self, key: int, value: bytes) -> None: cond = self.pool[key] self.pool[key] = value with cond: cond.notify_all()
[docs] def get(self, key: int, timeout: int = 5) -> bytes: """Get the value with timeout""" value = self.pool.get(key) if isinstance(value, threading.Condition): with value: value.wait(timeout=timeout) value = self.pool.get(key) if isinstance(value, threading.Condition): raise TimeoutError( f"Waiting timeout for async result of task[{key}]", ) return value return value
[docs] class RedisPool(AsyncResultPool): """Redis pool for storing results.""" INCR_KEY = "as_obj_id" TASK_QUEUE_PREFIX = "as_task_" def __init__( self, url: str, max_expire: int, ) -> None: """ Init redis pool. Args: url (`str`): The url of the redis server. max_expire (`int`): The max timeout of the result in the pool, when it is reached, the oldest item will be removed. """ try: self.pool = redis.from_url(url) self.pool.ping() except Exception as e: raise ConnectionError( f"Redis server at [{url}] is not available.", ) from e self.max_expire = max_expire def _get_object_id(self) -> int: return self.pool.incr(RedisPool.INCR_KEY)
[docs] def prepare(self) -> int: return self._get_object_id()
[docs] def set(self, key: int, value: bytes) -> None: qkey = RedisPool.TASK_QUEUE_PREFIX + str(key) self.pool.set(key, value, ex=self.max_expire) self.pool.rpush(qkey, key) self.pool.expire(qkey, self.max_expire)
[docs] def get(self, key: int, timeout: int = 5) -> bytes: result = self.pool.get(key) if result: return result else: keys = self.pool.blpop( keys=RedisPool.TASK_QUEUE_PREFIX + str(key), timeout=timeout, ) if keys is None: raise TimeoutError( f"Waiting timeout for async result of task[{key}]", ) self.pool.rpush(RedisPool.TASK_QUEUE_PREFIX + str(key), key) if int(keys[1]) == key: res = self.pool.get(key) if res is None: raise TimeoutError( f"Async Result of task[{key}] not found.", ) return res else: raise TimeoutError(f"Async Result of task[{key}] not found.")
[docs] def get_pool( pool_type: str = "local", max_expire: int = 7200, max_len: int = 8192, redis_url: str = "redis://localhost:6379", ) -> AsyncResultPool: """Get the pool according to the type. Args: pool_type (`str`): The type of the pool, can be `local` or `redis`, default is `local`. max_expire (`int`): The max expire time of the result in the pool, when it is reached, the oldest item will be removed. max_len (`int`): The max length of the pool. redis_url (`str`): The address of the redis server. """ if pool_type == "redis": return RedisPool(url=redis_url, max_expire=max_expire) else: return LocalPool(max_len=max_len, max_expire=max_expire)