# -*- coding: utf-8 -*-
"""The Redis session class."""
import json
from typing import Any, TYPE_CHECKING
from ._session_base import SessionBase
from .._logging import logger
from ..module import StateModule
if TYPE_CHECKING:
from redis.asyncio import ConnectionPool, Redis
else:
ConnectionPool = Any
Redis = Any
[文档]
class RedisSession(SessionBase):
"""The Redis session class."""
SESSION_KEY = "user_id:{user_id}:session:{session_id}:state"
"""Redis key pattern (without prefix) for storing sessions state for
a specific session.
"""
[文档]
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: str | None = None,
connection_pool: ConnectionPool | None = None,
key_ttl: int | None = None,
key_prefix: str = "",
**kwargs: Any,
) -> None:
"""Initialize the Redis session class by connecting to Redis.
Args:
host (`str`, defaults to `"localhost"`):
Redis server host.
port (`int`, defaults to `6379`):
Redis server port.
db (`int`, defaults to `0`):
Redis database index.
password (`str | None`, optional):
Redis password if required.
connection_pool (`ConnectionPool | None`, optional):
Optional Redis connection pool.
key_ttl (`int | None`, optional):
Expire time in seconds for each session state key. If provided,
the expiration will be refreshed on every save/load
(sliding TTL). If `None`, the session state will not expire.
key_prefix (`str`, default to `""`):
Optional Redis key prefix prepended to every key generated by
this storage. Useful for isolating keys across
apps/environments (e.g. `"prod:"`, `"staging:"`, `"myapp:"`).
**kwargs (`Any`):
Additional keyword arguments passed to redis client.
"""
try:
import redis.asyncio as redis
except ImportError as e:
raise ImportError(
"The 'redis' package is required for RedisSession. "
"Please install it via 'pip install redis[async]'.",
) from e
self.key_ttl = key_ttl
self.key_prefix = key_prefix or ""
self._client = redis.Redis(
host=host,
port=port,
db=db,
password=password,
connection_pool=connection_pool,
decode_responses=True,
**kwargs,
)
[文档]
def get_client(self) -> Redis:
"""Get the underlying Redis client.
Returns:
`Redis`:
The Redis client instance.
"""
return self._client
def _get_session_key(
self,
session_id: str,
user_id: str,
) -> str:
"""Get the Redis key to store a given session state."""
return self.key_prefix + self.SESSION_KEY.format(
user_id=user_id,
session_id=session_id,
)
[文档]
async def save_session_state(
self,
session_id: str,
user_id: str = "default_user",
**state_modules_mapping: StateModule,
) -> None:
"""Save the state dictionary to Redis.
Args:
session_id (`str`):
The session id.
user_id (`str`, default to `"default_user"`):
The user ID for the storage.
**state_modules_mapping (`dict[str, StateModule]`):
A dictionary mapping of state module names to their instances.
"""
state_dicts = {
name: state_module.state_dict()
for name, state_module in state_modules_mapping.items()
}
key = self._get_session_key(session_id, user_id=user_id)
value = json.dumps(state_dicts, ensure_ascii=False)
await self._client.set(key, value, ex=self.key_ttl)
logger.info("Save session state to redis key %s successfully.", key)
[文档]
async def load_session_state(
self,
session_id: str,
user_id: str = "default_user",
allow_not_exist: bool = True,
**state_modules_mapping: StateModule,
) -> None:
"""Load the state dictionary from Redis.
Args:
session_id (`str`):
The session id.
user_id (`str`, default to `"default_user"`):
The user ID for the storage.
allow_not_exist (`bool`, defaults to `True`):
Whether to allow the session to not exist.
**state_modules_mapping (`dict[str, StateModule]`):
The mapping of state modules to be loaded.
"""
key = self._get_session_key(session_id, user_id=user_id)
# Use GETEX to get and refresh TTL in a single atomic operation
if self.key_ttl is not None:
data = await self._client.getex(key, ex=self.key_ttl)
else:
data = await self._client.get(key)
if data is None:
if allow_not_exist:
logger.info(
"Session key %s does not exist in Redis. Skip loading "
"session state.",
key,
)
return
raise ValueError(
f"Failed to load session state because redis key {key} "
"does not exist.",
)
if isinstance(data, (bytes, bytearray)):
data = data.decode("utf-8")
states = json.loads(data)
for name, state_module in state_modules_mapping.items():
if name in states:
state_module.load_state_dict(states[name])
logger.info("Load session state from redis key %s successfully.", key)
[文档]
async def close(self) -> None:
"""Close the Redis client connection."""
await self._client.close()
async def __aenter__(self) -> "RedisSession":
"""Enter the async context manager.
Returns:
`RedisSession`:
The current `RedisSession` instance.
"""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: Any,
) -> None:
"""Exit the async context manager and close the connection.
Args:
exc_type (`type[BaseException] | None`):
The type of the exception.
exc_value (`BaseException | None`):
The exception instance.
traceback (`Any`):
The traceback.
"""
await self.close()