agentscope.session._redis_session 源代码

# -*- coding: utf-8 -*-
"""The Redis session class."""
import json
from typing import Any, TYPE_CHECKING

from ._session_base import SessionBase
from .._logging import logger
from ..module import StateModule

if TYPE_CHECKING:
    from redis.asyncio import ConnectionPool, Redis
else:
    ConnectionPool = Any
    Redis = Any


[文档] class RedisSession(SessionBase): """The Redis session class.""" SESSION_KEY = "user_id:{user_id}:session:{session_id}:state" """Redis key pattern (without prefix) for storing sessions state for a specific session. """
[文档] def __init__( self, host: str = "localhost", port: int = 6379, db: int = 0, password: str | None = None, connection_pool: ConnectionPool | None = None, key_ttl: int | None = None, key_prefix: str = "", **kwargs: Any, ) -> None: """Initialize the Redis session class by connecting to Redis. Args: host (`str`, defaults to `"localhost"`): Redis server host. port (`int`, defaults to `6379`): Redis server port. db (`int`, defaults to `0`): Redis database index. password (`str | None`, optional): Redis password if required. connection_pool (`ConnectionPool | None`, optional): Optional Redis connection pool. key_ttl (`int | None`, optional): Expire time in seconds for each session state key. If provided, the expiration will be refreshed on every save/load (sliding TTL). If `None`, the session state will not expire. key_prefix (`str`, default to `""`): Optional Redis key prefix prepended to every key generated by this storage. Useful for isolating keys across apps/environments (e.g. `"prod:"`, `"staging:"`, `"myapp:"`). **kwargs (`Any`): Additional keyword arguments passed to redis client. """ try: import redis.asyncio as redis except ImportError as e: raise ImportError( "The 'redis' package is required for RedisSession. " "Please install it via 'pip install redis[async]'.", ) from e self.key_ttl = key_ttl self.key_prefix = key_prefix or "" self._client = redis.Redis( host=host, port=port, db=db, password=password, connection_pool=connection_pool, decode_responses=True, **kwargs, )
[文档] def get_client(self) -> Redis: """Get the underlying Redis client. Returns: `Redis`: The Redis client instance. """ return self._client
def _get_session_key( self, session_id: str, user_id: str, ) -> str: """Get the Redis key to store a given session state.""" return self.key_prefix + self.SESSION_KEY.format( user_id=user_id, session_id=session_id, )
[文档] async def save_session_state( self, session_id: str, user_id: str = "default_user", **state_modules_mapping: StateModule, ) -> None: """Save the state dictionary to Redis. Args: session_id (`str`): The session id. user_id (`str`, default to `"default_user"`): The user ID for the storage. **state_modules_mapping (`dict[str, StateModule]`): A dictionary mapping of state module names to their instances. """ state_dicts = { name: state_module.state_dict() for name, state_module in state_modules_mapping.items() } key = self._get_session_key(session_id, user_id=user_id) value = json.dumps(state_dicts, ensure_ascii=False) await self._client.set(key, value, ex=self.key_ttl) logger.info("Save session state to redis key %s successfully.", key)
[文档] async def load_session_state( self, session_id: str, user_id: str = "default_user", allow_not_exist: bool = True, **state_modules_mapping: StateModule, ) -> None: """Load the state dictionary from Redis. Args: session_id (`str`): The session id. user_id (`str`, default to `"default_user"`): The user ID for the storage. allow_not_exist (`bool`, defaults to `True`): Whether to allow the session to not exist. **state_modules_mapping (`dict[str, StateModule]`): The mapping of state modules to be loaded. """ key = self._get_session_key(session_id, user_id=user_id) # Use GETEX to get and refresh TTL in a single atomic operation if self.key_ttl is not None: data = await self._client.getex(key, ex=self.key_ttl) else: data = await self._client.get(key) if data is None: if allow_not_exist: logger.info( "Session key %s does not exist in Redis. Skip loading " "session state.", key, ) return raise ValueError( f"Failed to load session state because redis key {key} " "does not exist.", ) if isinstance(data, (bytes, bytearray)): data = data.decode("utf-8") states = json.loads(data) for name, state_module in state_modules_mapping.items(): if name in states: state_module.load_state_dict(states[name]) logger.info("Load session state from redis key %s successfully.", key)
[文档] async def close(self) -> None: """Close the Redis client connection.""" await self._client.close()
async def __aenter__(self) -> "RedisSession": """Enter the async context manager. Returns: `RedisSession`: The current `RedisSession` instance. """ return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any, ) -> None: """Exit the async context manager and close the connection. Args: exc_type (`type[BaseException] | None`): The type of the exception. exc_value (`BaseException | None`): The exception instance. traceback (`Any`): The traceback. """ await self.close()